diff --git a/src/bin/bridgedb.rs b/src/bin/bridgedb.rs index 4a9409a..c2799e5 100644 --- a/src/bin/bridgedb.rs +++ b/src/bin/bridgedb.rs @@ -52,6 +52,8 @@ async fn main() { listen(addr, to_uppercase).await; } -fn to_uppercase(str: String) -> String { - str.to_uppercase() +// This function assumes the byte vector is a valid string. +fn to_uppercase(str_vec: Vec) -> Vec { + let str = std::str::from_utf8(&str_vec).unwrap(); + str.to_uppercase().into() } diff --git a/src/bin/lox_auth.rs b/src/bin/lox_auth.rs index 943cd22..431d8e1 100644 --- a/src/bin/lox_auth.rs +++ b/src/bin/lox_auth.rs @@ -68,6 +68,8 @@ async fn main() { listen(addr, reverse_string).await; } -fn reverse_string(str: String) -> String { - str.trim().chars().rev().collect::() + "\n" +// This function assumes the byte vector is a valid string. +fn reverse_string(str_vec: Vec) -> Vec { + let str = std::str::from_utf8(&str_vec).unwrap(); + str.trim().chars().rev().collect::().into() } diff --git a/src/bin/lox_client.rs b/src/bin/lox_client.rs index e097afb..75d07d6 100644 --- a/src/bin/lox_client.rs +++ b/src/bin/lox_client.rs @@ -35,7 +35,7 @@ async fn main() { let reachability_pub = &lox_auth_pubkeys[3]; let invitation_pub = &lox_auth_pubkeys[4]; - let s = send(addr, msg).await; + let s = send(addr, msg.into()).await; - println!("{}", s); + println!("{}", std::str::from_utf8(&s).unwrap()); } diff --git a/src/client_net.rs b/src/client_net.rs index 0afe21c..04d5724 100644 --- a/src/client_net.rs +++ b/src/client_net.rs @@ -5,32 +5,43 @@ to a server process and sending it data. */ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use std::str; - // may need to change strings to byte vectors in the future -pub async fn send(addr: String, str: String) -> String { +pub async fn send(addr: String, payload: Vec) -> Vec { let mut stream = TcpStream::connect(&addr) .await .expect("Failed to create TcpStream"); + // send number of bytes in payload + let payload_size = usize::to_be_bytes(payload.len()); + stream + .write_all(&payload_size) + .await + .expect("Failed to write number of bytes to listen for"); + // send data stream - .write_all(str.as_bytes()) + .write_all(&payload) .await .expect("Failed to write data to stream"); - // read response - let mut buf = vec![0; 1024]; - let n = stream + // get number of bytes in response + let mut nbuf: [u8; 8] = [0; 8]; + stream + .read(&mut nbuf) + .await + .expect("Failed to get number of bytes to read"); + let n = usize::from_be_bytes(nbuf); + + if n == 0 { + return vec![0; 0]; + } + + // receive response + let mut buf = vec![0; n]; + stream .read(&mut buf) .await .expect("Failed to read data from socket"); - if n == 0 { - return "".to_string(); - } - - str::from_utf8(&buf[0..n]) - .expect("Invalid UTF-8 sequence") - .to_string() + buf } diff --git a/src/server_net.rs b/src/server_net.rs index 1edb4ac..d3bb604 100644 --- a/src/server_net.rs +++ b/src/server_net.rs @@ -9,9 +9,7 @@ these work. */ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; -use std::str; - -pub async fn listen(addr: String, fun: fn(String) -> String) { +pub async fn listen(addr: String, fun: fn(Vec) -> Vec) { let listener = TcpListener::bind(&addr) .await .expect("Failed to create TcpListener"); @@ -22,11 +20,23 @@ pub async fn listen(addr: String, fun: fn(String) -> String) { let (mut socket, _) = listener.accept().await.expect("Failed to create socket"); tokio::spawn(async move { - let mut buf = vec![0; 1024]; - - // read data, perform function on it, write result loop { - let n = socket + // get number of bytes to receive + let mut nbuf: [u8; 8] = [0; 8]; + socket + .read(&mut nbuf) + .await + .expect("Failed to get number of bytes to read"); + let n = usize::from_be_bytes(nbuf); + + if n == 0 { + return; + } + + let mut buf = vec![0; n]; + + // receive data + socket .read(&mut buf) .await .expect("Failed to read data from socket"); @@ -35,15 +45,18 @@ pub async fn listen(addr: String, fun: fn(String) -> String) { return; } - // I think this is a problem if there's more data than fits in the buffer... - // But that's a problem for future me. - let s = str::from_utf8(&buf[0..n]) - .expect("Invalid UTF-8 sequence") - .to_string(); - let response = fun(s); + let response = fun(buf); + // send number of bytes in response + let response_size = usize::to_be_bytes(response.len()); socket - .write_all(response.as_bytes()) + .write_all(&response_size) + .await + .expect("Failed to write number of bytes to listen for"); + + // send response + socket + .write_all(&response) .await .expect("Failed to write data to socket"); }