Skip to content

Commit

Permalink
use internal message_channel instead of tokio channels
Browse files Browse the repository at this point in the history
  • Loading branch information
temeddix committed Sep 9, 2024
1 parent c709f9e commit 8d845a7
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 30 deletions.
2 changes: 1 addition & 1 deletion documentation/docs/frequently-asked-questions.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ onPressed: () async {
pub async fn respond() -> Result<()> {
use messages::tutorial_resource::*;
let mut receiver = MyUniqueInput::get_dart_signal_receiver()?;
let receiver = MyUniqueInput::get_dart_signal_receiver()?;
while let Some(dart_signal) = receiver.recv().await {
let my_unique_input = dart_signal.message;
MyUniqueOutput {
Expand Down
4 changes: 2 additions & 2 deletions documentation/docs/messaging.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ MyDataInput( ... ).sendSignalToRust();
```

```rust title="Rust"
let mut receiver = MyDataInput::get_dart_signal_receiver()?;
let receiver = MyDataInput::get_dart_signal_receiver()?;
while let Some(dart_signal) = receiver.recv().await {
let message: MyDataInput = dart_signal.message;
// Custom Rust logic here
Expand All @@ -88,7 +88,7 @@ MyDataInput( ... ).sendSignalToRust(binary);
```

```rust title="Rust"
let mut receiver = MyDataInput::get_dart_signal_receiver()?;
let receiver = MyDataInput::get_dart_signal_receiver()?;
while let Some(dart_signal) = receiver.recv().await {
let message: MyDataInput = dart_signal.message;
let binary: Vec<u8> = dart_signal.binary;
Expand Down
4 changes: 2 additions & 2 deletions documentation/docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use rinf::debug_print;
pub async fn calculate_precious_data() -> Result<()> {
use messages::tutorial_messages::*;

let mut receiver = MyPreciousData::get_dart_signal_receiver()?; // GENERATED
let receiver = MyPreciousData::get_dart_signal_receiver()?; // GENERATED
while let Some(dart_signal) = receiver.recv().await {
let my_precious_data = dart_signal.message;

Expand Down Expand Up @@ -222,7 +222,7 @@ pub async fn tell_treasure() -> Result<()> {

let mut current_value: i32 = 1;

let mut receiver = MyTreasureInput::get_dart_signal_receiver()?; // GENERATED
let receiver = MyTreasureInput::get_dart_signal_receiver()?; // GENERATED
while let Some(_) = receiver.recv().await {
MyTreasureOutput { current_value }.send_signal_to_dart(); // GENERATED
current_value += 1;
Expand Down
26 changes: 13 additions & 13 deletions flutter_ffi_plugin/bin/src/message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,13 @@ import 'package:rinf/rinf.dart';
'''
#![allow(unused_imports)]
use crate::tokio;
use prost::Message;
use rinf::{debug_print, send_rust_signal, DartSignal, RinfError};
use rinf::{
debug_print, message_channel, send_rust_signal,
DartSignal, MessageReceiver, MessageSender,
RinfError,
};
use std::sync::Mutex;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
''',
atFront: true,
Expand All @@ -288,21 +290,21 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
rustPath,
'''
type ${messageName}Cell = Mutex<Option<(
UnboundedSender<DartSignal<${normalizePascal(messageName)}>>,
Option<UnboundedReceiver<DartSignal<${normalizePascal(messageName)}>>>,
MessageSender<DartSignal<${normalizePascal(messageName)}>>,
Option<MessageReceiver<DartSignal<${normalizePascal(messageName)}>>>,
)>>;
pub static ${snakeName.toUpperCase()}_CHANNEL: ${messageName}Cell =
Mutex::new(None);
impl ${normalizePascal(messageName)} {
pub fn get_dart_signal_receiver()
-> Result<UnboundedReceiver<DartSignal<Self>>, RinfError>
-> Result<MessageReceiver<DartSignal<Self>>, RinfError>
{
let mut guard = ${snakeName.toUpperCase()}_CHANNEL
.lock()
.map_err(|_| RinfError::LockMessageChannel)?;
if guard.is_none() {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = message_channel();
guard.replace((sender, Some(receiver)));
}
let (mut sender, mut receiver_option) = guard
Expand All @@ -313,7 +315,7 @@ impl ${normalizePascal(messageName)} {
// which is now closed.
if sender.is_closed() {
let receiver;
(sender, receiver) = unbounded_channel();
(sender, receiver) = message_channel();
receiver_option = Some(receiver);
}
let receiver = receiver_option.ok_or(RinfError::MessageReceiverTaken)?;
Expand Down Expand Up @@ -421,13 +423,11 @@ impl ${normalizePascal(messageName)} {
#![allow(unused_imports)]
#![allow(unused_mut)]
use crate::tokio;
use prost::Message;
use rinf::{debug_print, DartSignal, RinfError};
use rinf::{debug_print, message_channel, DartSignal, RinfError};
use std::collections::HashMap;
use std::error::Error;
use std::sync::OnceLock;
use tokio::sync::mpsc::unbounded_channel;
type Handler = dyn Fn(&[u8], &[u8]) -> Result<(), RinfError> + Send + Sync;
type DartSignalHandlers = HashMap<i32, Box<Handler>>;
Expand Down Expand Up @@ -471,7 +471,7 @@ new_hash_map.insert(
.lock()
.map_err(|_| RinfError::LockMessageChannel)?;
if guard.is_none() {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = message_channel();
guard.replace((sender, Some(receiver)));
}
let mut pair = guard
Expand All @@ -481,7 +481,7 @@ new_hash_map.insert(
// a sender from the previous run already exists
// which is now closed.
if pair.0.is_closed() {
let (sender, receiver) = unbounded_channel();
let (sender, receiver) = message_channel();
guard.replace((sender, Some(receiver)));
pair = guard
.as_ref()
Expand Down
9 changes: 2 additions & 7 deletions flutter_ffi_plugin/example/native/hub/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ crate-type = ["lib", "cdylib", "staticlib"]
[dependencies]
rinf = "6.15.0"
prost = "0.13.0"
tokio = { version = "1", features = ["rt", "sync", "macros", "time"] }
tokio_with_wasm = { version = "0.6.3", features = [
"rt",
"sync",
"macros",
"time",
] }
tokio = { version = "1", features = ["rt", "sync", "time"] }
tokio_with_wasm = { version = "0.6.3", features = ["rt", "sync", "time"] }
wasm-bindgen = "0.2.93"
sample_crate = { path = "../sample_crate" }
6 changes: 4 additions & 2 deletions flutter_ffi_plugin/example/native/hub/src/sample_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub async fn tell_numbers() -> Result<()> {
use messages::counter_number::*;

// Stream getter is generated from a marked Protobuf message.
let mut receiver = SampleNumberInput::get_dart_signal_receiver()?;
let receiver = SampleNumberInput::get_dart_signal_receiver()?;
while let Some(dart_signal) = receiver.recv().await {
// Extract values from the message received from Dart.
// This message is a type that's declared in its Protobuf file.
Expand Down Expand Up @@ -154,7 +154,9 @@ pub async fn run_debug_tests() -> Result<()> {
tokio::time::sleep(Duration::from_secs(3)).await;
debug_print!("Third future finished.");
};
tokio::join!(join_first, join_second, join_third);
join_first.await;
join_second.await;
join_third.await;

// Avoid blocking the async event loop by yielding.
let mut last_time = sample_crate::get_current_time();
Expand Down
4 changes: 2 additions & 2 deletions flutter_ffi_plugin/template/native/hub/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ crate-type = ["lib", "cdylib", "staticlib"]
[dependencies]
rinf = "6.15.0"
prost = "0.12.6"
tokio = { version = "1", features = ["sync", "rt"] }
tokio = { version = "1", features = ["rt"] }

# Uncomment below to target the web.
# tokio_with_wasm = { version = "0.6.0", features = ["sync", "rt"] }
# tokio_with_wasm = { version = "0.6.0", features = ["rt"] }
# wasm-bindgen = "0.2.92"
2 changes: 1 addition & 1 deletion flutter_ffi_plugin/template/native/hub/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async fn communicate() -> Result<()> {
// Send signals to Dart like below.
SmallNumber { number: 7 }.send_signal_to_dart();
// Get receivers that listen to Dart signals like below.
let mut receiver = SmallText::get_dart_signal_receiver()?;
let receiver = SmallText::get_dart_signal_receiver()?;
while let Some(dart_signal) = receiver.recv().await {
let message: SmallText = dart_signal.message;
rinf::debug_print!("{message:?}");
Expand Down
124 changes: 124 additions & 0 deletions rust_crate/src/channel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use crate::error::RinfError;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

pub struct MessageSender<T> {
inner: Arc<Mutex<MessageChannel<T>>>,
}

pub struct MessageReceiver<T> {
inner: Arc<Mutex<MessageChannel<T>>>,
}

struct MessageChannel<T> {
queue: VecDeque<T>, // Message queue for storing multiple messages
waker: Option<Waker>,
sender_dropped: bool, // Track whether the sender has been dropped
receiver_dropped: bool, // Track whether the receiver has been dropped
}

impl<T> MessageSender<T> {
// Send a message and store it in the queue
pub fn send(&self, msg: T) -> Result<(), RinfError> {
let mut inner = match self.inner.lock() {
Ok(inner) => inner,
Err(_) => return Err(RinfError::BrokenMessageChannel),
};

// Return an error if the receiver has been dropped
if inner.receiver_dropped {
return Err(RinfError::ClosedMessageChannel);
}

// Enqueue the message
inner.queue.push_back(msg);
if let Some(waker) = inner.waker.take() {
waker.wake(); // Wake the receiver if it's waiting
}
Ok(())
}

// Check if the receiver is still alive
pub fn is_closed(&self) -> bool {
let inner = self.inner.lock().unwrap();
inner.receiver_dropped
}
}

impl<T> Drop for MessageSender<T> {
fn drop(&mut self) {
let mut inner = self.inner.lock().unwrap();
inner.sender_dropped = true; // Mark that the sender has been dropped
if let Some(waker) = inner.waker.take() {
waker.wake(); // Wake the receiver in case it's waiting
}
}
}

impl<T> MessageReceiver<T> {
// Receive the next message from the queue asynchronously
pub async fn recv(&self) -> Option<T> {
RecvFuture {
inner: self.inner.clone(),
}
.await
}
}

impl<T> Drop for MessageReceiver<T> {
fn drop(&mut self) {
let mut inner = self.inner.lock().unwrap();
inner.receiver_dropped = true; // Mark that the receiver has been dropped
if let Some(waker) = inner.waker.take() {
waker.wake(); // Wake any waiting sender
}
}
}

// Future implementation for receiving a message
struct RecvFuture<T> {
inner: Arc<Mutex<MessageChannel<T>>>,
}

impl<T> Future for RecvFuture<T> {
type Output = Option<T>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = match self.inner.lock() {
Ok(inner) => inner,
Err(_) => return Poll::Ready(None), // Return None on poisoned mutex
};

// Check if there are any messages in the queue
if let Some(msg) = inner.queue.pop_front() {
return Poll::Ready(Some(msg)); // Return the next message
}

// If no messages and the sender is dropped, return None
if inner.sender_dropped && inner.queue.is_empty() {
Poll::Ready(None)
} else {
inner.waker = Some(cx.waker().clone()); // Set the waker for later notification
Poll::Pending // No message available, wait
}
}
}

// Create the message channel with a message queue
pub fn message_channel<T>() -> (MessageSender<T>, MessageReceiver<T>) {
let channel = Arc::new(Mutex::new(MessageChannel {
queue: VecDeque::new(), // Initialize an empty message queue
waker: None,
sender_dropped: false, // Initially, the sender is not dropped
receiver_dropped: false, // Initially, the receiver is not dropped
}));
(
MessageSender {
inner: channel.clone(),
},
MessageReceiver { inner: channel },
)
}
8 changes: 8 additions & 0 deletions rust_crate/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub enum RinfError {
NoDartIsolate,
BuildRuntime,
LockMessageChannel,
BrokenMessageChannel,
ClosedMessageChannel,
NoMessageChannel,
MessageReceiverTaken,
DecodeMessage,
Expand All @@ -28,6 +30,12 @@ impl fmt::Display for RinfError {
RinfError::LockMessageChannel => {
write!(f, "Could not acquire the message channel lock.")
}
RinfError::BrokenMessageChannel => {
write!(f, "Message channel is broken.",)
}
RinfError::ClosedMessageChannel => {
write!(f, "Message channel is closed.",)
}
RinfError::NoMessageChannel => {
write!(f, "Message channel was not created.",)
}
Expand Down
2 changes: 2 additions & 0 deletions rust_crate/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod channel;
mod error;
mod macros;

Expand All @@ -7,5 +8,6 @@ mod interface_os;
#[cfg(target_family = "wasm")]
mod interface_web;

pub use channel::*;
pub use error::*;
pub use interface::*;

0 comments on commit 8d845a7

Please sign in to comment.