Skip to content

Commit

Permalink
Use AtomicU64 instead of u64 for the position field in Receiver
Browse files Browse the repository at this point in the history
The `recv()`, `recv_direct()`, `recv_blocking()`, and `try_recv()`
methods currently require `&mut self` to modify the value of the
`position` field. By using AtomicU64 for the `position` field eliminates
the need for mutability.

Fixes issue #66
  • Loading branch information
hozan23 committed Jul 4, 2024
1 parent de420a3 commit 37d644a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 55 deletions.
94 changes: 49 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
use std::task::{Context, Poll};

use event_listener::{Event, EventListener};
Expand All @@ -135,8 +138,8 @@ use pin_project_lite::pin_project;
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, TryRecvError, TrySendError};
///
/// let (s, mut r1) = broadcast(1);
/// let mut r2 = r1.clone();
/// let (s, r1) = broadcast(1);
/// let r2 = r1.clone();
///
/// assert_eq!(s.broadcast(10).await, Ok(None));
/// assert_eq!(s.try_broadcast(20), Err(TrySendError::Full(20)));
Expand Down Expand Up @@ -169,7 +172,7 @@ pub fn broadcast<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
};
let r = Receiver {
inner,
pos: 0,
pos: AtomicU64::new(0),
listener: None,
};

