diff --git a/src/lib.rs b/src/lib.rs index 5e94e2f..ac0188b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,6 +127,14 @@ where self.inner.load(order).is_null() } + #[inline] + fn inner_into_raw(val: Option

) -> *mut () { + match val { + Some(val) => val.into_raw(), + None => ptr::null_mut(), + } + } + #[inline] unsafe fn inner_from_raw(ptr: *mut ()) -> Option

{ if !ptr.is_null() { @@ -137,6 +145,100 @@ where } } +impl Atom

+where + P: IntoRawPtr + FromRawPtr + Deref, +{ + /// Stores `new` in the Atom if `current` has the same raw pointer + /// representation as the currently stored value. + /// + /// On success, the Atom's previous value is returned. On failure, `new` is + /// returned together with a raw pointer to the Atom's current unchanged + /// value, which is **not safe to dereference**, especially if the Atom is + /// accessed from multiple threads. + /// + /// `compare_and_swap` also takes an `Ordering` argument which describes + /// the memory ordering of this operation. + pub fn compare_and_swap( + &self, + current: Option<&P>, + new: Option

, + order: Ordering, + ) -> Result, (Option

, *mut P)> { + let pcurrent = Self::inner_as_ptr(current); + let pnew = Self::inner_into_raw(new); + let pprev = self.inner.compare_and_swap(pcurrent, pnew, order); + if pprev == pcurrent { + Ok(unsafe { Self::inner_from_raw(pprev) }) + } else { + Err((unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P)) + } + } + + /// Stores a value into the pointer if the current value is the same as the + /// `current` value. + /// + /// The return value is a result indicating whether the new value was + /// written and containing the previous value. On success this value is + /// guaranteed to be equal to `current`. + /// + /// `compare_exchange` takes two `Ordering` arguments to describe the + /// memory ordering of this operation. The first describes the required + /// ordering if the operation succeeds while the second describes the + /// required ordering when the operation fails. The failure ordering can't + /// be `Release` or `AcqRel` and must be equivalent or weaker than the + /// success ordering. + pub fn compare_exchange( + &self, + current: Option<&P>, + new: Option

, + success: Ordering, + failure: Ordering, + ) -> Result, (Option

, *mut P)> { + let pnew = Self::inner_into_raw(new); + self.inner + .compare_exchange(Self::inner_as_ptr(current), pnew, success, failure) + .map(|pprev| unsafe { Self::inner_from_raw(pprev) }) + .map_err(|pprev| (unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P)) + } + + /// Stores a value into the pointer if the current value is the same as the + /// `current` value. + /// + /// Unlike `compare_exchange`, this function is allowed to spuriously fail + /// even when the comparison succeeds, which can result in more efficient + /// code on some platforms. The return value is a result indicating whether + /// the new value was written and containing the previous value. + /// + /// `compare_exchange_weak` takes two `Ordering` arguments to describe the + /// memory ordering of this operation. The first describes the required + /// ordering if the operation succeeds while the second describes the + /// required ordering when the operation fails. The failure ordering can't + /// be `Release` or `AcqRel` and must be equivalent or weaker than the + /// success ordering. + pub fn compare_exchange_weak( + &self, + current: Option<&P>, + new: Option

, + success: Ordering, + failure: Ordering, + ) -> Result, (Option

, *mut P)> { + let pnew = Self::inner_into_raw(new); + self.inner + .compare_exchange_weak(Self::inner_as_ptr(current), pnew, success, failure) + .map(|pprev| unsafe { Self::inner_from_raw(pprev) }) + .map_err(|pprev| (unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P)) + } + + #[inline] + fn inner_as_ptr(val: Option<&P>) -> *mut () { + match val { + Some(val) => &**val as *const _ as *mut (), + None => ptr::null_mut(), + } + } +} + impl

Drop for Atom

