From 0b0324c4874c882e748d5f0d96ef450dc91d247f Mon Sep 17 00:00:00 2001 From: Cecylia Bocovich Date: Tue, 21 Mar 2023 17:37:15 -0400 Subject: [PATCH] Change ResourceStream from an Iterator into a Stream --- crates/rdsys-backend/Cargo.toml | 5 +- crates/rdsys-backend/src/lib.rs | 182 ++++++++++++++++++-------------- 2 files changed, 106 insertions(+), 81 deletions(-) diff --git a/crates/rdsys-backend/Cargo.toml b/crates/rdsys-backend/Cargo.toml index df6fbc6..cdaa5a0 100644 --- a/crates/rdsys-backend/Cargo.toml +++ b/crates/rdsys-backend/Cargo.toml @@ -10,6 +10,7 @@ serde_json = "1" futures-util = { version = "0.3"} serde = { version = "1", features = ["derive"]} bytes = "1" -tokio = "1" - +tokio = { version = "1", features = ["macros"]} reqwest = { version = "0.11", features = ["stream"]} +tokio-stream = "0.1.12" +futures = "0.3.27" diff --git a/crates/rdsys-backend/src/lib.rs b/crates/rdsys-backend/src/lib.rs index b54a67c..7ac1c52 100644 --- a/crates/rdsys-backend/src/lib.rs +++ b/crates/rdsys-backend/src/lib.rs @@ -4,10 +4,13 @@ //! https://gitlab.torproject.org/tpo/anti-censorship/rdsys/-/blob/main/doc/backend-api.md use bytes::{self, Buf, Bytes}; -use futures_util::StreamExt; +use core::pin::Pin; +use core::task::Poll; +use futures_util::{Stream, StreamExt}; use reqwest::Client; use std::io::{self, BufRead}; -use std::sync::mpsc; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; pub mod proto; @@ -39,27 +42,37 @@ impl From for Error { /// An iterable wrapper of ResourceDiff items for the streamed chunks of Bytes /// received from the connection to the rdsys backend pub struct ResourceStream { - rx: mpsc::Receiver, + rx: ReceiverStream, buf: Vec, partial: Option>, } -impl Iterator for ResourceStream { +impl ResourceStream { + pub fn new(rx: mpsc::Receiver) -> ResourceStream { + ResourceStream { + rx: ReceiverStream::new(rx), + buf: vec![], + partial: None, + } + } +} + +impl Stream for ResourceStream { type Item = proto::ResourceDiff; - fn next(&mut self) -> Option { - let mut parse = - |buffer: &mut bytes::buf::Reader| -> Result, Error> { - match buffer.read_until(b'\r', &mut self.buf) { - Ok(_) => match self.buf.pop() { - Some(b'\r') => match serde_json::from_slice(&self.buf) { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut core::task::Context) -> Poll> { + let parse = + |buffer: &mut bytes::buf::Reader, buf: &mut Vec| -> Result, Error> { + match buffer.read_until(b'\r', buf) { + Ok(_) => match buf.pop() { + Some(b'\r') => match serde_json::from_slice(buf) { Ok(diff) => { - self.buf.clear(); + buf.clear(); Ok(Some(diff)) } Err(e) => Err(Error::JSON(e)), }, Some(n) => { - self.buf.push(n); + buf.push(n); Ok(None) } None => Ok(None), @@ -67,26 +80,33 @@ impl Iterator for ResourceStream { Err(e) => Err(Error::Io(e)), } }; - + // This clone is here to avoid having multiple mutable references to self + // it's not optimal performance-wise but given that these resource streams aren't large + // this feels like an acceptable trade-off to the complexity of interior mutability + let mut buf = self.buf.clone(); if let Some(p) = &mut self.partial { - match parse(p) { - Ok(Some(diff)) => return Some(diff), + match parse(p, &mut buf) { + Ok(Some(diff)) => return Poll::Ready(Some(diff)), Ok(None) => self.partial = None, - Err(_) => return None, + Err(_) => return Poll::Ready(None), } } - for chunk in &self.rx { - let mut buffer = chunk.reader(); - match parse(&mut buffer) { - Ok(Some(diff)) => { - self.partial = Some(buffer); - return Some(diff); + self.buf = buf; + match Pin::new(&mut self.rx).poll_next(cx) { + Poll::Ready(Some(chunk)) => { + let mut buffer = chunk.reader(); + match parse(&mut buffer, &mut self.buf) { + Ok(Some(diff)) => { + self.partial = Some(buffer); + return Poll::Ready(Some(diff)); + } + Ok(None) => Poll::Pending, //maybe loop here? + Err(_) => return Poll::Ready(None), } - Ok(None) => continue, - Err(_) => return None, - }; + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } - None } } @@ -94,72 +114,76 @@ impl Iterator for ResourceStream { mod tests { use super::*; - #[test] - fn parse_resource() { + #[tokio::test] + async fn parse_resource() { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); let chunk = Bytes::from_static( b"{\"new\": null,\"changed\": null,\"gone\": null,\"full_update\": true}\r", ); - let (tx, rx) = mpsc::channel(); - tx.send(chunk).unwrap(); - let mut diffs = ResourceStream { - rx: rx, - partial: None, - buf: vec![], - }; - let res = diffs.next(); - assert_ne!(res, None); - if let Some(diff) = res { + let (tx, rx) = mpsc::channel(100); + tx.send(chunk).await.unwrap(); + let mut diffs = ResourceStream::new(rx); + let res = Pin::new(&mut diffs).poll_next(&mut cx); + assert_ne!(res, Poll::Ready(None)); + assert_ne!(res, Poll::Pending); + if let Poll::Ready(Some(diff)) = res { assert_eq!(diff.new, None); assert_eq!(diff.full_update, true); } } - #[test] - fn parse_across_chunks() { + #[tokio::test] + async fn parse_across_chunks() { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,"); let chunk2 = Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r"); - let (tx, rx) = mpsc::channel(); - tx.send(chunk1).unwrap(); - tx.send(chunk2).unwrap(); - let mut diffs = ResourceStream { - rx: rx, - partial: None, - buf: vec![], - }; - let res = diffs.next(); - assert_ne!(res, None); - if let Some(diff) = res { + let (tx, rx) = mpsc::channel(100); + tx.send(chunk1).await.unwrap(); + tx.send(chunk2).await.unwrap(); + let mut diffs = ResourceStream::new(rx); + let mut res = Pin::new(&mut diffs).poll_next(&mut cx); + while let Poll::Pending = res { + res = Pin::new(&mut diffs).poll_next(&mut cx); + } + assert_ne!(res, Poll::Ready(None)); + assert_ne!(res, Poll::Pending); + if let Poll::Ready(Some(diff)) = res { assert_eq!(diff.new, None); assert_eq!(diff.full_update, true); } } - #[test] - fn parse_multi_diff_partial_chunks() { + #[tokio::test] + async fn parse_multi_diff_partial_chunks() { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,"); let chunk2 = Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r{\"new\": null,\"changed"); let chunk3 = Bytes::from_static(b"\": null,\"gone\": null,\"full_update\": true}"); let chunk4 = Bytes::from_static(b"\r"); - let (tx, rx) = mpsc::channel(); - tx.send(chunk1).unwrap(); - tx.send(chunk2).unwrap(); - tx.send(chunk3).unwrap(); - tx.send(chunk4).unwrap(); - let mut diffs = ResourceStream { - rx: rx, - partial: None, - buf: vec![], - }; - let mut res = diffs.next(); - assert_ne!(res, None); - if let Some(diff) = res { + let (tx, rx) = mpsc::channel(100); + tx.send(chunk1).await.unwrap(); + tx.send(chunk2).await.unwrap(); + tx.send(chunk3).await.unwrap(); + tx.send(chunk4).await.unwrap(); + let mut diffs = ResourceStream::new(rx); + let mut res = Pin::new(&mut diffs).poll_next(&mut cx); + while let Poll::Pending = res { + res = Pin::new(&mut diffs).poll_next(&mut cx); + } + assert_ne!(res, Poll::Ready(None)); + assert_ne!(res, Poll::Pending); + if let Poll::Ready(Some(diff)) = res { assert_eq!(diff.new, None); assert_eq!(diff.full_update, true); } - res = diffs.next(); - assert_ne!(res, None); - if let Some(diff) = res { + res = Pin::new(&mut diffs).poll_next(&mut cx); + while let Poll::Pending = res { + res = Pin::new(&mut diffs).poll_next(&mut cx); + } + assert_ne!(res, Poll::Ready(None)); + assert_ne!(res, Poll::Pending); + if let Poll::Ready(Some(diff)) = res { assert_eq!(diff.new, None); assert_eq!(diff.full_update, true); } @@ -178,9 +202,13 @@ mod tests { /// let name = String::from("https"); /// let token = String::from("HttpsApiTokenPlaceholder"); /// let types = vec![String::from("obfs2"), String::from("scramblesuit")]; -/// let rx = start_stream(endpoint, name, token, types).await.unwrap(); -/// for diff in rx { -/// println!("Received diff: {:?}", diff); +/// let stream = start_stream(endpoint, name, token, types).await.unwrap(); +/// loop { +/// match Pin::new(&mut stream).poll_next(&mut cx) { +/// Poll::Ready(Some(diff)) => println!("Received diff: {:?}", diff), +/// Poll::Ready(None) => break, +/// Poll::Pending => continue, +/// } /// } /// ``` /// @@ -190,7 +218,7 @@ pub async fn start_stream( token: String, resource_types: Vec, ) -> Result { - let (tx, rx) = mpsc::channel(); + let (tx, rx) = mpsc::channel(100); let req = proto::ResourceRequest { request_origin: name, @@ -218,12 +246,8 @@ pub async fn start_stream( return; } }; - tx.send(bytes).unwrap(); + tx.send(bytes).await.unwrap(); } }); - Ok(ResourceStream { - rx, - buf: vec![], - partial: None, - }) + Ok(ResourceStream::new(rx)) }