Skip to content

Commit

Permalink
Impl "perfect" waking for tuple::merge
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus-consoli committed Nov 18, 2022
1 parent 4ea3425 commit 288bcce
Showing 1 changed file with 153 additions and 55 deletions.
208 changes: 153 additions & 55 deletions src/stream/merge/tuple.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,49 @@
use super::Merge as MergeTrait;
use crate::stream::IntoStream;
use crate::utils;
use crate::utils::{self, PollArray, WakerArray};

use core::fmt;
use futures_core::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};

// TODO: handle none case
macro_rules! poll_stream {
($stream_idx:tt, $iteration:ident, $this:ident, $streams:ident, $cx:ident, $len_streams:ident) => {
if $stream_idx == $iteration {
match unsafe { Pin::new_unchecked(&mut $streams.$stream_idx) }.poll_next(&mut $cx) {
Poll::Ready(Some(item)) => {
// Mark ourselves as ready again because we need to poll for the next item.
$this
.wakers
.readiness()
.lock()
.unwrap()
.set_ready($stream_idx);
return Poll::Ready(Some(item));
}
Poll::Ready(None) => {
*$this.completed += 1;
$this.state[$stream_idx].set_consumed();
if *$this.completed == $len_streams {
return Poll::Ready(None);
}
}
Poll::Pending => {}
}
}
};
}

