Skip to content

Commit

Permalink
Replace pybind11 with nanobind in frontend (#1173)
Browse files Browse the repository at this point in the history
**Context:**

The Catalyst frontend contains a small Python extension module at
[frontend/catalyst/utils/wrapper.cpp](https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/utils/wrapper.cpp),
which is an importable Python module written in C/C++ used to wrap the
entry point of compiled programs. The Python-C++ bindings were
originally implemented using pybind11. This PR is part of a larger
effort to replace all pybind11 code with nanobind.

**Description of the Change:**

This change replaces all the pybind11 code in the frontend with the
equivalent nanobind objects and operations.

It was also necessary to modify the frontend build system to build the
`wrapper` module using CMake for compatibility with nanobind, rather
than in `setup.py` with the `intree_extensions` setuptools utility
included with pybind11.

**Benefits:**

See Epic [68472](https://app.shortcut.com/xanaduai/epic/68472) for a
list of nanobind's benefits.

-----


[[sc-72837](https://app.shortcut.com/xanaduai/story/72837/replace-pybind11-with-nanobind-in-the-frontend)]

---------

Co-authored-by: Lee James O'Riordan <[email protected]>
  • Loading branch information
joeycarter and mlxd authored Nov 7, 2024
1 parent 24cd67d commit 143564a
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 81 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ jobs:
- name: Install Deps
run: |
sudo apt-get update
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make ninja-build
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install -r requirements.txt
# cuda-quantum is added manually here.
Expand Down Expand Up @@ -481,7 +481,7 @@ jobs:
- name: Install Deps
run: |
sudo apt-get update
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make ninja-build
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install -r requirements.txt
make frontend
Expand Down Expand Up @@ -536,7 +536,7 @@ jobs:
- name: Install Deps
run: |
sudo apt-get update
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make
sudo apt-get install -y python3 python3-pip libomp-dev libasan6 make ninja-build
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install -r requirements.txt
make frontend
Expand Down
3 changes: 3 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ build:
python: "3.10"
apt_packages:
- graphviz
- cmake
- ninja-build
- clang

# Optionally set the version of Python and requirements required to build your docs
python:
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ help:
@echo " test to run the Catalyst test suites"
@echo " docs to build the documentation for Catalyst"
@echo " clean to uninstall Catalyst and delete all temporary and cache files"
@echo " clean-frontend to clean build files of Catalyst Frontend"
@echo " clean-mlir to clean build files of MLIR and custom Catalyst dialects"
@echo " clean-runtime to clean build files of Catalyst Runtime"
@echo " clean-oqc to clean build files of OQC Runtime"
Expand Down Expand Up @@ -201,12 +202,16 @@ clean:
rm -rf dist __pycache__
rm -rf .coverage coverage_html_report

clean-all: clean-mlir clean-runtime clean-oqc
clean-all: clean-frontend clean-mlir clean-runtime clean-oqc
@echo "uninstall catalyst and delete all temporary, cache, and build files"
$(PYTHON) -m pip uninstall -y pennylane-catalyst
rm -rf dist __pycache__
rm -rf .coverage coverage_html_report/

.PHONY: clean-frontend
clean-frontend:
find frontend/catalyst -name "*.so" -exec rm -v {} +

.PHONY: clean-mlir clean-dialects clean-llvm clean-mhlo clean-enzyme
clean-mlir:
$(MAKE) -C mlir clean
Expand Down
12 changes: 11 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

<h3>Improvements 🛠</h3>

* Replace pybind11 with nanobind for C++/Python bindings in the frontend.
[(#1173)](https://github.com/PennyLaneAI/catalyst/pull/1173)

Nanobind has been developed as a natural successor to the pybind11 library and offers a number of
[advantages](https://nanobind.readthedocs.io/en/latest/why.html#major-additions), in particular,
its ability to target Python's [stable ABI interface](https://docs.python.org/3/c-api/stable.html)
starting with Python 3.12.

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand All @@ -16,4 +24,6 @@

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
This release contains contributions from (in alphabetical order):

Joey Carter
5 changes: 5 additions & 0 deletions frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.26)

project(catalyst_frontend LANGUAGES CXX)

add_subdirectory(catalyst)
1 change: 1 addition & 0 deletions frontend/catalyst/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(utils)
46 changes: 46 additions & 0 deletions frontend/catalyst/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Python 3
REQUIRED COMPONENTS Interpreter Development.Module
OPTIONAL_COMPONENTS Development.SABIModule)

# nanobind suggests including these lines to configure CMake to perform an optimized release build
# by default unless another build type is specified. Without this addition, binding code may run
# slowly and produce large binaries.
# See https://nanobind.readthedocs.io/en/latest/building.html#preliminaries
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

# Detect the installed nanobind package and import it into CMake
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_VARIABLE nanobind_ROOT OUTPUT_STRIP_TRAILING_WHITESPACE)

find_package(nanobind CONFIG REQUIRED)

# Get the NumPy include directory
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import numpy; print(numpy.get_include())"
OUTPUT_VARIABLE NUMPY_INCLUDE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)

# Source file list for `wrapper` module
set(WRAPPER_SRC_FILES
wrapper.cpp
)

# Create the Python `wrapper` module
# Target the stable ABI for Python 3.12+, which reduces the number of binary wheels that must be
# built (`STABLE_ABI` does nothing on older Python versions).
nanobind_add_module(wrapper STABLE_ABI ${WRAPPER_SRC_FILES})

# Add the NumPy include directory to the library's include paths
target_include_directories(wrapper PRIVATE ${NUMPY_INCLUDE_DIR})

# Use suffix ".so" rather than ".abi3.so" for library file using Stable ABI
# This is necessary for compatibility with setuptools build extensions
set_target_properties(wrapper PROPERTIES SUFFIX ".so")
89 changes: 49 additions & 40 deletions frontend/catalyst/utils/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
// limitations under the License.

#include <csignal>
#include <iostream>
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>

// TODO: Periodically check and increment version.
// https://endoflife.date/numpy
#define NPY_NO_DEPRECATED_API NPY_1_24_API_VERSION

#include "numpy/arrayobject.h"
#include "numpy/ndarrayobject.h"

namespace py = pybind11;
namespace nb = nanobind;

struct memref_beginning_t {
char *allocated;
Expand Down Expand Up @@ -50,7 +49,7 @@ size_t *to_sizes(char *base, size_t rank)
size_t aligned = sizeof(void *);
size_t offset = sizeof(size_t);
size_t bytes_offset = allocated + aligned + offset;
return (size_t *)(base + bytes_offset);
return reinterpret_cast<size_t *>(base + bytes_offset);
}

size_t *to_strides(char *base, size_t rank)
Expand All @@ -64,7 +63,7 @@ size_t *to_strides(char *base, size_t rank)
size_t offset = sizeof(size_t);
size_t sizes = rank * sizeof(size_t);
size_t bytes_offset = allocated + aligned + offset + sizes;
return (size_t *)(base + bytes_offset);
return reinterpret_cast<size_t *>(base + bytes_offset);
}

void free_wrap(PyObject *capsule)
Expand All @@ -76,7 +75,7 @@ void free_wrap(PyObject *capsule)
const npy_intp *npy_get_dimensions(char *memref, size_t rank)
{
size_t *sizes = to_sizes(memref, rank);
const npy_intp *dims = (npy_intp *)sizes;
const npy_intp *dims = reinterpret_cast<npy_intp *>(sizes);
return dims;
}

Expand All @@ -87,41 +86,42 @@ const npy_intp *npy_get_strides(char *memref, size_t element_size, size_t rank)
// memref strides are in terms of elements.
// numpy strides are in terms of bytes.
// Therefore multiply by element size.
strides[idx] *= (size_t)element_size;
strides[idx] *= element_size;
}

npy_intp *npy_strides = (npy_intp *)strides;
npy_intp *npy_strides = reinterpret_cast<npy_intp *>(strides);
return npy_strides;
}

