diff --git a/src/lib.rs b/src/lib.rs
index 5e94e2f..ac0188b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -127,6 +127,14 @@ where
self.inner.load(order).is_null()
}
+ #[inline]
+ fn inner_into_raw(val: Option
) -> *mut () {
+ match val {
+ Some(val) => val.into_raw(),
+ None => ptr::null_mut(),
+ }
+ }
+
#[inline]
unsafe fn inner_from_raw(ptr: *mut ()) -> Option
{
if !ptr.is_null() {
@@ -137,6 +145,100 @@ where
}
}
+impl
Atom
+where
+ P: IntoRawPtr + FromRawPtr + Deref,
+{
+ /// Stores `new` in the Atom if `current` has the same raw pointer
+ /// representation as the currently stored value.
+ ///
+ /// On success, the Atom's previous value is returned. On failure, `new` is
+ /// returned together with a raw pointer to the Atom's current unchanged
+ /// value, which is **not safe to dereference**, especially if the Atom is
+ /// accessed from multiple threads.
+ ///
+ /// `compare_and_swap` also takes an `Ordering` argument which describes
+ /// the memory ordering of this operation.
+ pub fn compare_and_swap(
+ &self,
+ current: Option<&P>,
+ new: Option,
+ order: Ordering,
+ ) -> Result, (Option, *mut P)> {
+ let pcurrent = Self::inner_as_ptr(current);
+ let pnew = Self::inner_into_raw(new);
+ let pprev = self.inner.compare_and_swap(pcurrent, pnew, order);
+ if pprev == pcurrent {
+ Ok(unsafe { Self::inner_from_raw(pprev) })
+ } else {
+ Err((unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P))
+ }
+ }
+
+ /// Stores a value into the pointer if the current value is the same as the
+ /// `current` value.
+ ///
+ /// The return value is a result indicating whether the new value was
+ /// written and containing the previous value. On success this value is
+ /// guaranteed to be equal to `current`.
+ ///
+ /// `compare_exchange` takes two `Ordering` arguments to describe the
+ /// memory ordering of this operation. The first describes the required
+ /// ordering if the operation succeeds while the second describes the
+ /// required ordering when the operation fails. The failure ordering can't
+ /// be `Release` or `AcqRel` and must be equivalent or weaker than the
+ /// success ordering.
+ pub fn compare_exchange(
+ &self,
+ current: Option<&P>,
+ new: Option
,
+ success: Ordering,
+ failure: Ordering,
+ ) -> Result, (Option, *mut P)> {
+ let pnew = Self::inner_into_raw(new);
+ self.inner
+ .compare_exchange(Self::inner_as_ptr(current), pnew, success, failure)
+ .map(|pprev| unsafe { Self::inner_from_raw(pprev) })
+ .map_err(|pprev| (unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P))
+ }
+
+ /// Stores a value into the pointer if the current value is the same as the
+ /// `current` value.
+ ///
+ /// Unlike `compare_exchange`, this function is allowed to spuriously fail
+ /// even when the comparison succeeds, which can result in more efficient
+ /// code on some platforms. The return value is a result indicating whether
+ /// the new value was written and containing the previous value.
+ ///
+ /// `compare_exchange_weak` takes two `Ordering` arguments to describe the
+ /// memory ordering of this operation. The first describes the required
+ /// ordering if the operation succeeds while the second describes the
+ /// required ordering when the operation fails. The failure ordering can't
+ /// be `Release` or `AcqRel` and must be equivalent or weaker than the
+ /// success ordering.
+ pub fn compare_exchange_weak(
+ &self,
+ current: Option<&P>,
+ new: Option
,
+ success: Ordering,
+ failure: Ordering,
+ ) -> Result, (Option, *mut P)> {
+ let pnew = Self::inner_into_raw(new);
+ self.inner
+ .compare_exchange_weak(Self::inner_as_ptr(current), pnew, success, failure)
+ .map(|pprev| unsafe { Self::inner_from_raw(pprev) })
+ .map_err(|pprev| (unsafe { Self::inner_from_raw(pnew) }, pprev as *mut P))
+ }
+
+ #[inline]
+ fn inner_as_ptr(val: Option<&P>) -> *mut () {
+ match val {
+ Some(val) => &**val as *const _ as *mut (),
+ None => ptr::null_mut(),
+ }
+ }
+}
+
impl
Drop for Atom
where
P: IntoRawPtr + FromRawPtr,
diff --git a/tests/atom.rs b/tests/atom.rs
index 22c3073..17f7e08 100644
--- a/tests/atom.rs
+++ b/tests/atom.rs
@@ -15,6 +15,7 @@
extern crate atom;
use atom::*;
+use std::collections::HashSet;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::*;
@@ -45,6 +46,120 @@ fn set_if_none() {
);
}
+#[test]
+fn compare_and_swap_basics() {
+ cas_test_basics_helper(|a, cas_val, next_val| {
+ a.compare_and_swap(cas_val, next_val, Ordering::SeqCst)
+ });
+}
+
+#[test]
+fn compare_exchange_basics() {
+ cas_test_basics_helper(|a, cas_val, next_val| {
+ a.compare_exchange(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst)
+ });
+}
+
+#[test]
+fn compare_exchange_weak_basics() {
+ cas_test_basics_helper(|a, cas_val, next_val| {
+ a.compare_exchange_weak(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst)
+ });
+}
+
+#[test]
+fn compare_and_swap_threads() {
+ cas_test_threads_helper(|a, cas_val, next_val| {
+ a.compare_and_swap(cas_val, next_val, Ordering::SeqCst)
+ });
+}
+
+#[test]
+fn compare_exchange_threads() {
+ cas_test_threads_helper(|a, cas_val, next_val| {
+ a.compare_exchange(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst)
+ });
+}
+
+#[test]
+fn compare_exchange_weak_threads() {
+ cas_test_threads_helper(|a, cas_val, next_val| {
+ a.compare_exchange_weak(cas_val, next_val, Ordering::SeqCst, Ordering::SeqCst)
+ });
+}
+
+type TestCASFn = fn(&Atom>, Option<&Arc>, Option>)
+ -> Result>, (Option>, *mut Arc)>;
+
+fn cas_test_basics_helper(cas: TestCASFn) {
+ let cur_val = Arc::new("123".to_owned());
+ let mut next_val = Arc::new("456".to_owned());
+ let other_val = Arc::new("1927447".to_owned());
+
+ let a = Atom::new(cur_val.clone());
+
+ let pcur = IntoRawPtr::into_raw(cur_val.clone());
+ let pnext = IntoRawPtr::into_raw(next_val.clone());
+
+ for attempt in vec![None, Some(&other_val), Some(&Arc::new("wow".to_owned()))] {
+ let res = cas(&a, attempt, Some(next_val.clone())).unwrap_err();
+ next_val = res.0.unwrap();
+ assert_eq!(res.1, pcur as *mut _);
+ }
+
+ let res = cas(&a, Some(&cur_val), Some(next_val.clone()));
+ assert_eq!(res, Ok(Some(cur_val)));
+
+ for attempt in vec![None, Some(&other_val), Some(&Arc::new("wow".to_owned()))] {
+ let res = cas(&a, attempt, None).unwrap_err();
+ assert_eq!(res, (None, pnext as *mut _));
+ }
+}
+
+fn cas_test_threads_helper(cas: TestCASFn) {
+ let cur_val = Arc::new("current".to_owned());
+ let next_val = Arc::new("next".to_owned());
+ let other_val = Arc::new("other".to_owned());
+
+ let a = Arc::new(Atom::new(cur_val.clone()));
+
+ let num_threads = 10;
+ let cas_thread = num_threads / 2;
+ let pprevs: Vec> = (0..num_threads)
+ .map(|i| {
+ let a = a.clone();
+ let cur_val = cur_val.clone();
+ let next_val = next_val.clone();
+ let other_val = other_val.clone();
+ thread::spawn(move || {
+ let cas_val = Some(if i == cas_thread {
+ &cur_val
+ } else {
+ &other_val
+ });
+ match cas(&a, cas_val, Some(next_val.clone())) {
+ Ok(prev) => {
+ let prev = prev.unwrap();
+ assert!(Arc::ptr_eq(&prev, &cur_val));
+ assert!(!Arc::ptr_eq(&prev, &next_val));
+ Ok(prev.into_raw() as usize)
+ }
+ Err((_, pprev)) => Err(pprev as usize),
+ }
+ })
+ })
+ .map(|handle| handle.join().unwrap())
+ .collect();
+ assert_eq!(pprevs.iter().filter(|pprev| pprev.is_ok()).count(), 1);
+ let uniq_pprevs: HashSet<_> = pprevs
+ .into_iter()
+ .map(|pprev| pprev.unwrap_or_else(|pprev| pprev) as *mut _)
+ .collect();
+ assert!(uniq_pprevs.contains(&cur_val.into_raw()));
+ assert!(!uniq_pprevs.contains(&other_val.into_raw()));
+ assert_eq!(a.take(Ordering::Relaxed), Some(next_val));
+}
+
#[derive(Clone)]
struct Canary(Arc);