Skip to content

Commit

Permalink
Merge pull request #29 from yoshuawuyts/remove-maybe-done-vec-join
Browse files Browse the repository at this point in the history
Remove `MaybeDone` from `impl Join for Vec`
  • Loading branch information
yoshuawuyts authored Nov 7, 2022
2 parents 2c117be + 1dd997c commit 1b422ee
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 52 deletions.
8 changes: 2 additions & 6 deletions src/future/into_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ impl<Fut: Future> IntoFuture for Vec<Fut> {
type IntoFuture = crate::future::join::vec::Join<Fut>;

fn into_future(self) -> Self::IntoFuture {
let elems = self
.into_iter()
.map(|fut| MaybeDone::new(core::future::IntoFuture::into_future(fut)))
.collect::<Box<_>>()
.into();
crate::future::join::vec::Join::new(elems)
use crate::future::join::vec::Join;
Join::new(self.into_iter().collect())
}
}

Expand Down
142 changes: 107 additions & 35 deletions src/future/join/vec.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,59 @@
use super::Join as JoinTrait;
use crate::utils::iter_pin_mut;
use crate::utils::MaybeDone;
use crate::utils::{iter_pin_mut_vec, PollState};

use core::fmt;
use core::future::{Future, IntoFuture};
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::boxed::Box;
use std::mem::{self, MaybeUninit};
use std::vec::Vec;

impl<Fut> JoinTrait for Vec<Fut>
where
Fut: IntoFuture,
{
type Output = Vec<Fut::Output>;
type Future = Join<Fut::IntoFuture>;

fn join(self) -> Self::Future {
let elems = self
.into_iter()
.map(|fut| MaybeDone::new(fut.into_future()))
.collect::<Box<_>>()
.into();
Join::new(elems)
}
}
use pin_project::{pin_project, pinned_drop};

/// Waits for two similarly-typed futures to complete.
///
/// Awaits multiple futures simultaneously, returning the output of the
/// futures once both complete.
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[pin_project(PinnedDrop)]
pub struct Join<Fut>
where
Fut: Future,
{
elems: Pin<Box<[MaybeDone<Fut>]>>,
consumed: bool,
pending: usize,
items: Vec<MaybeUninit<<Fut as Future>::Output>>,
state: Vec<PollState>,
#[pin]
futures: Vec<Fut>,
}

impl<Fut> Join<Fut>
where
Fut: Future,
{
pub(crate) fn new(elems: Pin<Box<[MaybeDone<Fut>]>>) -> Self {
Self { elems }
pub(crate) fn new(futures: Vec<Fut>) -> Self {
Join {
consumed: false,
pending: futures.len(),
items: std::iter::repeat_with(|| MaybeUninit::uninit())
.take(futures.len())
.collect(),
state: vec![PollState::default(); futures.len()],
futures,
}
}
}

impl<Fut> JoinTrait for Vec<Fut>
where
Fut: IntoFuture,
{
type Output = Vec<Fut::Output>;
type Future = Join<Fut::IntoFuture>;

fn join(self) -> Self::Future {
Join::new(self.into_iter().map(IntoFuture::into_future).collect())
}
}

Expand All @@ -54,7 +63,7 @@ where
Fut::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Join").field("elems", &self.elems).finish()
f.debug_list().entries(self.state.iter()).finish()
}
}