Expand Down Expand Up @@ -203,21 +206,22 @@ impl<T> Inner<T> {
/// Try receiving at the given position, returning either the element or a reference to it.
///
/// Result is used here instead of Cow because we don't have a Clone bound on T.
fn try_recv_at(&mut self, pos: &mut u64) -> Result<Result<T, &T>, TryRecvError> {
let i = match pos.checked_sub(self.head_pos) {
fn try_recv_at(&mut self, pos: &AtomicU64) -> Result<Result<T, &T>, TryRecvError> {
let i = pos.load(Ordering::Acquire);
let i = match i.checked_sub(self.head_pos) {
Some(i) => i
.try_into()
.expect("Head position more than usize::MAX behind a receiver"),
None => {
let count = self.head_pos - *pos;
*pos = self.head_pos;
let count = self.head_pos - pos.load(Ordering::Relaxed);
pos.store(self.head_pos, Ordering::Release);
return Err(TryRecvError::Overflowed(count));
}
};

let last_waiter;
if let Some((_elt, waiters)) = self.queue.get_mut(i) {
*pos += 1;
pos.fetch_add(1, Ordering::Release);
*waiters -= 1;
last_waiter = *waiters == 0;
} else {
Expand Down Expand Up @@ -331,7 +335,7 @@ impl<T> Sender<T> {
/// ```
/// use async_broadcast::{broadcast, TrySendError, TryRecvError};
///
/// let (mut s, mut r) = broadcast::<i32>(3);
/// let (mut s, r) = broadcast::<i32>(3);
/// assert_eq!(s.capacity(), 3);
/// s.try_broadcast(1).unwrap();
/// s.try_broadcast(2).unwrap();
Expand Down Expand Up @@ -378,7 +382,7 @@ impl<T> Sender<T> {
/// ```
/// use async_broadcast::{broadcast, TrySendError, TryRecvError};
///
/// let (mut s, mut r) = broadcast::<i32>(2);
/// let (mut s, r) = broadcast::<i32>(2);
/// s.try_broadcast(1).unwrap();
/// s.try_broadcast(2).unwrap();
/// assert_eq!(s.try_broadcast(3), Err(TrySendError::Full(3)));
Expand Down Expand Up @@ -423,7 +427,7 @@ impl<T> Sender<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::broadcast;
///
/// let (mut s, mut r) = broadcast::<i32>(2);
/// let (mut s, r) = broadcast::<i32>(2);
/// s.broadcast(1).await.unwrap();
///
/// let _ = r.deactivate();
Expand All @@ -447,7 +451,7 @@ impl<T> Sender<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r) = broadcast(1);
/// let (s, r) = broadcast(1);
/// s.broadcast(1).await.unwrap();
/// assert!(s.close());
///
Expand Down Expand Up @@ -611,11 +615,11 @@ impl<T> Sender<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r1) = broadcast(2);
/// let (s, r1) = broadcast(2);
///
/// assert_eq!(s.broadcast(1).await, Ok(None));
///
/// let mut r2 = s.new_receiver();
/// let r2 = s.new_receiver();
///
/// assert_eq!(s.broadcast(2).await, Ok(None));
/// drop(s);
Expand All @@ -633,7 +637,7 @@ impl<T> Sender<T> {
inner.receiver_count += 1;
Receiver {
inner: self.inner.clone(),
pos: inner.head_pos + inner.queue.len() as u64,
pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64),
listener: None,
}
}
Expand Down Expand Up @@ -816,7 +820,7 @@ impl<T> Clone for Sender<T> {
#[derive(Debug)]
pub struct Receiver<T> {
inner: Arc<Mutex<Inner<T>>>,
pos: u64,
pos: AtomicU64,

/// Listens for a send or close event to unblock this stream.
listener: Option<EventListener>,
Expand Down Expand Up @@ -964,7 +968,7 @@ impl<T> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r) = broadcast(1);
/// let (s, r) = broadcast(1);
/// s.broadcast(1).await.unwrap();
/// assert!(s.close());
///
Expand Down Expand Up @@ -1138,7 +1142,7 @@ impl<T> Receiver<T> {
/// let inactive = r.deactivate();
/// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10)));
///
/// let mut r = inactive.activate();
/// let r = inactive.activate();
/// assert_eq!(s.broadcast(10).await, Ok(None));
/// assert_eq!(r.recv().await, Ok(10));
/// # });
Expand Down Expand Up @@ -1175,8 +1179,8 @@ impl<T: Clone> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r1) = broadcast(1);
/// let mut r2 = r1.clone();
/// let (s, r1) = broadcast(1);
/// let r2 = r1.clone();
///
/// assert_eq!(s.broadcast(1).await, Ok(None));
/// drop(s);
Expand All @@ -1187,7 +1191,7 @@ impl<T: Clone> Receiver<T> {
/// assert_eq!(r2.recv().await, Err(RecvError::Closed));
/// # });
/// ```
pub fn recv(&mut self) -> Pin<Box<Recv<'_, T>>> {
pub fn recv(&self) -> Pin<Box<Recv<'_, T>>> {
Box::pin(self.recv_direct())
}

Expand All @@ -1203,8 +1207,8 @@ impl<T: Clone> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r1) = broadcast(1);
/// let mut r2 = r1.clone();
/// let (s, r1) = broadcast(1);
/// let r2 = r1.clone();
///
/// assert_eq!(s.broadcast(1).await, Ok(None));
/// drop(s);
Expand All @@ -1215,7 +1219,7 @@ impl<T: Clone> Receiver<T> {
/// assert_eq!(r2.recv_direct().await, Err(RecvError::Closed));
/// # });
/// ```
pub fn recv_direct(&mut self) -> Recv<'_, T> {
pub fn recv_direct(&self) -> Recv<'_, T> {
Recv::_new(RecvInner {
receiver: self,
listener: None,
Expand All @@ -1237,10 +1241,9 @@ impl<T: Clone> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, TryRecvError};
///
/// let (s, mut r1) = broadcast(1);
/// let mut r2 = r1.clone();
/// let (s, r1) = broadcast(1);
/// let r2 = r1.clone();
/// assert_eq!(s.broadcast(1).await, Ok(None));
///
/// assert_eq!(r1.try_recv(), Ok(1));
/// assert_eq!(r1.try_recv(), Err(TryRecvError::Empty));
/// assert_eq!(r2.try_recv(), Ok(1));
Expand All @@ -1251,11 +1254,11 @@ impl<T: Clone> Receiver<T> {
/// assert_eq!(r2.try_recv(), Err(TryRecvError::Closed));
/// # });
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.inner
.lock()
.unwrap()
.try_recv_at(&mut self.pos)
.try_recv_at(&self.pos)
.map(|cow| cow.unwrap_or_else(T::clone))
}

Expand Down Expand Up @@ -1284,7 +1287,7 @@ impl<T: Clone> Receiver<T> {
/// ```
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r) = broadcast(1);
/// let (s, r) = broadcast(1);
///
/// assert_eq!(s.broadcast_blocking(1), Ok(None));
/// drop(s);
Expand All @@ -1293,7 +1296,7 @@ impl<T: Clone> Receiver<T> {
/// assert_eq!(r.recv_blocking(), Err(RecvError::Closed));
/// ```
#[cfg(not(target_family = "wasm"))]
pub fn recv_blocking(&mut self) -> Result<T, RecvError> {
pub fn recv_blocking(&self) -> Result<T, RecvError> {
self.recv_direct().wait()
}

Expand All @@ -1307,7 +1310,7 @@ impl<T: Clone> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s1, mut r) = broadcast(2);
/// let (s1, r) = broadcast(2);
///
/// assert_eq!(s1.broadcast(1).await, Ok(None));
///
Expand Down Expand Up @@ -1341,11 +1344,11 @@ impl<T: Clone> Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r1) = broadcast(2);
/// let (s, r1) = broadcast(2);
///
/// assert_eq!(s.broadcast(1).await, Ok(None));
///
/// let mut r2 = r1.new_receiver();
/// let r2 = r1.new_receiver();
///
/// assert_eq!(s.broadcast(2).await, Ok(None));
/// drop(s);
Expand All @@ -1363,7 +1366,7 @@ impl<T: Clone> Receiver<T> {
inner.receiver_count += 1;
Receiver {
inner: self.inner.clone(),
pos: inner.head_pos + inner.queue.len() as u64,
pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64),
listener: None,
}
}
Expand Down Expand Up @@ -1458,7 +1461,7 @@ impl<T> Drop for Receiver<T> {

// Remove ourself from each item's counter
loop {
match inner.try_recv_at(&mut self.pos) {
match inner.try_recv_at(&self.pos) {
Ok(_) => continue,
Err(TryRecvError::Overflowed(_)) => continue,
Err(TryRecvError::Closed) => break,
Expand All @@ -1481,12 +1484,12 @@ impl<T> Clone for Receiver<T> {
/// # futures_lite::future::block_on(async {
/// use async_broadcast::{broadcast, RecvError};
///
/// let (s, mut r1) = broadcast(1);
/// let (s, r1) = broadcast(1);
///
/// assert_eq!(s.broadcast(1).await, Ok(None));
/// drop(s);
///
/// let mut r2 = r1.clone();
/// let r2 = r1.clone();
///
/// assert_eq!(r1.recv().await, Ok(1));
/// assert_eq!(r1.recv().await, Err(RecvError::Closed));
Expand All @@ -1498,13 +1501,14 @@ impl<T> Clone for Receiver<T> {
let mut inner = self.inner.lock().unwrap();
inner.receiver_count += 1;
// increment the waiter count on all items not yet received by this object
let n = self.pos.saturating_sub(inner.head_pos) as usize;
let pos = self.pos.load(Ordering::Relaxed);
let n = pos.saturating_sub(inner.head_pos) as usize;
for (_elt, waiters) in inner.queue.iter_mut().skip(n) {
*waiters += 1;
}
Receiver {
inner: self.inner.clone(),
pos: self.pos,
pos: AtomicU64::new(pos),
listener: None,
}
}
Expand Down Expand Up @@ -1798,7 +1802,7 @@ easy_wrapper! {
pin_project! {
#[derive(Debug)]
struct RecvInner<'a, T> {
receiver: &'a mut Receiver<T>,
receiver: &'a Receiver<T>,
listener: Option<EventListener>,

// Keeping this type `!Unpin` enables future optimizations.
Expand Down Expand Up @@ -1870,7 +1874,7 @@ impl<T> InactiveReceiver<T> {
/// let inactive = r.deactivate();
/// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10)));
///
/// let mut r = inactive.activate();
/// let r = inactive.activate();
/// assert_eq!(s.try_broadcast(10), Ok(None));
/// assert_eq!(r.try_recv(), Ok(10));
/// ```
Expand All @@ -1889,7 +1893,7 @@ impl<T> InactiveReceiver<T> {
/// let inactive = r.deactivate();
/// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10)));
///
/// let mut r = inactive.activate_cloned();
/// let r = inactive.activate_cloned();
/// assert_eq!(s.try_broadcast(10), Ok(None));
/// assert_eq!(r.try_recv(), Ok(10));
/// ```
Expand All @@ -1905,7 +1909,7 @@ impl<T> InactiveReceiver<T> {

Receiver {
inner: self.inner.clone(),
pos: inner.head_pos + inner.queue.len() as u64,
pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64),
listener: None,
}
}
Expand Down
Loading

0 comments on commit 37d644a

Please sign in to comment.