diff --git a/task/src/event_stream.rs b/task/src/event_stream.rs index 107b55ee04..72dfe09b8d 100644 --- a/task/src/event_stream.rs +++ b/task/src/event_stream.rs @@ -151,3 +151,97 @@ impl EventStream for ChannelStream { inner.subscribers.remove(&uid); } } + +pub mod test { + use crate::*; + use async_compatibility_layer::art::{async_sleep, async_spawn}; + use std::time::Duration; + #[derive(Clone, Debug, PartialEq, Eq)] + pub enum TestMessage { + One, + Two, + Three, + } + + impl PassType for TestMessage {} + + #[cfg(test)] + #[cfg_attr( + feature = "tokio-executor", + tokio::test(flavor = "multi_thread", worker_threads = 20) + )] + #[cfg_attr(feature = "async-std-executor", async_std::test)] + async fn test_channel_stream_basic() { + use crate::task::FilterEvent; + + use super::ChannelStream; + + let channel_stream = ChannelStream::::new(); + let (mut stream, _) = channel_stream.subscribe(FilterEvent::default()).await; + let dup_channel_stream = channel_stream.clone(); + + let dup_dup_channel_stream = channel_stream.clone(); + + async_spawn(async move { + let (mut stream, _) = dup_channel_stream.subscribe(FilterEvent::default()).await; + assert!(stream.next().await.unwrap() == TestMessage::Three); + assert!(stream.next().await.unwrap() == TestMessage::One); + assert!(stream.next().await.unwrap() == TestMessage::Two); + }); + + async_spawn(async move { + dup_dup_channel_stream.publish(TestMessage::Three).await; + dup_dup_channel_stream.publish(TestMessage::One).await; + dup_dup_channel_stream.publish(TestMessage::Two).await; + }); + async_sleep(Duration::new(3, 0)).await; + + assert!(stream.next().await.unwrap() == TestMessage::Three); + assert!(stream.next().await.unwrap() == TestMessage::One); + assert!(stream.next().await.unwrap() == TestMessage::Two); + } + + #[cfg(test)] + #[cfg_attr( + feature = "tokio-executor", + tokio::test(flavor = "multi_thread", worker_threads = 1) + )] + #[cfg_attr(feature = "async-std-executor", async_std::test)] + async fn test_channel_stream_xtreme() { + use crate::task::FilterEvent; + + use super::ChannelStream; + + let channel_stream = ChannelStream::::new(); + let (mut stream, _) = channel_stream.subscribe(FilterEvent::default()).await; + + let mut streams = Vec::new(); + + for _i in 0..1000 { + let dup_channel_stream = channel_stream.clone(); + let (mut stream, _) = dup_channel_stream.subscribe(FilterEvent::default()).await; + streams.push(stream); + } + + let dup_dup_channel_stream = channel_stream.clone(); + + for _i in 0..1000 { + let mut stream = streams.pop().unwrap(); + async_spawn(async move { + for event in [TestMessage::One, TestMessage::Two, TestMessage::Three] { + for _ in 0..100 { + assert!(stream.next().await.unwrap() == event); + } + } + }); + } + + async_spawn(async move { + for event in [TestMessage::One, TestMessage::Two, TestMessage::Three] { + for _ in 0..100 { + dup_dup_channel_stream.publish(event.clone()).await; + } + } + }); + } +} diff --git a/task/src/task_impls.rs b/task/src/task_impls.rs index 21f0465bc8..2e455dfb2a 100644 --- a/task/src/task_impls.rs +++ b/task/src/task_impls.rs @@ -257,6 +257,11 @@ pub mod test { #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct State {} + #[derive(Clone, Debug, Eq, PartialEq, Hash, Default)] + pub struct CounterState { + num_events_recved: u64 + } + #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum Event { Finished, @@ -268,6 +273,9 @@ pub mod test { impl TS for State {} impl PassType for State {} + impl TS for CounterState {} + impl PassType for CounterState {} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Message { Finished, @@ -279,6 +287,7 @@ pub mod test { // TODO fill in generics for stream pub type AppliedHSTWithEvent = HSTWithEvent, State>; + pub type AppliedHSTWithEventCounterState = HSTWithEvent, CounterState>; pub type AppliedHSTWithMessage = HSTWithMessage, State>; pub type AppliedHSTWithEventMessage = HSTWithEventAndMessage< @@ -327,6 +336,66 @@ pub mod test { )] #[cfg_attr(feature = "async-std-executor", async_std::test)] async fn test_task_with_event_stream() { + + setup_logging(); + let event_stream: event_stream::ChannelStream = event_stream::ChannelStream::new(); + let mut registry = GlobalRegistry::new(); + + let mut task_runner = crate::task_launcher::TaskRunner::default(); + + for i in 0..10000 { + let state = CounterState::default(); + let event_handler = HandleEvent(Arc::new(move |event, mut state: CounterState| { + async move { + + if let Event::Dummy = event { + state.num_events_recved += 1; + } + + + if state.num_events_recved == 100 { + (Some(HotShotTaskCompleted::ShutDown), state) + } else { + (None, state) + } + } + .boxed() + })); + let name = format!("Test Task {:?}", i).to_string(); + let built_task = TaskBuilder::::new(name.clone()) + .register_event_stream(event_stream.clone(), FilterEvent::default()) + .await + .register_registry(&mut registry) + .await + .register_state(state) + .register_event_handler(event_handler); + let id = built_task.get_task_id().unwrap(); + let result = AppliedHSTWithEventCounterState::build(built_task).launch(); + task_runner = task_runner.add_task(id, name, result); + } + + + async_spawn(async move { + for _ in 0..100 { + event_stream.publish(Event::Dummy).await; + } + }); + + let results = task_runner.launch().await; + for result in results { + assert!(result.1 == HotShotTaskCompleted::ShutDown); + } + + } + + + #[cfg(test)] + #[cfg_attr( + feature = "tokio-executor", + tokio::test(flavor = "multi_thread", worker_threads = 2) + )] + #[cfg_attr(feature = "async-std-executor", async_std::test)] + async fn test_task_with_event_stream_xtreme() { setup_logging(); let event_stream: event_stream::ChannelStream = event_stream::ChannelStream::new();