Expand All @@ -64,23 +73,86 @@ where
{
type Output = Vec<Fut::Output>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut all_done = true;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

for elem in iter_pin_mut(self.elems.as_mut()) {
if elem.poll(cx).is_pending() {
all_done = false;
assert!(
!*this.consumed,
"Futures must not be polled after completing"
);

// Poll all futures
let futures = this.futures.as_mut();
for (i, fut) in iter_pin_mut_vec(futures).enumerate() {
if this.state[i].is_pending() {
if let Poll::Ready(value) = fut.poll(cx) {
this.items[i] = MaybeUninit::new(value);
this.state[i] = PollState::Done;
*this.pending -= 1;
}
}
}

if all_done {
let mut elems = mem::replace(&mut self.elems, Box::pin([]));
let result = iter_pin_mut(elems.as_mut())
.map(|e| e.take().unwrap())
.collect();
Poll::Ready(result)
// Check whether we're all done now or need to keep going.
if *this.pending == 0 {
// Mark all data as "consumed" before we take it
*this.consumed = true;
this.state.iter_mut().for_each(|state| {
debug_assert!(state.is_done(), "Future should have reached a `Done` state");
*state = PollState::Consumed;
});

// SAFETY: we've checked with the state that all of our outputs have been
// filled, which means we're ready to take the data and assume it's initialized.
let items = unsafe {
let items = mem::take(this.items);
mem::transmute::<_, Vec<Fut::Output>>(items)
};
Poll::Ready(items)
} else {
Poll::Pending
}
}
}

/// Drop the already initialized values on cancellation.
#[pinned_drop]
impl<Fut> PinnedDrop for Join<Fut>
where
Fut: Future,
{
fn drop(self: Pin<&mut Self>) {
let this = self.project();

// Get the indexes of the initialized values.
let indexes = this
.state
.iter_mut()
.enumerate()
.filter(|(_, state)| state.is_done())
.map(|(i, _)| i);

// Drop each value at the index.
for i in indexes {
// SAFETY: we've just filtered down to *only* the initialized values.
// We can assume they're initialized, and this is where we drop them.
unsafe { this.items[i].assume_init_drop() };
}
}
}

#[cfg(test)]
mod test {
use super::*;
use std::future;

#[test]
fn smoke() {
futures_lite::future::block_on(async {
let res = vec![future::ready("hello"), future::ready("world")]
.join()
.await;
assert_eq!(res, vec!["hello", "world"]);
});
}
}
4 changes: 3 additions & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
mod fuse;
mod maybe_done;
mod pin;
mod poll_state;
mod rng;

pub(crate) use fuse::Fuse;
pub(crate) use maybe_done::MaybeDone;
pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut};
pub(crate) use pin::{get_pin_mut, get_pin_mut_from_vec, iter_pin_mut, iter_pin_mut_vec};
pub(crate) use poll_state::PollState;
pub(crate) use rng::random;
19 changes: 9 additions & 10 deletions src/utils/pin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ pub(crate) fn iter_pin_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<
.map(|t| unsafe { Pin::new_unchecked(t) })
}

// From: Yosh made this one up, hehehe
// #[cfg(feature = "unstable")]
// pub(crate) fn pin_project_array<T, const N: usize>(slice: Pin<&mut [T; N]>) -> [Pin<&mut T>; N] {
// // SAFETY: `std` _could_ make this unsound if it were to decide Pin's
// // invariants aren't required to transmit through arrays. Otherwise this has
// // the same safety as a normal field pin projection.
// unsafe { slice.get_unchecked_mut() }
// .each_mut()
// .map(|t| unsafe { Pin::new_unchecked(t) })
// }
// From: `futures_rs::join_all!` -- https://github.com/rust-lang/futures-rs/blob/b48eb2e9a9485ef7388edc2f177094a27e08e28b/futures-util/src/future/join_all.rs#L18-L23
pub(crate) fn iter_pin_mut_vec<T>(slice: Pin<&mut Vec<T>>) -> impl Iterator<Item = Pin<&mut T>> {
// SAFETY: `std` _could_ make this unsound if it were to decide Pin's
// invariants aren't required to transmit through slices. Otherwise this has
// the same safety as a normal field pin projection.
unsafe { slice.get_unchecked_mut() }
.iter_mut()
.map(|t| unsafe { Pin::new_unchecked(t) })
}

/// Returns a pinned mutable reference to an element or subslice depending on the
/// type of index (see `get`) or `None` if the index is out of bounds.
Expand Down
39 changes: 39 additions & 0 deletions src/utils/poll_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/// Enumerate the current poll state.
#[derive(Debug, Clone, Copy, Default)]
pub(crate) enum PollState {
/// Polling the underlying future.
#[default]
Pending,
/// Data has been written to the output structure
/// and the future should no longer be polled.
Done,
/// Data has been consumed from the output structure,
/// and we should no longer reason about it.
Consumed,
}

impl PollState {
/// Returns `true` if the metadata is [`Pending`].
///
/// [`Pending`]: Metadata::Pending
#[must_use]
pub(crate) fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}

/// Returns `true` if the poll state is [`Done`].
///
/// [`Done`]: PollState::Done
#[must_use]
pub(crate) fn is_done(&self) -> bool {
matches!(self, Self::Done)
}

/// Returns `true` if the poll state is [`Consumed`].
///
/// [`Consumed`]: PollState::Consumed
#[must_use]
pub(crate) fn is_consumed(&self) -> bool {
matches!(self, Self::Consumed)
}
}

0 comments on commit 1b422ee

Please sign in to comment.