From c04b922a5f2ea68af462e253a23be8ec1a1ab87e Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Sat, 11 Nov 2023 18:41:04 +0100 Subject: [PATCH] Remove wakers for cancelled tasks When an async function does: ``` async f(&mut self) { tokio::select! { v = self.subscriber().next => { /* do something with v */ } _ = std::future::ready() => {}, } } ``` then the future returned by `self.subscriber().next` is cancelled, but the observed object stilled referenced the waker, preventing the future (consequently, the function's closure) from being dropped even though it won't be scheduled again. This change is twofold: 1. `ObservableState` is now handed `Weak` references, so it does not keep futures alive, and a strong reference is kept by whichever object is held by the future awaiting it (`Subscriber` or `Next`) 2. `ObservableState` garbage-collects weak references from time to time, so its own vector of wakers does not grow unbounded --- eyeball/src/state.rs | 72 ++++++++++++++++++++++++++++++++++----- eyeball/src/subscriber.rs | 42 +++++++++++++++++------ 2 files changed, 96 insertions(+), 18 deletions(-) diff --git a/eyeball/src/state.rs b/eyeball/src/state.rs index ae77a4c..886980b 100644 --- a/eyeball/src/state.rs +++ b/eyeball/src/state.rs @@ -2,8 +2,8 @@ use std::{ hash::{Hash, Hasher}, mem, sync::{ - atomic::{AtomicU64, Ordering}, - RwLock, + atomic::{AtomicU64, AtomicUsize, Ordering}, + RwLock, Weak, }, task::Waker, }; @@ -27,12 +27,29 @@ pub struct ObservableState { /// locked for reading. This way, it is guaranteed that between a subscriber /// reading the value and adding a waker because the value hasn't changed /// yet, no updates to the value could have happened. - wakers: RwLock>, + /// + /// It contains weak references to wakers, so it does not keep references to + /// [`Subscriber`](crate::Subscriber) or [`Next`](crate::subscriber::Next) + /// that would otherwise be dropped and won't be awaited again (eg. as part + /// of a future being cancelled). + wakers: RwLock>>, + + /// Whenever wakers.len() reaches this size, iterate through it and remove + /// dangling weak references. + /// This is updated in order to only cleanup every time the list of wakers + /// doubled in size since the previous cleanup, allowing a O(1) amortized + /// time complexity. + next_wakers_cleanup_at_len: AtomicUsize, } impl ObservableState { pub(crate) fn new(value: T) -> Self { - Self { value, version: AtomicU64::new(1), wakers: Default::default() } + Self { + value, + version: AtomicU64::new(1), + wakers: Default::default(), + next_wakers_cleanup_at_len: AtomicUsize::new(64), // Arbitrary constant + } } /// Get a reference to the inner value. @@ -45,8 +62,34 @@ impl ObservableState { self.version.load(Ordering::Acquire) } - pub(crate) fn add_waker(&self, waker: Waker) { - self.wakers.write().unwrap().push(waker); + pub(crate) fn add_waker(&self, waker: Weak) { + // TODO: clean up dangling Weak references in the vector if there are too many + let mut wakers = self.wakers.write().unwrap(); + wakers.push(waker); + if wakers.len() >= self.next_wakers_cleanup_at_len.load(Ordering::Relaxed) { + // Remove dangling Weak references from the vector to free any + // cancelled future that awaited on a `Subscriber` of this + // observable. + let mut new_wakers = Vec::with_capacity(wakers.len()); + for waker in wakers.iter() { + if waker.strong_count() > 0 { + new_wakers.push(waker.clone()); + } + } + if new_wakers.len() == wakers.len() { + #[cfg(feature = "tracing")] + tracing::debug!("No dangling wakers among set of {}", wakers.len()); + } else { + #[cfg(feature = "tracing")] + tracing::debug!( + "Removed {} dangling wakers from a set of {}", + wakers.len() - new_wakers.len(), + wakers.len() + ); + std::mem::swap(&mut *wakers, &mut new_wakers); + } + self.next_wakers_cleanup_at_len.store(wakers.len() * 2, Ordering::Relaxed); + } } pub(crate) fn set(&mut self, value: T) -> T { @@ -111,7 +154,7 @@ fn hash(value: &T) -> u64 { fn wake(wakers: I) where - I: IntoIterator, + I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let iter = wakers.into_iter(); @@ -124,7 +167,20 @@ where tracing::debug!("No wakers"); } } + let mut num_alive_wakers = 0; for waker in iter { - waker.wake(); + if let Some(waker) = waker.upgrade() { + num_alive_wakers += 1; + waker.wake_by_ref(); + } + } + + #[cfg(feature = "tracing")] + { + tracing::debug!("Woke up {num_alive_wakers} waiting subscribers"); + } + #[cfg(not(feature = "tracing"))] + { + let _ = num_alive_wakers; // For Clippy } } diff --git a/eyeball/src/subscriber.rs b/eyeball/src/subscriber.rs index ae1bde1..1513be0 100644 --- a/eyeball/src/subscriber.rs +++ b/eyeball/src/subscriber.rs @@ -7,7 +7,8 @@ use std::{ fmt, future::{poll_fn, Future}, pin::Pin, - task::{Context, Poll}, + sync::{Arc, Weak}, + task::{Context, Poll, Waker}, }; use futures_core::Stream; @@ -22,11 +23,14 @@ pub(crate) mod async_lock; pub struct Subscriber { state: L::SubscriberState, observed_version: u64, + /// Prevent wakers from being dropped from `ObservableState` until this + /// `Subscriber` is dropped + wakers: Vec>, } impl Subscriber { pub(crate) fn new(state: readlock::SharedReadLock>, version: u64) -> Self { - Self { state, observed_version: version } + Self { state, observed_version: version, wakers: Vec::new() } } /// Wait for an update and get a clone of the updated value. @@ -87,7 +91,12 @@ impl Subscriber { #[must_use] pub async fn next_ref(&mut self) -> Option> { // Unclear how to implement this as a named future. - poll_fn(|cx| self.poll_next_ref(cx).map(|opt| opt.map(|_| {}))).await?; + let mut waker = None; + poll_fn(|cx| { + waker = Some(Arc::new(cx.waker().clone())); + self.poll_next_ref(Arc::downgrade(waker.as_ref().unwrap())).map(|opt| opt.map(|_| {})) + }) + .await?; Some(self.next_ref_now()) } @@ -120,7 +129,7 @@ impl Subscriber { ObservableReadGuard::new(self.state.lock()) } - fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll>> { + fn poll_next_ref(&mut self, waker: Weak) -> Poll>> { let state = self.state.lock(); let version = state.version(); if version == 0 { @@ -129,7 +138,7 @@ impl Subscriber { self.observed_version = version; Poll::Ready(Some(ObservableReadGuard::new(state))) } else { - state.add_waker(cx.waker().clone()); + state.add_waker(waker); Poll::Pending } } @@ -160,7 +169,7 @@ impl Subscriber { where L::SubscriberState: Clone, { - Self { state: self.state.clone(), observed_version: 0 } + Self { state: self.state.clone(), observed_version: 0, wakers: Vec::new() } } } @@ -178,7 +187,11 @@ where L::SubscriberState: Clone, { fn clone(&self) -> Self { - Self { state: self.state.clone(), observed_version: self.observed_version } + Self { + state: self.state.clone(), + observed_version: self.observed_version, + wakers: Vec::new(), + } } } @@ -198,7 +211,10 @@ impl Stream for Subscriber { type Item = T; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_next_ref(cx).map(opt_guard_to_owned) + let waker = Arc::new(cx.waker().clone()); + let poll = self.poll_next_ref(Arc::downgrade(&waker)).map(opt_guard_to_owned); + self.wakers.push(waker); + poll } } @@ -207,11 +223,14 @@ impl Stream for Subscriber { #[allow(missing_debug_implementations)] pub struct Next<'a, T, L: Lock = SyncLock> { subscriber: &'a mut Subscriber, + /// Prevent wakers from being dropped from `ObservableState` until this + /// `Next` is dropped + wakers: Vec>, } impl<'a, T> Next<'a, T> { fn new(subscriber: &'a mut Subscriber) -> Self { - Self { subscriber } + Self { subscriber, wakers: Vec::new() } } } @@ -219,7 +238,10 @@ impl Future for Next<'_, T> { type Output = Option; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.subscriber.poll_next_ref(cx).map(opt_guard_to_owned) + let waker = Arc::new(cx.waker().clone()); + let poll = self.subscriber.poll_next_ref(Arc::downgrade(&waker)).map(opt_guard_to_owned); + self.wakers.push(waker); + poll } }