Skip to content

Commit

Permalink
future: add proper synchronization when awaiting the future
Browse files Browse the repository at this point in the history
Previously, there was a scenario when using `current-thread` tokio runtime,
that caused the deadlock.

The scenario described:
```
- we make use of a `current-thread` tokio runtime
  (`Builder::new_current_thread().enable_all().build()`)
- we create a future which sleeps, let's say, 2s and then sets the value
- we call `cass_future_wait_timed` on this future, with a timeout 1s.
  This consumes the handle, starts the future, times out after 1s (there is still a task to be polled until completion - the future from above which sets the value)
- now we call `cass_future_wait` on the same future.
  This waits on cond variable (blocking operation).
  Now current thread is blocked, it is waiting (on a cond variable) for a value to be set.
  But there is noone to poll the future that sets the value,
  since we are using a `current-thread` runtime - deadlock...
```

This commit fixes that, preventing such scenario from happening.

Co-authored-by: Karol Baryła <[email protected]>
  • Loading branch information
muzarski and Lorak-mmk committed Aug 7, 2024
1 parent fbc198d commit 25440dc
Showing 1 changed file with 91 additions and 36 deletions.
127 changes: 91 additions & 36 deletions scylla-rust-wrapper/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::query_result::CassResult;
use crate::types::*;
use crate::uuid::CassUuid;
use crate::RUNTIME;
use futures::future;
use scylla::prepared_statement::PreparedStatement;
use std::future::Future;
use std::mem;
Expand Down Expand Up @@ -62,8 +63,13 @@ pub struct CassFuture {
/// An error that can appear during `cass_future_wait_timed`.
enum FutureError {
TimeoutError,
InvalidDuration,
}

/// The timeout appeared when we tried to await `JoinHandle`.
/// This errors contains the original handle, so it can be awaited later again.
struct JoinHandleTimeout(JoinHandle<()>);

impl CassFuture {
pub fn make_raw(
fut: impl Future<Output = CassFutureResult> + Send + 'static,
Expand Down Expand Up @@ -116,20 +122,31 @@ impl CassFuture {

fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureState) -> T) -> T {
let mut guard = self.state.lock().unwrap();
let handle = guard.join_handle.take();
if let Some(handle) = handle {
mem::drop(guard);
// unwrap: JoinError appears only when future either panic'ed or canceled.
RUNTIME.block_on(handle).unwrap();
guard = self.state.lock().unwrap();
} else {
guard = self
.wait_for_value
.wait_while(guard, |state| state.value.is_none())
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
loop {
let handle = guard.join_handle.take();
if let Some(handle) = handle {
mem::drop(guard);
// unwrap: JoinError appears only when future either panic'ed or canceled.
RUNTIME.block_on(handle).unwrap();
guard = self.state.lock().unwrap();
} else {
guard = self
.wait_for_value
.wait_while(guard, |state| {
state.value.is_none() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
if guard.join_handle.is_some() {
// join_handle was none, and now it isn't - some other thread must
// have timed out and returned the handle. We need to take over
// the work of completing the feature. To do that, we go into
// another iteration so that we land in the branch with block_on.
continue;
}
}
return f(&mut guard);
}
f(&mut guard)
}

fn with_waited_result_timed<T>(
Expand All @@ -146,31 +163,69 @@ impl CassFuture {
timeout_duration: Duration,
) -> Result<T, FutureError> {
let mut guard = self.state.lock().unwrap();
let handle = guard.join_handle.take();
if let Some(handle) = handle {
mem::drop(guard);
// Need to wrap it with async{} block, so the timeout is lazily executed inside the runtime.
// See mention about panics: https://docs.rs/tokio/latest/tokio/time/fn.timeout.html.
let timed = async { tokio::time::timeout(timeout_duration, handle).await };
// unwrap: JoinError appears only when future either panic'ed or canceled.
RUNTIME
.block_on(timed)
.map_err(|_| FutureError::TimeoutError)?
.unwrap();
guard = self.state.lock().unwrap();
} else {
let (guard_result, timeout_result) = self
.wait_for_value
.wait_timeout_while(guard, timeout_duration, |state| state.value.is_none())
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
if timeout_result.timed_out() {
return Err(FutureError::TimeoutError);
let deadline = tokio::time::Instant::now()
.checked_add(timeout_duration)
.ok_or(FutureError::InvalidDuration)?;

loop {
let handle = guard.join_handle.take();
if let Some(handle) = handle {
mem::drop(guard);
// Need to wrap it with async{} block, so the timeout is lazily executed inside the runtime.
// See mention about panics: https://docs.rs/tokio/latest/tokio/time/fn.timeout.html.
let timed = async {
let sleep_future = tokio::time::sleep_until(deadline);
tokio::pin!(sleep_future);
let value = future::select(handle, sleep_future).await;
match value {
future::Either::Left((result, _)) => Ok(result),
future::Either::Right((_, handle)) => Err(JoinHandleTimeout(handle)),
}
};
match RUNTIME.block_on(timed) {
Err(JoinHandleTimeout(returned_handle)) => {
// We timed out. so we can't finish waiting for the future.
// The problem is that if current thread executor is used,
// then no one will run this future - other threads will
// go into the branch with condvar and wait there.
// To fix that:
// - Return the join handle, so that next thread can take it
// - Signal one thread, so that if all other consumers are
// already waiting on condvar, one of them wakes up and
// picks up the work.
guard = self.state.lock().unwrap();
guard.join_handle = Some(returned_handle);
self.wait_for_value.notify_one();
return Err(FutureError::TimeoutError);
}
// unwrap: JoinError appears only when future either panic'ed or canceled.
Ok(result) => result.unwrap(),
};
guard = self.state.lock().unwrap();
} else {
let remaining_timeout = deadline.duration_since(tokio::time::Instant::now());
let (guard_result, timeout_result) = self
.wait_for_value
.wait_timeout_while(guard, remaining_timeout, |state| {
state.value.is_none() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
if timeout_result.timed_out() {
return Err(FutureError::TimeoutError);
}
guard = guard_result;
if guard.join_handle.is_some() {
// join_handle was none, and now it isn't - some other thread must
// have timed out and returned the handle. We need to take over
// the work of completing the feature. To do that, we go into
// another iteration so that we land in the branch with block_on.
continue;
}
}
guard = guard_result;
}

Ok(f(&mut guard))
return Ok(f(&mut guard));
}
}

pub fn set_callback(&self, cb: CassFutureCallback, data: *mut c_void) -> CassError {
Expand Down

0 comments on commit 25440dc

Please sign in to comment.