Skip to content

Commit

Permalink
Implement perfect waking for array/vec Join
Browse files Browse the repository at this point in the history
Tries to implement yoshuawuyts#21 for array and vec Join.
  • Loading branch information
Swatinem committed Nov 30, 2022
1 parent 058da68 commit dba2e92
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
39 changes: 30 additions & 9 deletions src/future/join/array.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Join as JoinTrait;
use crate::utils::{self, PollArray};
use crate::utils::{self, PollArray, WakerArray};

use core::array;
use core::fmt;
Expand All @@ -26,6 +26,7 @@ where
consumed: bool,
pending: usize,
items: [MaybeUninit<<Fut as Future>::Output>; N],
wakers: WakerArray<N>,
state: PollArray<N>,
#[pin]
futures: [Fut; N],
Expand All @@ -41,6 +42,7 @@ where
consumed: false,
pending: N,
items: array::from_fn(|_| MaybeUninit::uninit()),
wakers: WakerArray::new(),
state: PollArray::new(),
futures,
}
Expand Down Expand Up @@ -85,12 +87,32 @@ where
"Futures must not be polled after completing"
);

// Poll all futures
// Mark futures as ready according to the wakers
{
let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
}

for (i, state) in this.state.iter_mut().enumerate() {
if !state.is_consumed() && readiness.clear_ready(i) {
state.set_ready();
}
}
}

// Poll all ready futures
for (i, fut) in utils::iter_pin_mut(this.futures.as_mut()).enumerate() {
if this.state[i].is_pending() {
if let Poll::Ready(value) = fut.poll(cx) {
if this.state[i].is_ready() {
// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(i).unwrap());

if let Poll::Ready(value) = fut.poll(&mut cx) {
this.items[i] = MaybeUninit::new(value);
this.state[i].set_ready();
this.state[i].set_consumed();
*this.pending -= 1;
}
}
Expand All @@ -102,10 +124,9 @@ where
*this.consumed = true;
for state in this.state.iter_mut() {
debug_assert!(
state.is_ready(),
"Future should have reached a `Ready` state"
state.is_consumed(),
"Future should have reached a `Consumed` state"
);
state.set_consumed();
}

let mut items = array::from_fn(|_| MaybeUninit::uninit());
Expand Down Expand Up @@ -135,7 +156,7 @@ where
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_ready())
.filter(|(_, state)| state.is_consumed())
.map(|(i, _)| i);

// Drop each value at the index.
Expand Down
46 changes: 34 additions & 12 deletions src/future/join/vec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Join as JoinTrait;
use crate::utils::{iter_pin_mut_vec, PollVec};
use crate::utils::{iter_pin_mut_vec, PollVec, WakerVec};

use core::fmt;
use core::future::{Future, IntoFuture};
Expand All @@ -26,6 +26,7 @@ where
consumed: bool,
pending: usize,
items: Vec<MaybeUninit<<Fut as Future>::Output>>,
wakers: WakerVec,
state: PollVec,
#[pin]
futures: Vec<Fut>,
Expand All @@ -36,13 +37,15 @@ where
Fut: Future,
{
pub(crate) fn new(futures: Vec<Fut>) -> Self {
let len = futures.len();
Join {
consumed: false,
pending: futures.len(),
pending: len,
items: std::iter::repeat_with(MaybeUninit::uninit)
.take(futures.len())
.take(len)
.collect(),
state: PollVec::new(futures.len()),
wakers: WakerVec::new(len),
state: PollVec::new(len),
futures,
}
}
Expand Down Expand Up @@ -84,14 +87,34 @@ where
"Futures must not be polled after completing"
);

// Poll all futures
// Mark futures as ready according to the wakers
{
let mut readiness = this.wakers.readiness().lock().unwrap();
readiness.set_waker(cx.waker());

if !readiness.any_ready() {
// Nothing is ready yet
return Poll::Pending;
}

for (i, state) in this.state.iter_mut().enumerate() {
if !state.is_consumed() && readiness.clear_ready(i) {
state.set_ready();
}
}
}

// Poll all ready futures
let futures = this.futures.as_mut();
let states = &mut this.state[..];
for (i, fut) in iter_pin_mut_vec(futures).enumerate() {
if states[i].is_pending() {
if let Poll::Ready(value) = fut.poll(cx) {
if states[i].is_ready() {
// Obtain the intermediate waker.
let mut cx = Context::from_waker(this.wakers.get(i).unwrap());

if let Poll::Ready(value) = fut.poll(&mut cx) {
this.items[i] = MaybeUninit::new(value);
states[i].set_ready();
states[i].set_consumed();
*this.pending -= 1;
}
}
Expand All @@ -103,10 +126,9 @@ where
*this.consumed = true;
this.state.iter_mut().for_each(|state| {
debug_assert!(
state.is_ready(),
"Future should have reached a `Ready` state"
state.is_consumed(),
"Future should have reached a `Consumed` state"
);
state.set_consumed();
});

// SAFETY: we've checked with the state that all of our outputs have been
Expand Down Expand Up @@ -136,7 +158,7 @@ where
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_ready())
.filter(|(_, state)| state.is_consumed())
.map(|(i, _)| i);

// Drop each value at the index.
Expand Down
1 change: 1 addition & 0 deletions src/utils/poll_state/poll_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ impl PollState {
/// Returns `true` if the metadata is [`Pending`][Self::Pending].
#[must_use]
#[inline]
#[allow(unused)]
pub(crate) fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
Expand Down

0 comments on commit dba2e92

Please sign in to comment.