From 8d845a7cab2b89969393d79e7386f0833d8eeee4 Mon Sep 17 00:00:00 2001 From: Donghyun Kim Date: Mon, 9 Sep 2024 20:46:31 +0900 Subject: [PATCH] use internal `message_channel` instead of tokio channels --- .../docs/frequently-asked-questions.md | 2 +- documentation/docs/messaging.md | 4 +- documentation/docs/tutorial.md | 4 +- flutter_ffi_plugin/bin/src/message.dart | 26 ++-- .../example/native/hub/Cargo.toml | 9 +- .../native/hub/src/sample_functions.rs | 6 +- .../template/native/hub/Cargo.toml | 4 +- .../template/native/hub/src/lib.rs | 2 +- rust_crate/src/channel.rs | 124 ++++++++++++++++++ rust_crate/src/error.rs | 8 ++ rust_crate/src/lib.rs | 2 + 11 files changed, 161 insertions(+), 30 deletions(-) create mode 100644 rust_crate/src/channel.rs diff --git a/documentation/docs/frequently-asked-questions.md b/documentation/docs/frequently-asked-questions.md index 8c3a6a28..ed2507c9 100644 --- a/documentation/docs/frequently-asked-questions.md +++ b/documentation/docs/frequently-asked-questions.md @@ -236,7 +236,7 @@ onPressed: () async { pub async fn respond() -> Result<()> { use messages::tutorial_resource::*; - let mut receiver = MyUniqueInput::get_dart_signal_receiver()?; + let receiver = MyUniqueInput::get_dart_signal_receiver()?; while let Some(dart_signal) = receiver.recv().await { let my_unique_input = dart_signal.message; MyUniqueOutput { diff --git a/documentation/docs/messaging.md b/documentation/docs/messaging.md index d0ffa260..1f85e31f 100644 --- a/documentation/docs/messaging.md +++ b/documentation/docs/messaging.md @@ -68,7 +68,7 @@ MyDataInput( ... ).sendSignalToRust(); ``` ```rust title="Rust" -let mut receiver = MyDataInput::get_dart_signal_receiver()?; +let receiver = MyDataInput::get_dart_signal_receiver()?; while let Some(dart_signal) = receiver.recv().await { let message: MyDataInput = dart_signal.message; // Custom Rust logic here @@ -88,7 +88,7 @@ MyDataInput( ... ).sendSignalToRust(binary); ``` ```rust title="Rust" -let mut receiver = MyDataInput::get_dart_signal_receiver()?; +let receiver = MyDataInput::get_dart_signal_receiver()?; while let Some(dart_signal) = receiver.recv().await { let message: MyDataInput = dart_signal.message; let binary: Vec = dart_signal.binary; diff --git a/documentation/docs/tutorial.md b/documentation/docs/tutorial.md index 3ccd0618..68d871a8 100644 --- a/documentation/docs/tutorial.md +++ b/documentation/docs/tutorial.md @@ -65,7 +65,7 @@ use rinf::debug_print; pub async fn calculate_precious_data() -> Result<()> { use messages::tutorial_messages::*; - let mut receiver = MyPreciousData::get_dart_signal_receiver()?; // GENERATED + let receiver = MyPreciousData::get_dart_signal_receiver()?; // GENERATED while let Some(dart_signal) = receiver.recv().await { let my_precious_data = dart_signal.message; @@ -222,7 +222,7 @@ pub async fn tell_treasure() -> Result<()> { let mut current_value: i32 = 1; - let mut receiver = MyTreasureInput::get_dart_signal_receiver()?; // GENERATED + let receiver = MyTreasureInput::get_dart_signal_receiver()?; // GENERATED while let Some(_) = receiver.recv().await { MyTreasureOutput { current_value }.send_signal_to_dart(); // GENERATED current_value += 1; diff --git a/flutter_ffi_plugin/bin/src/message.dart b/flutter_ffi_plugin/bin/src/message.dart index 78e4a8f0..7cfbe841 100644 --- a/flutter_ffi_plugin/bin/src/message.dart +++ b/flutter_ffi_plugin/bin/src/message.dart @@ -266,11 +266,13 @@ import 'package:rinf/rinf.dart'; ''' #![allow(unused_imports)] -use crate::tokio; use prost::Message; -use rinf::{debug_print, send_rust_signal, DartSignal, RinfError}; +use rinf::{ + debug_print, message_channel, send_rust_signal, + DartSignal, MessageReceiver, MessageSender, + RinfError, +}; use std::sync::Mutex; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; ''', atFront: true, @@ -288,21 +290,21 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; rustPath, ''' type ${messageName}Cell = Mutex>, - Option>>, + MessageSender>, + Option>>, )>>; pub static ${snakeName.toUpperCase()}_CHANNEL: ${messageName}Cell = Mutex::new(None); impl ${normalizePascal(messageName)} { pub fn get_dart_signal_receiver() - -> Result>, RinfError> + -> Result>, RinfError> { let mut guard = ${snakeName.toUpperCase()}_CHANNEL .lock() .map_err(|_| RinfError::LockMessageChannel)?; if guard.is_none() { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = message_channel(); guard.replace((sender, Some(receiver))); } let (mut sender, mut receiver_option) = guard @@ -313,7 +315,7 @@ impl ${normalizePascal(messageName)} { // which is now closed. if sender.is_closed() { let receiver; - (sender, receiver) = unbounded_channel(); + (sender, receiver) = message_channel(); receiver_option = Some(receiver); } let receiver = receiver_option.ok_or(RinfError::MessageReceiverTaken)?; @@ -421,13 +423,11 @@ impl ${normalizePascal(messageName)} { #![allow(unused_imports)] #![allow(unused_mut)] -use crate::tokio; use prost::Message; -use rinf::{debug_print, DartSignal, RinfError}; +use rinf::{debug_print, message_channel, DartSignal, RinfError}; use std::collections::HashMap; use std::error::Error; use std::sync::OnceLock; -use tokio::sync::mpsc::unbounded_channel; type Handler = dyn Fn(&[u8], &[u8]) -> Result<(), RinfError> + Send + Sync; type DartSignalHandlers = HashMap>; @@ -471,7 +471,7 @@ new_hash_map.insert( .lock() .map_err(|_| RinfError::LockMessageChannel)?; if guard.is_none() { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = message_channel(); guard.replace((sender, Some(receiver))); } let mut pair = guard @@ -481,7 +481,7 @@ new_hash_map.insert( // a sender from the previous run already exists // which is now closed. if pair.0.is_closed() { - let (sender, receiver) = unbounded_channel(); + let (sender, receiver) = message_channel(); guard.replace((sender, Some(receiver))); pair = guard .as_ref() diff --git a/flutter_ffi_plugin/example/native/hub/Cargo.toml b/flutter_ffi_plugin/example/native/hub/Cargo.toml index 35a09f04..a0dc4368 100644 --- a/flutter_ffi_plugin/example/native/hub/Cargo.toml +++ b/flutter_ffi_plugin/example/native/hub/Cargo.toml @@ -14,12 +14,7 @@ crate-type = ["lib", "cdylib", "staticlib"] [dependencies] rinf = "6.15.0" prost = "0.13.0" -tokio = { version = "1", features = ["rt", "sync", "macros", "time"] } -tokio_with_wasm = { version = "0.6.3", features = [ - "rt", - "sync", - "macros", - "time", -] } +tokio = { version = "1", features = ["rt", "sync", "time"] } +tokio_with_wasm = { version = "0.6.3", features = ["rt", "sync", "time"] } wasm-bindgen = "0.2.93" sample_crate = { path = "../sample_crate" } diff --git a/flutter_ffi_plugin/example/native/hub/src/sample_functions.rs b/flutter_ffi_plugin/example/native/hub/src/sample_functions.rs index eaa8f4d7..14ff51c7 100755 --- a/flutter_ffi_plugin/example/native/hub/src/sample_functions.rs +++ b/flutter_ffi_plugin/example/native/hub/src/sample_functions.rs @@ -22,7 +22,7 @@ pub async fn tell_numbers() -> Result<()> { use messages::counter_number::*; // Stream getter is generated from a marked Protobuf message. - let mut receiver = SampleNumberInput::get_dart_signal_receiver()?; + let receiver = SampleNumberInput::get_dart_signal_receiver()?; while let Some(dart_signal) = receiver.recv().await { // Extract values from the message received from Dart. // This message is a type that's declared in its Protobuf file. @@ -154,7 +154,9 @@ pub async fn run_debug_tests() -> Result<()> { tokio::time::sleep(Duration::from_secs(3)).await; debug_print!("Third future finished."); }; - tokio::join!(join_first, join_second, join_third); + join_first.await; + join_second.await; + join_third.await; // Avoid blocking the async event loop by yielding. let mut last_time = sample_crate::get_current_time(); diff --git a/flutter_ffi_plugin/template/native/hub/Cargo.toml b/flutter_ffi_plugin/template/native/hub/Cargo.toml index 95ebfc3d..3cd67d70 100644 --- a/flutter_ffi_plugin/template/native/hub/Cargo.toml +++ b/flutter_ffi_plugin/template/native/hub/Cargo.toml @@ -14,8 +14,8 @@ crate-type = ["lib", "cdylib", "staticlib"] [dependencies] rinf = "6.15.0" prost = "0.12.6" -tokio = { version = "1", features = ["sync", "rt"] } +tokio = { version = "1", features = ["rt"] } # Uncomment below to target the web. -# tokio_with_wasm = { version = "0.6.0", features = ["sync", "rt"] } +# tokio_with_wasm = { version = "0.6.0", features = ["rt"] } # wasm-bindgen = "0.2.92" diff --git a/flutter_ffi_plugin/template/native/hub/src/lib.rs b/flutter_ffi_plugin/template/native/hub/src/lib.rs index 60f4aa22..981f304e 100644 --- a/flutter_ffi_plugin/template/native/hub/src/lib.rs +++ b/flutter_ffi_plugin/template/native/hub/src/lib.rs @@ -24,7 +24,7 @@ async fn communicate() -> Result<()> { // Send signals to Dart like below. SmallNumber { number: 7 }.send_signal_to_dart(); // Get receivers that listen to Dart signals like below. - let mut receiver = SmallText::get_dart_signal_receiver()?; + let receiver = SmallText::get_dart_signal_receiver()?; while let Some(dart_signal) = receiver.recv().await { let message: SmallText = dart_signal.message; rinf::debug_print!("{message:?}"); diff --git a/rust_crate/src/channel.rs b/rust_crate/src/channel.rs new file mode 100644 index 00000000..6ce09129 --- /dev/null +++ b/rust_crate/src/channel.rs @@ -0,0 +1,124 @@ +use crate::error::RinfError; +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; + +pub struct MessageSender { + inner: Arc>>, +} + +pub struct MessageReceiver { + inner: Arc>>, +} + +struct MessageChannel { + queue: VecDeque, // Message queue for storing multiple messages + waker: Option, + sender_dropped: bool, // Track whether the sender has been dropped + receiver_dropped: bool, // Track whether the receiver has been dropped +} + +impl MessageSender { + // Send a message and store it in the queue + pub fn send(&self, msg: T) -> Result<(), RinfError> { + let mut inner = match self.inner.lock() { + Ok(inner) => inner, + Err(_) => return Err(RinfError::BrokenMessageChannel), + }; + + // Return an error if the receiver has been dropped + if inner.receiver_dropped { + return Err(RinfError::ClosedMessageChannel); + } + + // Enqueue the message + inner.queue.push_back(msg); + if let Some(waker) = inner.waker.take() { + waker.wake(); // Wake the receiver if it's waiting + } + Ok(()) + } + + // Check if the receiver is still alive + pub fn is_closed(&self) -> bool { + let inner = self.inner.lock().unwrap(); + inner.receiver_dropped + } +} + +impl Drop for MessageSender { + fn drop(&mut self) { + let mut inner = self.inner.lock().unwrap(); + inner.sender_dropped = true; // Mark that the sender has been dropped + if let Some(waker) = inner.waker.take() { + waker.wake(); // Wake the receiver in case it's waiting + } + } +} + +impl MessageReceiver { + // Receive the next message from the queue asynchronously + pub async fn recv(&self) -> Option { + RecvFuture { + inner: self.inner.clone(), + } + .await + } +} + +impl Drop for MessageReceiver { + fn drop(&mut self) { + let mut inner = self.inner.lock().unwrap(); + inner.receiver_dropped = true; // Mark that the receiver has been dropped + if let Some(waker) = inner.waker.take() { + waker.wake(); // Wake any waiting sender + } + } +} + +// Future implementation for receiving a message +struct RecvFuture { + inner: Arc>>, +} + +impl Future for RecvFuture { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = match self.inner.lock() { + Ok(inner) => inner, + Err(_) => return Poll::Ready(None), // Return None on poisoned mutex + }; + + // Check if there are any messages in the queue + if let Some(msg) = inner.queue.pop_front() { + return Poll::Ready(Some(msg)); // Return the next message + } + + // If no messages and the sender is dropped, return None + if inner.sender_dropped && inner.queue.is_empty() { + Poll::Ready(None) + } else { + inner.waker = Some(cx.waker().clone()); // Set the waker for later notification + Poll::Pending // No message available, wait + } + } +} + +// Create the message channel with a message queue +pub fn message_channel() -> (MessageSender, MessageReceiver) { + let channel = Arc::new(Mutex::new(MessageChannel { + queue: VecDeque::new(), // Initialize an empty message queue + waker: None, + sender_dropped: false, // Initially, the sender is not dropped + receiver_dropped: false, // Initially, the receiver is not dropped + })); + ( + MessageSender { + inner: channel.clone(), + }, + MessageReceiver { inner: channel }, + ) +} diff --git a/rust_crate/src/error.rs b/rust_crate/src/error.rs index c1d99107..e0377050 100644 --- a/rust_crate/src/error.rs +++ b/rust_crate/src/error.rs @@ -7,6 +7,8 @@ pub enum RinfError { NoDartIsolate, BuildRuntime, LockMessageChannel, + BrokenMessageChannel, + ClosedMessageChannel, NoMessageChannel, MessageReceiverTaken, DecodeMessage, @@ -28,6 +30,12 @@ impl fmt::Display for RinfError { RinfError::LockMessageChannel => { write!(f, "Could not acquire the message channel lock.") } + RinfError::BrokenMessageChannel => { + write!(f, "Message channel is broken.",) + } + RinfError::ClosedMessageChannel => { + write!(f, "Message channel is closed.",) + } RinfError::NoMessageChannel => { write!(f, "Message channel was not created.",) } diff --git a/rust_crate/src/lib.rs b/rust_crate/src/lib.rs index fae10f70..92319676 100644 --- a/rust_crate/src/lib.rs +++ b/rust_crate/src/lib.rs @@ -1,3 +1,4 @@ +mod channel; mod error; mod macros; @@ -7,5 +8,6 @@ mod interface_os; #[cfg(target_family = "wasm")] mod interface_web; +pub use channel::*; pub use error::*; pub use interface::*;