py::list move_returns(void *memref_array_ptr, py::object result_desc, py::object transfer,
py::dict numpy_arrays)
nb::list move_returns(void *memref_array_ptr, nb::object result_desc, nb::object transfer,
nb::dict numpy_arrays)
{
py::list returns;
nb::list returns;
if (result_desc.is_none()) {
return returns;
}

auto ctypes = py::module::import("ctypes");
auto ctypes = nb::module_::import_("ctypes");
using f_ptr_t = bool (*)(void *);
f_ptr_t f_transfer_ptr = *((f_ptr_t *)ctypes.attr("addressof")(transfer).cast<size_t>());
f_ptr_t f_transfer_ptr = *((f_ptr_t *)nb::cast<size_t>(ctypes.attr("addressof")(transfer)));

/* Data from the result description */
auto ranks = result_desc.attr("_ranks_");
auto etypes = result_desc.attr("_etypes_");
auto sizes = result_desc.attr("_sizes_");

size_t memref_len = ranks.attr("__len__")().cast<size_t>();
size_t memref_len = nb::cast<size_t>(ranks.attr("__len__")());
size_t offset = 0;

char *memref_array_bytes = (char *)(memref_array_ptr);
char *memref_array_bytes = reinterpret_cast<char *>(memref_array_ptr);

for (size_t idx = 0; idx < memref_len; idx++) {
unsigned int rank_i = ranks.attr("__getitem__")(idx).cast<unsigned int>();
unsigned int rank_i = nb::cast<unsigned int>(ranks.attr("__getitem__")(idx));
char *memref_i_beginning = memref_array_bytes + offset;
offset += memref_size_based_on_rank(rank_i);

struct memref_beginning_t *memref = (struct memref_beginning_t *)memref_i_beginning;
struct memref_beginning_t *memref =
reinterpret_cast<struct memref_beginning_t *>(memref_i_beginning);
bool is_in_rt_heap = f_transfer_ptr(memref->allocated);

if (!is_in_rt_heap) {
Expand All @@ -133,15 +133,16 @@ py::list move_returns(void *memref_array_ptr, py::object result_desc, py::object
// The first case is guaranteed by the use of the flag --cp-global-memref
//
// Use the numpy_arrays dictionary which sets up the following map:
// integer (memory address) -> py::object (numpy array)
auto array_object = numpy_arrays.attr("__getitem__")((size_t)memref->allocated);
// integer (memory address) -> nb::object (numpy array)
auto array_object =
numpy_arrays.attr("__getitem__")(reinterpret_cast<size_t>(memref->allocated));
returns.append(array_object);
continue;
}

const npy_intp *dims = npy_get_dimensions(memref_i_beginning, rank_i);

size_t element_size = sizes.attr("__getitem__")(idx).cast<size_t>();
size_t element_size = nb::cast<size_t>(sizes.attr("__getitem__")(idx));
const npy_intp *strides = npy_get_strides(memref_i_beginning, element_size, rank_i);

auto etype_i = etypes.attr("__getitem__")(idx);
Expand All @@ -157,79 +158,87 @@ py::list move_returns(void *memref_array_ptr, py::object result_desc, py::object
throw std::runtime_error("PyArray_NewFromDescr failed.");
}

PyObject *capsule =
PyCapsule_New(memref->allocated, NULL, (PyCapsule_Destructor)&free_wrap);
PyObject *capsule = PyCapsule_New(memref->allocated, NULL,
reinterpret_cast<PyCapsule_Destructor>(&free_wrap));
if (!capsule) {
throw std::runtime_error("PyCapsule_New failed.");
}

int retval = PyArray_SetBaseObject((PyArrayObject *)new_array, capsule);
int retval = PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(new_array), capsule);
bool success = 0 == retval;
if (!success) {
throw std::runtime_error("PyArray_SetBaseObject failed.");
}

returns.append(new_array);
returns.append(nb::borrow(new_array)); // nb::borrow increments ref count by 1

// Now we insert the array into the dictionary.
// This dictionary is a map of the type:
// integer (memory address) -> py::object (numpy array)
// integer (memory address) -> nb::object (numpy array)
//
// Upon first entry into this function, it holds the numpy.arrays
// sent as an input to the generated function.
// Upon following entries it is extended with the numpy.arrays
// which are the output of the generated function.
PyObject *pyLong = PyLong_FromLong((size_t)memref->allocated);
PyObject *pyLong = PyLong_FromLong(reinterpret_cast<size_t>(memref->allocated));
if (!pyLong) {
throw std::runtime_error("PyLong_FromLong failed.");
}

numpy_arrays[pyLong] = new_array;
numpy_arrays[pyLong] = nb::borrow(new_array); // nb::borrow increments ref count by 1

// Decrement reference counts.
// The final ref count of `new_array` should be 2: one for the `returns` list and one for
// the `numpy_arrays` dict.
Py_DECREF(pyLong);
Py_DECREF(new_array);
}
return returns;
}

