From dba2e9254c260cba82f08d8d2e25aabd8f09cd17 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Wed, 30 Nov 2022 11:20:21 +0100 Subject: [PATCH] Implement perfect waking for array/vec Join Tries to implement #21 for array and vec Join. --- src/future/join/array.rs | 39 +++++++++++++++++++------ src/future/join/vec.rs | 46 ++++++++++++++++++++++-------- src/utils/poll_state/poll_state.rs | 1 + 3 files changed, 65 insertions(+), 21 deletions(-) diff --git a/src/future/join/array.rs b/src/future/join/array.rs index 50529f5..a673603 100644 --- a/src/future/join/array.rs +++ b/src/future/join/array.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::{self, PollArray}; +use crate::utils::{self, PollArray, WakerArray}; use core::array; use core::fmt; @@ -26,6 +26,7 @@ where consumed: bool, pending: usize, items: [MaybeUninit<::Output>; N], + wakers: WakerArray, state: PollArray, #[pin] futures: [Fut; N], @@ -41,6 +42,7 @@ where consumed: false, pending: N, items: array::from_fn(|_| MaybeUninit::uninit()), + wakers: WakerArray::new(), state: PollArray::new(), futures, } @@ -85,12 +87,32 @@ where "Futures must not be polled after completing" ); - // Poll all futures + // Mark futures as ready according to the wakers + { + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + for (i, state) in this.state.iter_mut().enumerate() { + if !state.is_consumed() && readiness.clear_ready(i) { + state.set_ready(); + } + } + } + + // Poll all ready futures for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() { - if this.state[i].is_pending() { - if let Poll::Ready(value) = fut.poll(cx) { + if this.state[i].is_ready() { + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); + + if let Poll::Ready(value) = fut.poll(&mut cx) { this.items[i] = MaybeUninit::new(value); - this.state[i].set_ready(); + this.state[i].set_consumed(); *this.pending -= 1; } } @@ -102,10 +124,9 @@ where *this.consumed = true; for state in this.state.iter_mut() { debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" + state.is_consumed(), + "Future should have reached a `Consumed` state" ); - state.set_consumed(); } let mut items = array::from_fn(|_| MaybeUninit::uninit()); @@ -135,7 +156,7 @@ where .state .iter_mut() .enumerate() - .filter(|(_, state)| state.is_ready()) + .filter(|(_, state)| state.is_consumed()) .map(|(i, _)| i); // Drop each value at the index. diff --git a/src/future/join/vec.rs b/src/future/join/vec.rs index d11a48e..1073ff0 100644 --- a/src/future/join/vec.rs +++ b/src/future/join/vec.rs @@ -1,5 +1,5 @@ use super::Join as JoinTrait; -use crate::utils::{iter_pin_mut_vec, PollVec}; +use crate::utils::{iter_pin_mut_vec, PollVec, WakerVec}; use core::fmt; use core::future::{Future, IntoFuture}; @@ -26,6 +26,7 @@ where consumed: bool, pending: usize, items: Vec::Output>>, + wakers: WakerVec, state: PollVec, #[pin] futures: Vec, @@ -36,13 +37,15 @@ where Fut: Future, { pub(crate) fn new(futures: Vec) -> Self { + let len = futures.len(); Join { consumed: false, - pending: futures.len(), + pending: len, items: std::iter::repeat_with(MaybeUninit::uninit) - .take(futures.len()) + .take(len) .collect(), - state: PollVec::new(futures.len()), + wakers: WakerVec::new(len), + state: PollVec::new(len), futures, } } @@ -84,14 +87,34 @@ where "Futures must not be polled after completing" ); - // Poll all futures + // Mark futures as ready according to the wakers + { + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); + + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } + + for (i, state) in this.state.iter_mut().enumerate() { + if !state.is_consumed() && readiness.clear_ready(i) { + state.set_ready(); + } + } + } + + // Poll all ready futures let futures = this.futures.as_mut(); let states = &mut this.state[..]; for (i, fut) in iter_pin_mut_vec(futures).enumerate() { - if states[i].is_pending() { - if let Poll::Ready(value) = fut.poll(cx) { + if states[i].is_ready() { + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(i).unwrap()); + + if let Poll::Ready(value) = fut.poll(&mut cx) { this.items[i] = MaybeUninit::new(value); - states[i].set_ready(); + states[i].set_consumed(); *this.pending -= 1; } } @@ -103,10 +126,9 @@ where *this.consumed = true; this.state.iter_mut().for_each(|state| { debug_assert!( - state.is_ready(), - "Future should have reached a `Ready` state" + state.is_consumed(), + "Future should have reached a `Consumed` state" ); - state.set_consumed(); }); // SAFETY: we've checked with the state that all of our outputs have been @@ -136,7 +158,7 @@ where .state .iter_mut() .enumerate() - .filter(|(_, state)| state.is_ready()) + .filter(|(_, state)| state.is_consumed()) .map(|(i, _)| i); // Drop each value at the index. diff --git a/src/utils/poll_state/poll_state.rs b/src/utils/poll_state/poll_state.rs index 35be7fe..1843ef4 100644 --- a/src/utils/poll_state/poll_state.rs +++ b/src/utils/poll_state/poll_state.rs @@ -17,6 +17,7 @@ impl PollState { /// Returns `true` if the metadata is [`Pending`][Self::Pending]. #[must_use] #[inline] + #[allow(unused)] pub(crate) fn is_pending(&self) -> bool { matches!(self, Self::Pending) }