Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

add a helper for running futures concurrently #541

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions rust/xaynet-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ async-trait = "0.1.42"
base64 = "0.13.0"
bincode = "1.3.1"
derive_more = { version = "0.99.11", default-features = false, features = ["from"] }
# TODO: remove once concurrent_futures.rs was moved to the e2e package
futures = "0.3.12"
paste = "1.0.4"
serde = { version = "1.0.119", features = ["derive"] }
sodiumoxide = "0.2.6"
thiserror = "1.0.23"
# TODO (XN-1372): upgrade
# TODO: move to dev-dependencies once concurrent_futures.rs was moved to the e2e package
tokio = { version = "0.2.24", features = ["rt-core", "macros"] }
tracing = "0.1.22"
url = "2.2.0"
xaynet-core = { path = "../xaynet-core", version = "0.1.0" }
Expand All @@ -39,8 +44,6 @@ rand = "0.8.2"
mockall = "0.9.0"
num = { version = "0.3.1", features = ["serde"] }
serde_json = "1.0.61"
# TODO (XN-1372): upgrade
tokio = { version = "0.2.24", features = ["rt-core", "macros"] }
# TODO (XN-1372): can't upgrade yet because of tokio
tokio-test = "0.2.1"
xaynet-core = { path = "../xaynet-core", features = ["testutils"] }
Expand Down
10 changes: 4 additions & 6 deletions rust/xaynet-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,12 @@
//! ```

pub mod client;

mod message_encoder;
pub(crate) use self::message_encoder::MessageEncoder;

pub mod settings;

mod state_machine;
pub use state_machine::{LocalModelConfig, SerializableState, StateMachine, TransitionOutcome};

mod traits;
pub(crate) mod utils;

pub(crate) use self::message_encoder::MessageEncoder;
pub use self::traits::{ModelStore, Notify, XaynetClient};
pub use state_machine::{LocalModelConfig, SerializableState, StateMachine, TransitionOutcome};
138 changes: 138 additions & 0 deletions rust/xaynet-sdk/src/utils/concurrent_futures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#![allow(dead_code)]

use std::{
collections::VecDeque,
pin::Pin,
task::{Context, Poll},
};

use futures::{
stream::{FuturesUnordered, Stream},
Future,
};
use tokio::task::{JoinError, JoinHandle};

/// `ConcurrentFutures` can keep a capped number of futures running concurrently, and yield their
/// result as they finish. When the max number of concurrent futures is reached, new tasks are
/// queued until some in-flight futures finish.
pub struct ConcurrentFutures<T>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
/// In-flight futures.
running: FuturesUnordered<JoinHandle<T::Output>>,
/// Buffered tasks.
queued: VecDeque<T>,
/// Max number of concurrent futures.
max_in_flight: usize,
}

impl<T> ConcurrentFutures<T>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
pub fn new(max_in_flight: usize) -> Self {
Self {
running: FuturesUnordered::new(),
queued: VecDeque::new(),
max_in_flight,
}
}

pub fn push(&mut self, task: T) {
self.queued.push_back(task)
}
}

impl<T> Unpin for ConcurrentFutures<T>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
}

impl<T> Stream for ConcurrentFutures<T>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
type Item = Result<T::Output, JoinError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
while this.running.len() < this.max_in_flight {
if let Some(queued) = this.queued.pop_front() {
let handle = tokio::spawn(queued);
this.running.push(handle);
} else {
break;
}
}
Pin::new(&mut this.running).poll_next(cx)
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use futures::stream::StreamExt;
use tokio::time::delay_for;

use super::*;

#[tokio::test]
async fn test() {
let mut stream =
ConcurrentFutures::<Pin<Box<dyn Future<Output = u8> + Send + 'static>>>::new(2);

stream.push(Box::pin(async {
delay_for(Duration::from_millis(10_u64)).await;
1_u8
}));

stream.push(Box::pin(async {
delay_for(Duration::from_millis(28_u64)).await;
2_u8
}));

stream.push(Box::pin(async {
delay_for(Duration::from_millis(8_u64)).await;
3_u8
}));

stream.push(Box::pin(async {
delay_for(Duration::from_millis(2_u64)).await;
4_u8
}));

// poll_next() hasn't been called yet so all futures are queued
assert_eq!(stream.running.len(), 0);
assert_eq!(stream.queued.len(), 4);

// future 1 and 2 are spawned, then future 1 is ready
assert_eq!(stream.next().await.unwrap().unwrap(), 1);

// future 2 is pending, futures 3 and 4 are queued
assert_eq!(stream.running.len(), 1);
assert_eq!(stream.queued.len(), 2);

// future 3 is spawned, then future 3 is ready
assert_eq!(stream.next().await.unwrap().unwrap(), 3);

// future 2 is pending, future 4 is queued
assert_eq!(stream.running.len(), 1);
assert_eq!(stream.queued.len(), 1);

// future 4 is spawned, then future 4 is ready
assert_eq!(stream.next().await.unwrap().unwrap(), 4);

// future 2 is pending, then future 2 is ready
assert_eq!(stream.next().await.unwrap().unwrap(), 2);

// all futures have been resolved
assert_eq!(stream.running.len(), 0);
assert_eq!(stream.queued.len(), 0);
}
}
2 changes: 2 additions & 0 deletions rust/xaynet-sdk/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// TODO: move to the e2e package
pub mod concurrent_futures;
2 changes: 1 addition & 1 deletion rust/xaynet-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ derive_more = { version = "0.99.11", default-features = false, features = [
"into",
] }
displaydoc = "0.1.7"
futures = "0.3.11"
futures = "0.3.12"
hex = "0.4.2"
http = "0.2.3"
influxdb = "0.3.0"
Expand Down