Skip to content

Commit

Permalink
Implement FromConcurrentStream for Result<Vec<T>, E>
Browse files Browse the repository at this point in the history
  • Loading branch information
tyilo committed Sep 20, 2024
1 parent a79d3df commit 224562a
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions src/concurrent_stream/from_concurrent_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ impl<T> FromConcurrentStream<T> for Vec<T> {
}
}

impl<T, E> FromConcurrentStream<Result<T, E>> for Result<Vec<T>, E> {
async fn from_concurrent_stream<S>(iter: S) -> Self
where
S: IntoConcurrentStream<Item = Result<T, E>>,
{
let stream = iter.into_co_stream();
let mut output = Ok(Vec::with_capacity(stream.size_hint().1.unwrap_or_default()));
stream.drive(ResultVecConsumer::new(&mut output)).await;
output
}
}

// TODO: replace this with a generalized `fold` operation
#[pin_project]
pub(crate) struct VecConsumer<'a, Fut: Future> {
Expand Down Expand Up @@ -73,6 +85,60 @@ where
}
}

#[pin_project]
pub(crate) struct ResultVecConsumer<'a, Fut: Future, T, E> {
#[pin]
group: FuturesUnordered<Fut>,
output: &'a mut Result<Vec<T>, E>,
}

impl<'a, Fut: Future, T, E> ResultVecConsumer<'a, Fut, T, E> {
pub(crate) fn new(output: &'a mut Result<Vec<T>, E>) -> Self {
Self {
group: FuturesUnordered::new(),
output,
}
}
}

impl<'a, Fut, T, E> Consumer<Result<T, E>, Fut> for ResultVecConsumer<'a, Fut, T, E>
where
Fut: Future<Output = Result<T, E>>,
{
type Output = ();

async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState {
let mut this = self.project();
// unbounded concurrency, so we just goooo
this.group.as_mut().push(future);
ConsumerState::Continue
}

async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let mut this = self.project();

while let Some(item) = this.group.next().await {
match item {
Ok(item) => {
let Ok(items) = this.output else {
panic!("progress called after returning ConsumerState::Break");
};
items.push(item);
}
Err(e) => {
**this.output = Err(e);
return ConsumerState::Break;
}
}
}
ConsumerState::Empty
}

async fn flush(self: Pin<&mut Self>) -> Self::Output {
self.progress().await;
}
}

#[cfg(test)]
mod test {
use crate::prelude::*;
Expand All @@ -85,4 +151,24 @@ mod test {
assert_eq!(v, &[1, 1, 1, 1, 1]);
});
}

#[test]
fn collect_to_result_ok() {
futures_lite::future::block_on(async {
let v: Result<Vec<_>, ()> = stream::repeat(Ok(1)).co().take(5).collect().await;
assert_eq!(v, Ok(vec![1, 1, 1, 1, 1]));
});
}

#[test]
fn collect_to_result_err() {
futures_lite::future::block_on(async {
let v: Result<Vec<_>, _> = stream::repeat(Err::<u8, _>(()))
.co()
.take(5)
.collect()
.await;
assert_eq!(v, Err(()));
});
}
}

0 comments on commit 224562a

Please sign in to comment.