py::list wrap(py::object func, py::tuple py_args, py::object result_desc, py::object transfer,
py::dict numpy_arrays)
nb::list wrap(nb::object func, nb::tuple py_args, nb::object result_desc, nb::object transfer,
nb::dict numpy_arrays)
{
// Install signal handler to catch user interrupts (e.g. CTRL-C).
signal(SIGINT, [](int code) { throw std::runtime_error("KeyboardInterrupt (SIGINT)"); });

py::list returns;
nb::list returns;

size_t length = py_args.attr("__len__")().cast<size_t>();
size_t length = nb::cast<size_t>(py_args.attr("__len__")());
if (length != 2) {
throw std::invalid_argument("Invalid number of arguments.");
}

auto ctypes = py::module::import("ctypes");
auto ctypes = nb::module_::import_("ctypes");
using f_ptr_t = void (*)(void *, void *);
f_ptr_t f_ptr = *reinterpret_cast<f_ptr_t *>(ctypes.attr("addressof")(func).cast<size_t>());
f_ptr_t f_ptr = *reinterpret_cast<f_ptr_t *>(nb::cast<size_t>(ctypes.attr("addressof")(func)));

auto value0 = py_args.attr("__getitem__")(0);
void *value0_ptr = *reinterpret_cast<void **>(ctypes.attr("addressof")(value0).cast<size_t>());
void *value0_ptr =
*reinterpret_cast<void **>(nb::cast<size_t>(ctypes.attr("addressof")(value0)));
auto value1 = py_args.attr("__getitem__")(1);
void *value1_ptr = *reinterpret_cast<void **>(ctypes.attr("addressof")(value1).cast<size_t>());
void *value1_ptr =
*reinterpret_cast<void **>(nb::cast<size_t>(ctypes.attr("addressof")(value1)));

{
py::gil_scoped_release lock;
nb::gil_scoped_release lock;
f_ptr(value0_ptr, value1_ptr);
}
returns = move_returns(value0_ptr, result_desc, transfer, numpy_arrays);

return returns;
}

