From 552db21ce17f2f78618fbef01ecab0c5bb8a608e Mon Sep 17 00:00:00 2001 From: Vecna Date: Mon, 29 Apr 2024 11:51:54 -0400 Subject: [PATCH] Add mock server to serve extra-infos for tests and simulation --- Cargo.toml | 3 + src/lib.rs | 57 ++++++---- src/simulation/extra_infos_server.rs | 163 +++++++++++++++++++++++++++ src/tests.rs | 44 +++++++- 4 files changed, 243 insertions(+), 24 deletions(-) create mode 100644 src/simulation/extra_infos_server.rs diff --git a/Cargo.toml b/Cargo.toml index e03ae01..98edbf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,6 @@ x25519-dalek = { version = "2", features = ["serde", "static_secrets"] } [dev-dependencies] base64 = "0.21.7" + +[features] +simulation = [] diff --git a/src/lib.rs b/src/lib.rs index 4bc14f5..c8eab58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,11 @@ pub mod negative_report; pub mod positive_report; pub mod request_handler; +#[cfg(any(test, feature = "simulation"))] +pub mod simulation { + pub mod extra_infos_server; +} + use analysis::Analyzer; use extra_info::*; use negative_report::*; @@ -205,27 +210,41 @@ pub fn add_bridge_to_db(db: &Db, fingerprint: [u8; 20]) { // Download a webpage and return it as a string pub async fn download(url: &str) -> Result> { - let https = hyper_rustls::HttpsConnectorBuilder::new() - .with_native_roots() - .expect("no native root CA certificates found") - .https_only() - .enable_http1() - .build(); - - let client: hyper_util::client::legacy::Client<_, Empty> = - hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https); - - println!("Downloading {}", url); - let mut res = client.get(url.parse()?).await?; - assert_eq!(res.status(), StatusCode::OK); - let mut body_str = String::default(); - while let Some(next) = res.frame().await { - let frame = next?; - if let Some(chunk) = frame.data_ref() { - body_str.push_str(&String::from_utf8(chunk.to_vec())?); + if url.starts_with("https://") { + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_native_roots() + .expect("no native root CA certificates found") + .https_only() + .enable_http1() + .build(); + let client: hyper_util::client::legacy::Client<_, Empty> = + hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https); + println!("Downloading {}", url); + let mut res = client.get(url.parse()?).await?; + assert_eq!(res.status(), StatusCode::OK); + let mut body_str = String::default(); + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + body_str.push_str(&String::from_utf8(chunk.to_vec())?); + } } + Ok(body_str) + } else { + let client: hyper_util::client::legacy::Client<_, Empty> = + hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build_http(); + println!("Downloading {}", url); + let mut res = client.get(url.parse()?).await?; + assert_eq!(res.status(), StatusCode::OK); + let mut body_str = String::default(); + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + body_str.push_str(&String::from_utf8(chunk.to_vec())?); + } + } + Ok(body_str) } - Ok(body_str) } // Process extra-infos diff --git a/src/simulation/extra_infos_server.rs b/src/simulation/extra_infos_server.rs new file mode 100644 index 0000000..a585f35 --- /dev/null +++ b/src/simulation/extra_infos_server.rs @@ -0,0 +1,163 @@ +use crate::extra_info::ExtraInfo; + +use hyper::{ + body::{self, Bytes}, + header::HeaderValue, + server::conn::AddrStream, + service::{make_service_fn, service_fn}, + Body, Method, Request, Response, Server, +}; +use serde_json::json; +use std::{collections::HashSet, convert::Infallible, net::SocketAddr, time::Duration}; +use tokio::{ + spawn, + sync::{broadcast, mpsc, oneshot}, + time::sleep, +}; + +async fn serve_extra_infos( + extra_infos_pages: &mut Vec, + req: Request, +) -> Result, Infallible> { + match req.method() { + &Method::OPTIONS => Ok(Response::builder() + .header("Access-Control-Allow-Origin", HeaderValue::from_static("*")) + .header("Access-Control-Allow-Headers", "accept, content-type") + .header("Access-Control-Allow-Methods", "POST") + .status(200) + .body(Body::from("Allow POST")) + .unwrap()), + _ => match req.uri().path() { + "/" => Ok::<_, Infallible>(serve_index(&extra_infos_pages)), + "/add" => Ok::<_, Infallible>({ + let bytes = body::to_bytes(req.into_body()).await.unwrap(); + add_extra_infos(extra_infos_pages, bytes) + }), + path => Ok::<_, Infallible>({ + // Serve the requested file + serve_extra_infos_file(&extra_infos_pages, path) + }), + }, + } +} + +pub async fn server() { + let (context_tx, context_rx) = mpsc::channel(32); + let request_tx = context_tx.clone(); + let shutdown_cmd_tx = context_tx.clone(); + let (shutdown_tx, mut shutdown_rx) = broadcast::channel(16); + let kill_context = shutdown_tx.subscribe(); + + let context_manager = + spawn(async move { create_context_manager(context_rx, kill_context).await }); + + let addr = SocketAddr::from(([127, 0, 0, 1], 8004)); + let make_svc = make_service_fn(move |_conn: &AddrStream| { + let request_tx = request_tx.clone(); + let service = service_fn(move |req| { + let request_tx = request_tx.clone(); + let (response_tx, response_rx) = oneshot::channel(); + let cmd = Command::Request { + req, + sender: response_tx, + }; + async move { + request_tx.send(cmd).await.unwrap(); + response_rx.await.unwrap() + } + }); + async move { Ok::<_, Infallible>(service) } + }); + let server = Server::bind(&addr).serve(make_svc); + println!("Listening on localhost:8004"); + if let Err(e) = server.await { + eprintln!("server error: {}", e); + } +} + +async fn create_context_manager( + context_rx: mpsc::Receiver, + mut kill: broadcast::Receiver<()>, +) { + tokio::select! { + create_context = context_manager(context_rx) => create_context, + _ = kill.recv() => {println!("Shut down context_manager");}, + } +} + +async fn context_manager(mut context_rx: mpsc::Receiver) { + let mut extra_infos_pages = Vec::::new(); + + while let Some(cmd) = context_rx.recv().await { + use Command::*; + match cmd { + Request { req, sender } => { + let response = serve_extra_infos(&mut extra_infos_pages, req).await; + if let Err(e) = sender.send(response) { + eprintln!("Server Response Error: {:?}", e); + } + sleep(Duration::from_millis(1)).await; + } + Shutdown { shutdown_sig } => { + drop(shutdown_sig); + } + } + } +} + +#[derive(Debug)] +enum Command { + Request { + req: Request, + sender: oneshot::Sender, Infallible>>, + }, + Shutdown { + shutdown_sig: broadcast::Sender<()>, + }, +} + +fn add_extra_infos(extra_infos_pages: &mut Vec, request: Bytes) -> Response { + let extra_infos: HashSet = match serde_json::from_slice(&request) { + Ok(req) => req, + Err(e) => { + let response = json!({"error": e.to_string()}); + let val = serde_json::to_string(&response).unwrap(); + return prepare_header(val); + } + }; + + let mut extra_infos_file = String::new(); + for extra_info in extra_infos { + extra_infos_file.push_str(extra_info.to_string().as_str()); + } + extra_infos_pages.push(extra_infos_file); + prepare_header("OK".to_string()) +} + +fn serve_index(extra_infos_pages: &Vec) -> Response { + let mut body_str = String::new(); + for i in 0..extra_infos_pages.len() { + body_str + .push_str(format!("{}-extra-infos\n", i, i).as_str()); + } + prepare_header(body_str) +} + +fn serve_extra_infos_file(extra_infos_pages: &Vec, path: &str) -> Response { + if path.ends_with("-extra-infos") { + if let Ok(index) = &path[1..(path.len() - "-extra-infos".len())].parse::() { + if extra_infos_pages.len() > *index { + return prepare_header(extra_infos_pages[*index].clone()); + } + } + } + prepare_header("Not a valid file".to_string()) +} + +// Prepare HTTP Response for successful Server Request +fn prepare_header(response: String) -> Response { + let mut resp = Response::new(Body::from(response)); + resp.headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + resp +} diff --git a/src/tests.rs b/src/tests.rs index 97c3248..eacf575 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -3,6 +3,7 @@ use crate::{ analysis::{blocked_in, Analyzer}, bridge_verification_info::BridgeVerificationInfo, + simulation::extra_infos_server, *, }; use lox_library::{ @@ -19,7 +20,9 @@ use sha1::{Digest, Sha1}; use std::{ collections::{BTreeMap, HashMap, HashSet}, sync::{Arc, Mutex}, + time::Duration, }; +use tokio::{spawn, time::sleep}; use x25519_dalek::{PublicKey, StaticSecret}; struct TestHarness { @@ -180,8 +183,8 @@ async fn test_download_extra_infos() { bincode::deserialize(&db.get(bridge_to_test).unwrap().unwrap()).unwrap(); } -#[test] -fn test_simulate_extra_infos() { +#[tokio::test] +async fn test_simulate_extra_infos() { let extra_info_str = r#"@type bridge-extra-info 1.3 extra-info ElephantBridgeDE2 72E12B89136B45BBC81D1EF0AC7DDDBB91B148DB master-key-ed25519 eWxjRwAWW7n8BGG9fNa6rApmBFbe3f0xcD7dqwOICW8 @@ -240,10 +243,41 @@ router-digest F30B38390C375E1EE74BFED844177804442569E0"#; assert!(!db.contains_key("bridges").unwrap()); assert!(!db.contains_key(bridge_to_test).unwrap()); - // TODO: Run local web server and change this to update_extra_infos - add_extra_info_to_db(&db, extra_info_2); + // Start web server + spawn(async move { + extra_infos_server::server().await; + }); - // Check that DB contains information on a bridge with high uptime + // Give server time to start + sleep(Duration::new(1, 0)).await; + + // Update extra-infos (no new data to add) + update_extra_infos(&db, "http://localhost:8004/") + .await + .unwrap(); + + // Check that database is still empty + assert!(!db.contains_key("bridges").unwrap()); + assert!(!db.contains_key(bridge_to_test).unwrap()); + + // Add our extra-info to the server's records + { + use hyper::{Body, Client, Method, Request}; + let client = Client::new(); + let req = Request::builder() + .method(Method::POST) + .uri("http://localhost:8004/add".parse::().unwrap()) + .body(Body::from(serde_json::to_string(&extra_info_set).unwrap())) + .unwrap(); + client.request(req).await.unwrap(); + } + + // Update extra-infos (add new record) + update_extra_infos(&db, "http://localhost:8004/") + .await + .unwrap(); + + // Check that DB now contains information on this bridge assert!(db.contains_key("bridges").unwrap()); let bridges: HashSet<[u8; 20]> = bincode::deserialize(&db.get("bridges").unwrap().unwrap()).unwrap();