Skip to content

Commit

Permalink
refactor PyErr state to reduce blast radius of threading challenge
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Oct 25, 2024
1 parent b3bb667 commit 9f344ac
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 121 deletions.
190 changes: 135 additions & 55 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,112 @@
use std::cell::UnsafeCell;

use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
types::{PyTraceback, PyType},
Bound, IntoPy, Py, PyAny, PyObject, PyTypeInfo, Python,
Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
};

pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Self::from_inner(PyErrStateInner::Lazy(f))
}

pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}
})))
}

#[cfg(not(Py_3_12))]
pub(crate) fn ffi_tuple(
ptype: PyObject,
pvalue: Option<PyObject>,
ptraceback: Option<PyObject>,
) -> Self {
Self::from_inner(PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
})
}

pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
#[cfg(not(Py_3_12))]
use crate::types::any::PyAnyMethods;

Self::from_inner(PyErrStateInner::Normalized(normalized))
}

pub fn restore(self, py: Python<'_>) {
self.inner
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
}

self.make_normalized(py)
}

#[cold]
fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
}

pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: Py<PyType>,
Expand All @@ -14,6 +116,24 @@ pub(crate) struct PyErrStateNormalized {
}

impl PyErrStateNormalized {
pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
#[cfg(not(Py_3_12))]
use crate::types::any::PyAnyMethods;

Self {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
pvalue: pvalue.into(),
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
self.ptype.bind(py).clone()
Expand Down Expand Up @@ -85,7 +205,7 @@ pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) type PyErrStateLazyFn =
dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;

pub(crate) enum PyErrState {
enum PyErrStateInner {
Lazy(Box<PyErrStateLazyFn>),
#[cfg(not(Py_3_12))]
FfiTuple {
Expand All @@ -96,66 +216,26 @@ pub(crate) enum PyErrState {
Normalized(PyErrStateNormalized),
}

/// Helper conversion trait that allows to use custom arguments for lazy exception construction.
pub trait PyErrArguments: Send + Sync {
/// Arguments for exception
fn arguments(self, py: Python<'_>) -> PyObject;
}

impl<T> PyErrArguments for T
where
T: IntoPy<PyObject> + Send + Sync,
{
fn arguments(self, py: Python<'_>) -> PyObject {
self.into_py(py)
}
}

impl PyErrState {
pub(crate) fn lazy(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
PyErrState::Lazy(Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}))
}

pub(crate) fn normalized(pvalue: Bound<'_, PyBaseException>) -> Self {
#[cfg(not(Py_3_12))]
use crate::types::any::PyAnyMethods;

Self::Normalized(PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
pvalue: pvalue.into(),
})
}

pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
impl PyErrStateInner {
pub fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
match self {
#[cfg(not(Py_3_12))]
PyErrState::Lazy(lazy) => {
PyErrStateInner::Lazy(lazy) => {
let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
unsafe {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
#[cfg(Py_3_12)]
PyErrState::Lazy(lazy) => {
PyErrStateInner::Lazy(lazy) => {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
raise_lazy(py, lazy);
PyErrStateNormalized::take(py)
.expect("exception missing after writing to the interpreter")
}
#[cfg(not(Py_3_12))]
PyErrState::FfiTuple {
PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
Expand All @@ -168,15 +248,15 @@ impl PyErrState {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
PyErrState::Normalized(normalized) => normalized,
PyErrStateInner::Normalized(normalized) => normalized,
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
pub fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = match self {
PyErrState::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrState::FfiTuple {
PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
Expand All @@ -185,7 +265,7 @@ impl PyErrState {
pvalue.map_or(std::ptr::null_mut(), Py::into_ptr),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
PyErrState::Normalized(PyErrStateNormalized {
PyErrStateInner::Normalized(PyErrStateNormalized {
ptype,
pvalue,
ptraceback,
Expand All @@ -199,10 +279,10 @@ impl PyErrState {
}

#[cfg(Py_3_12)]
pub(crate) fn restore(self, py: Python<'_>) {
pub fn restore(self, py: Python<'_>) {
match self {
PyErrState::Lazy(lazy) => raise_lazy(py, lazy),
PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
},
}
Expand Down
Loading

0 comments on commit 9f344ac

Please sign in to comment.