macro_rules! impl_merge_tuple {
($StructName:ident) => {
($ignore:ident $StructName:ident) => {
/// A stream that merges multiple streams into a single stream.
///
/// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
/// documentation for more.
///
/// [`merge`]: trait.Merge.html#method.merge
/// [`Merge`]: trait.Merge.html
#[pin_project::pin_project]
pub struct $StructName {
rng: utils::RandomGenerator,
}
pub struct $StructName {}

impl fmt::Debug for $StructName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand All @@ -29,25 +52,29 @@ macro_rules! impl_merge_tuple {
}

impl Stream for $StructName {
type Item = std::convert::Infallible; // TODO: convert to `never` type in the stdlib
type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib

fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}

impl MergeTrait for () {
type Item = std::convert::Infallible; // TODO: convert to `never` type in the stdlib
type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib
type Stream = $StructName;

fn merge(self) -> Self::Stream {
$StructName {
rng: utils::RandomGenerator::new(),
}
$StructName { }
}
}
};
($StructName:ident $($F:ident)+) => {
($mod_name:ident $StructName:ident $($F:ident)+) => {
mod $mod_name {
#[derive(Debug)]
#[pin_project::pin_project]
pub(super) struct Streams<$($F,)+>($(#[pin] pub(super) $F,)+);
}

/// A stream that merges multiple streams into a single stream.
///
/// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
Expand All @@ -60,9 +87,12 @@ macro_rules! impl_merge_tuple {
where $(
$F: Stream<Item = T>,
)* {
done: bool,
$(#[pin] $F: $F,)*
#[pin] streams: $mod_name::Streams<$($F,)+>,
rng: utils::RandomGenerator,
wakers: WakerArray<{utils::tuple_len!($($F,)+)}>,
state: PollArray<{utils::tuple_len!($($F,)+)}>,
completed: u8,
done: bool,
}

impl<T, $($F),*> fmt::Debug for $StructName<T, $($F),*>
Expand All @@ -72,7 +102,7 @@ macro_rules! impl_merge_tuple {
)* {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Merge")
$(.field(&self.$F))*
.field(&self.streams)
.finish()
}
}
Expand All @@ -84,33 +114,41 @@ macro_rules! impl_merge_tuple {
type Item = T;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let this = self.project();

// Return early in case we're polled again after completion.
if *this.done {
return Poll::Ready(None);
}
let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

const LEN: u32 = utils::tuple_len!($($F,)*);
const PERMUTATIONS: u32 = utils::permutations(LEN);
let r = this.rng.generate(PERMUTATIONS);
let mut pending = false;
for i in 0..LEN {
utils::gen_conditions!(LEN, i, r, this, cx, poll_next, {
Poll::Ready(Some(value)) => return Poll::Ready(Some(value)),
Poll::Ready(None) => continue,
Poll::Pending => {
pending = true;
continue
},
}, $($F,)*);
}
if pending {
Poll::Pending
} else {
*this.done = true;
Poll::Ready(None)
const LEN: u8 = utils::tuple_len!($($F,)*);
let r = this.rng.generate(LEN as u32) as u8;

let mut streams = this.streams.project();

// Iterate over our streams one-by-one. If a stream yields a value,
// we exit early. By default we'll return `Poll::Ready(None)`, but
// this changes if we encounter a `Poll::Pending`.
for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) {
if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
} else if !readiness.clear_ready(index) || this.state[index].is_consumed() {
continue;
}

// unlock readiness so we don't deadlock when polling
drop(readiness);

// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(index).unwrap());

// poll the `streams.{index}` stream
utils::tuple_for_each!(poll_stream (index, this, streams, cx, LEN) $($F)*);

// Lock readiness so we can use it again
readiness = this.wakers.readiness().lock().unwrap();
}

Poll::Pending
}
}

Expand All @@ -124,32 +162,36 @@ macro_rules! impl_merge_tuple {
fn merge(self) -> Self::Stream {
let ($($F,)*): ($($F,)*) = self;
$StructName {
done: false,
streams: $mod_name::Streams($($F.into_stream(),)+),
rng: utils::RandomGenerator::new(),
$($F: $F.into_stream()),*
wakers: WakerArray::new(),
state: PollArray::new(),
completed: 0,
done: false,
}
}
}
};
}

impl_merge_tuple! { Merge0 }
impl_merge_tuple! { Merge1 A }
impl_merge_tuple! { Merge2 A B }
impl_merge_tuple! { Merge3 A B C }
impl_merge_tuple! { Merge4 A B C D }
impl_merge_tuple! { Merge5 A B C D E }
impl_merge_tuple! { Merge6 A B C D E F }
impl_merge_tuple! { Merge7 A B C D E F G }
impl_merge_tuple! { Merge8 A B C D E F G H }
impl_merge_tuple! { Merge9 A B C D E F G H I }
impl_merge_tuple! { Merge10 A B C D E F G H I J }
impl_merge_tuple! { Merge11 A B C D E F G H I J K }
impl_merge_tuple! { Merge12 A B C D E F G H I J K L }
impl_merge_tuple! { merge0 Merge0 }
impl_merge_tuple! { merge1 Merge1 A }
impl_merge_tuple! { merge2 Merge2 A B }
impl_merge_tuple! { merge3 Merge3 A B C }
impl_merge_tuple! { merge4 Merge4 A B C D }
impl_merge_tuple! { merge5 Merge5 A B C D E }
impl_merge_tuple! { merge6 Merge6 A B C D E F }
impl_merge_tuple! { merge7 Merge7 A B C D E F G }
impl_merge_tuple! { merge8 Merge8 A B C D E F G H }
impl_merge_tuple! { merge9 Merge9 A B C D E F G H I }
impl_merge_tuple! { merge10 Merge10 A B C D E F G H I J }
impl_merge_tuple! { merge11 Merge11 A B C D E F G H I J K }
impl_merge_tuple! { merge12 Merge12 A B C D E F G H I J K L }

#[cfg(test)]
mod tests {
use super::*;
use futures::task::LocalSpawnExt;
use futures_lite::future::block_on;
use futures_lite::prelude::*;
use futures_lite::stream;
Expand Down Expand Up @@ -228,4 +270,60 @@ mod tests {
assert_eq!(counter, 10);
})
}

/// This test case uses channels so we'll have streams that return Pending from time to time.
///
/// The purpose of this test is to make sure we have the waking logic working.
#[test]
fn merge_channels() {
use std::cell::RefCell;
use std::rc::Rc;

use futures::executor::LocalPool;

use crate::future::Join;
use crate::utils::channel::local_channel;

let mut pool = LocalPool::new();

let done = Rc::new(RefCell::new(false));
let done2 = done.clone();

pool.spawner()
.spawn_local(async move {
let (send1, receive1) = local_channel();
let (send2, receive2) = local_channel();
let (send3, receive3) = local_channel();

let (count, ()) = (
async {
(receive1, receive2, receive3)
.merge()
.fold(0, |a, b| a + b)
.await
},
async {
for i in 1..=4 {
send1.send(i);
send2.send(i);
send3.send(i);
}
drop(send1);
drop(send2);
drop(send3);
},
)
.join()
.await;

assert_eq!(count, 30);

*done2.borrow_mut() = true;
})
.unwrap();

while !*done.borrow() {
pool.run_until_stalled()
}
}
}

0 comments on commit 288bcce

Please sign in to comment.