where P: IntoRawPtr + FromRawPtr, diff --git a/tests/atom.rs b/tests/atom.rs index 22c3073..17f7e08 100644 --- a/tests/atom.rs +++ b/tests/atom.rs @@ -15,6 +15,7 @@ extern crate atom; use atom::*; +use std::collections::HashSet; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::*; @@ -45,6 +46,120 @@ fn set_if_none() { ); } +#[test] +fn compare_and_swap_basics() { + cas_test_basics_helper(|a, cas_val, next_val| { + a.compare_and_swap(cas_val, next_val, Ordering::SeqCst) + }); +} + +#[test] +fn compare_exchange_basics() { + cas_test_basics_helper(|a, cas_val, next_val| { + a.compare_exchange(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst) + }); +} + +#[test] +fn compare_exchange_weak_basics() { + cas_test_basics_helper(|a, cas_val, next_val| { + a.compare_exchange_weak(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst) + }); +} + +#[test] +fn compare_and_swap_threads() { + cas_test_threads_helper(|a, cas_val, next_val| { + a.compare_and_swap(cas_val, next_val, Ordering::SeqCst) + }); +} + +#[test] +fn compare_exchange_threads() { + cas_test_threads_helper(|a, cas_val, next_val| { + a.compare_exchange(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst) + }); +} + +#[test] +fn compare_exchange_weak_threads() { + cas_test_threads_helper(|a, cas_val, next_val| { + a.compare_exchange_weak(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst) + }); +} + +type TestCASFn = fn(&Atom>, Option<&Arc>, Option>) + -> Result>, (Option>, *mut Arc)>; + +fn cas_test_basics_helper(cas: TestCASFn) { + let cur_val = Arc::new("123".to_owned()); + let mut next_val = Arc::new("456".to_owned()); + let other_val = Arc::new("1927447".to_owned()); + + let a = Atom::new(cur_val.clone()); + + let pcur = IntoRawPtr::into_raw(cur_val.clone()); + let pnext = IntoRawPtr::into_raw(next_val.clone()); + + for attempt in vec![None, Some(&other_val), Some(&Arc::new("wow".to_owned()))] { + let res = cas(&a, attempt, Some(next_val.clone())).unwrap_err(); + next_val = res.0.unwrap(); + assert_eq!(res.1, pcur as *mut _); + } + + let res = cas(&a, Some(&cur_val), Some(next_val.clone())); + assert_eq!(res, Ok(Some(cur_val))); + + for attempt in vec![None, Some(&other_val), Some(&Arc::new("wow".to_owned()))] { + let res = cas(&a, attempt, None).unwrap_err(); + assert_eq!(res, (None, pnext as *mut _)); + } +} + +fn cas_test_threads_helper(cas: TestCASFn) { + let cur_val = Arc::new("current".to_owned()); + let next_val = Arc::new("next".to_owned()); + let other_val = Arc::new("other".to_owned()); + + let a = Arc::new(Atom::new(cur_val.clone())); + + let num_threads = 10; + let cas_thread = num_threads / 2; + let pprevs: Vec> = (0..num_threads) + .map(|i| { + let a = a.clone(); + let cur_val = cur_val.clone(); + let next_val = next_val.clone(); + let other_val = other_val.clone(); + thread::spawn(move || { + let cas_val = Some(if i == cas_thread { + &cur_val + } else { + &other_val + }); + match cas(&a, cas_val, Some(next_val.clone())) { + Ok(prev) => { + let prev = prev.unwrap(); + assert!(Arc::ptr_eq(&prev, &cur_val)); + assert!(!Arc::ptr_eq(&prev, &next_val)); + Ok(prev.into_raw() as usize) + } + Err((_, pprev)) => Err(pprev as usize), + } + }) + }) + .map(|handle| handle.join().unwrap()) + .collect(); + assert_eq!(pprevs.iter().filter(|pprev| pprev.is_ok()).count(), 1); + let uniq_pprevs: HashSet<_> = pprevs + .into_iter() + .map(|pprev| pprev.unwrap_or_else(|pprev| pprev) as *mut _) + .collect(); + assert!(uniq_pprevs.contains(&cur_val.into_raw())); + assert!(!uniq_pprevs.contains(&other_val.into_raw())); + assert_eq!(a.take(Ordering::Relaxed), Some(next_val)); +} + #[derive(Clone)] struct Canary(Arc);