PYBIND11_MODULE(wrapper, m)
NB_MODULE(wrapper, m)
{
m.doc() = "wrapper module";
m.def("wrap", &wrap, "A wrapper function.");
// We have to annotate all the arguments to `wrap` to allow `result_desc` to be None
// See https://nanobind.readthedocs.io/en/latest/functions.html#none-arguments
m.def("wrap", &wrap, "A wrapper function.", nb::arg("func"), nb::arg("py_args"),
nb::arg("result_desc").none(), nb::arg("transfer"), nb::arg("numpy_arrays"));
int retval = _import_array();
bool success = retval >= 0;
if (!success) {
throw pybind11::import_error("Couldn't import numpy array C-API.");
throw nb::import_error("Could not import numpy array C-API.");
}
}
6 changes: 3 additions & 3 deletions frontend/test/pytest/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def identity(arg) -> int:
def cir(x):
return identity(x)

with pytest.raises(TypeError, match="Callback identity expected type"):
with pytest.raises(RuntimeError, match="TypeError: Callback identity expected type"):
cir(arg)


Expand Down Expand Up @@ -334,7 +334,7 @@ def cir(x):
captured = capsys.readouterr()
assert captured.out.strip() == ""

with pytest.raises(ValueError, match="debug.callback is expected to return None"):
with pytest.raises(RuntimeError, match="ValueError: debug.callback is expected to return None"):
cir(0)


Expand Down Expand Up @@ -953,7 +953,7 @@ def result(x):
return jnp.sum(some_func(jnp.sin(x)))

x = 0.435
with pytest.raises(TypeError, match="Callback some_func expected type"):
with pytest.raises(RuntimeError, match="TypeError: Callback some_func expected type"):
result(x)


Expand Down
Loading

0 comments on commit 143564a

Please sign in to comment.