Skip to content

Commit

Permalink
start pinning APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshuawuyts committed Mar 20, 2024
1 parent 4017779 commit f6e63b4
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 73 deletions.
23 changes: 15 additions & 8 deletions src/concurrent_stream/enumerate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use pin_project::pin_project;

use super::{ConcurrentStream, Consumer};
use core::future::Future;
use core::num::NonZeroUsize;
Expand Down Expand Up @@ -47,7 +49,9 @@ impl<CS: ConcurrentStream> ConcurrentStream for Enumerate<CS> {
}
}

#[pin_project]
struct EnumerateConsumer<C> {
#[pin]
inner: C,
count: usize,
}
Expand All @@ -58,18 +62,21 @@ where
{
type Output = C::Output;

async fn send(&mut self, future: Fut) -> super::ConsumerState {
let count = self.count;
self.count += 1;
self.inner.send(EnumerateFuture::new(future, count)).await
async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState {
let this = self.project();
let count = *this.count;
*this.count += 1;
this.inner.send(EnumerateFuture::new(future, count)).await
}

async fn progress(&mut self) -> super::ConsumerState {
self.inner.progress().await
async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let this = self.project();
this.inner.progress().await
}

async fn flush(&mut self) -> Self::Output {
self.inner.flush().await
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let this = self.project();
this.inner.flush().await
}
}

Expand Down
27 changes: 15 additions & 12 deletions src/concurrent_stream/for_each.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ where
// NOTE: we can remove the `Arc` here if we're willing to make this struct self-referential
count: Arc<AtomicUsize>,
#[pin]
group: Pin<Box<FutureGroup<ForEachFut<F, FutT, T, FutB>>>>,
group: FutureGroup<ForEachFut<F, FutT, T, FutB>>,
limit: usize,
f: F,
_phantom: PhantomData<(T, FutB)>,
Expand All @@ -45,7 +45,7 @@ where
f,
_phantom: PhantomData,
count: Arc::new(AtomicUsize::new(0)),
group: Box::pin(FutureGroup::new()),
group: FutureGroup::new(),
}
}
}
Expand All @@ -60,30 +60,33 @@ where
{
type Output = ();

async fn send(&mut self, future: FutT) -> super::ConsumerState {
async fn send(mut self: Pin<&mut Self>, future: FutT) -> super::ConsumerState {
let mut this = self.project();
// If we have no space, we're going to provide backpressure until we have space
while self.count.load(Ordering::Relaxed) >= self.limit {
self.group.next().await;
while this.count.load(Ordering::Relaxed) >= *this.limit {
this.group.next().await;
}

// Space was available! - insert the item for posterity
self.count.fetch_add(1, Ordering::Relaxed);
let fut = ForEachFut::new(self.f.clone(), future, self.count.clone());
self.group.as_mut().insert_pinned(fut);
this.count.fetch_add(1, Ordering::Relaxed);
let fut = ForEachFut::new(this.f.clone(), future, this.count.clone());
this.group.as_mut().insert_pinned(fut);

ConsumerState::Continue
}

async fn progress(&mut self) -> super::ConsumerState {
while let Some(_) = self.group.next().await {}
async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let mut this = self.project();
while let Some(_) = this.group.next().await {}
ConsumerState::Empty
}

async fn flush(&mut self) -> Self::Output {
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let mut this = self.project();
// 4. We will no longer receive any additional futures from the
// underlying stream; wait until all the futures in the group have
// resolved.
while let Some(_) = self.group.next().await {}
while let Some(_) = this.group.next().await {}
}
}

Expand Down
26 changes: 16 additions & 10 deletions src/concurrent_stream/from_concurrent_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use alloc::vec::Vec;
use core::future::Future;
use core::pin::Pin;
use futures_lite::StreamExt;
use pin_project::pin_project;

/// Conversion from a [`ConcurrentStream`]
#[allow(async_fn_in_trait)]
Expand All @@ -28,15 +29,17 @@ impl<T> FromConcurrentStream<T> for Vec<T> {
}

// TODO: replace this with a generalized `fold` operation
#[pin_project]
pub(crate) struct VecConsumer<'a, Fut: Future> {
group: Pin<Box<FutureGroup<Fut>>>,
#[pin]
group: FutureGroup<Fut>,
output: &'a mut Vec<Fut::Output>,
}

impl<'a, Fut: Future> VecConsumer<'a, Fut> {
pub(crate) fn new(output: &'a mut Vec<Fut::Output>) -> Self {
Self {
group: Box::pin(FutureGroup::new()),
group: FutureGroup::new(),
output,
}
}
Expand All @@ -48,21 +51,24 @@ where
{
type Output = ();

async fn send(&mut self, future: Fut) -> super::ConsumerState {
async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState {
let mut this = self.project();
// unbounded concurrency, so we just goooo
self.group.as_mut().insert_pinned(future);
this.group.as_mut().insert_pinned(future);
ConsumerState::Continue
}

async fn progress(&mut self) -> super::ConsumerState {
while let Some(item) = self.group.next().await {
self.output.push(item);
async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let mut this = self.project();
while let Some(item) = this.group.next().await {
this.output.push(item);
}
ConsumerState::Empty
}
async fn flush(&mut self) -> Self::Output {
while let Some(item) = self.group.next().await {
self.output.push(item);
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let mut this = self.project();
while let Some(item) = this.group.next().await {
this.output.push(item);
}
}
}
Expand Down
20 changes: 14 additions & 6 deletions src/concurrent_stream/limit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use pin_project::pin_project;

use super::{ConcurrentStream, Consumer};
use core::future::Future;
use core::num::NonZeroUsize;
use core::pin::Pin;

/// A concurrent iterator that limits the amount of concurrency applied.
///
Expand Down Expand Up @@ -43,7 +46,9 @@ impl<CS: ConcurrentStream> ConcurrentStream for Limit<CS> {
}
}

#[pin_project]
struct LimitConsumer<C> {
#[pin]
inner: C,
}
impl<C, Item, Fut> Consumer<Item, Fut> for LimitConsumer<C>
Expand All @@ -53,15 +58,18 @@ where
{
type Output = C::Output;

async fn send(&mut self, future: Fut) -> super::ConsumerState {
self.inner.send(future).await
async fn send(self: Pin<&mut Self>, future: Fut) -> super::ConsumerState {
let this = self.project();
this.inner.send(future).await
}

async fn progress(&mut self) -> super::ConsumerState {
self.inner.progress().await
async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let this = self.project();
this.inner.progress().await
}

async fn flush(&mut self) -> Self::Output {
self.inner.flush().await
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let this = self.project();
this.inner.flush().await
}
}
22 changes: 14 additions & 8 deletions src/concurrent_stream/map.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use pin_project::pin_project;

use super::{ConcurrentStream, Consumer};
use core::num::NonZeroUsize;
use core::{
Expand Down Expand Up @@ -71,7 +73,7 @@ where
}
}

// OK: validated! - all bounds should check out
#[pin_project]
pub struct MapConsumer<C, F, FutT, T, FutB, B>
where
FutT: Future<Output = T>,
Expand All @@ -80,6 +82,7 @@ where
F: Clone,
FutB: Future<Output = B>,
{
#[pin]
inner: C,
f: F,
_phantom: PhantomData<(FutT, T, FutB, B)>,
Expand All @@ -95,17 +98,20 @@ where
{
type Output = C::Output;

async fn progress(&mut self) -> super::ConsumerState {
self.inner.progress().await
async fn progress(self: Pin<&mut Self>) -> super::ConsumerState {
let this = self.project();
this.inner.progress().await
}

async fn send(&mut self, future: FutT) -> super::ConsumerState {
let fut = MapFuture::new(self.f.clone(), future);
self.inner.send(fut).await
async fn send(self: Pin<&mut Self>, future: FutT) -> super::ConsumerState {
let this = self.project();
let fut = MapFuture::new(this.f.clone(), future);
this.inner.send(fut).await
}

async fn flush(&mut self) -> Self::Output {
self.inner.flush().await
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let this = self.project();
this.inner.flush().await
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/concurrent_stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod try_for_each;

use core::future::Future;
use core::num::NonZeroUsize;
use core::pin::Pin;
use for_each::ForEachConsumer;
use try_for_each::TryForEachConsumer;

Expand All @@ -37,18 +38,18 @@ where
type Output;

/// Send an item down to the next step in the processing queue.
async fn send(&mut self, fut: Fut) -> ConsumerState;
async fn send(self: Pin<&mut Self>, fut: Fut) -> ConsumerState;

/// Make progress on the consumer while doing something else.
///
/// It should always be possible to drop the future returned by this
/// function. This is solely intended to keep work going on the `Consumer`
/// while doing e.g. waiting for new futures from a stream.
async fn progress(&mut self) -> ConsumerState;
async fn progress(self: Pin<&mut Self>) -> ConsumerState;

/// We have no more data left to send to the `Consumer`; wait for its
/// output.
async fn flush(&mut self) -> Self::Output;
async fn flush(self: Pin<&mut Self>) -> Self::Output;
}

/// Concurrently operate over items in a stream
Expand Down
24 changes: 16 additions & 8 deletions src/concurrent_stream/take.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use pin_project::pin_project;

use super::{ConcurrentStream, Consumer, ConsumerState};
use core::future::Future;
use core::num::NonZeroUsize;
use core::pin::Pin;

/// A concurrent iterator that only iterates over the first `n` iterations of `iter`.
///
Expand Down Expand Up @@ -49,7 +52,9 @@ impl<CS: ConcurrentStream> ConcurrentStream for Take<CS> {
}
}

#[pin_project]
struct TakeConsumer<C> {
#[pin]
inner: C,
count: usize,
limit: usize,
Expand All @@ -61,22 +66,25 @@ where
{
type Output = C::Output;

async fn send(&mut self, future: Fut) -> ConsumerState {
self.count += 1;
let state = self.inner.send(future).await;
if self.count >= self.limit {
async fn send(self: Pin<&mut Self>, future: Fut) -> ConsumerState {
let this = self.project();
*this.count += 1;
let state = this.inner.send(future).await;
if this.count >= this.limit {
ConsumerState::Break
} else {
state
}
}

async fn progress(&mut self) -> ConsumerState {
self.inner.progress().await
async fn progress(self: Pin<&mut Self>) -> ConsumerState {
let this = self.project();
this.inner.progress().await
}

async fn flush(&mut self) -> Self::Output {
self.inner.flush().await
async fn flush(self: Pin<&mut Self>) -> Self::Output {
let this = self.project();
this.inner.flush().await
}
}

Expand Down
Loading

0 comments on commit f6e63b4

Please sign in to comment.