Skip to content

Commit

Permalink
Replacing pybind11 with nanobind (#83)
Browse files Browse the repository at this point in the history
* Replacing pybind11 with nanobind

* removing extra unused namespace

* build: nanobind doesn't enable LTO by default, no need to globally disable

* Including @lgarrison's suggestion

---------

Co-authored-by: Lehman Garrison <[email protected]>
  • Loading branch information
dfm and lgarrison authored Apr 22, 2024
1 parent 93fb062 commit 816b4d5
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 159 deletions.
26 changes: 15 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
cmake_minimum_required(VERSION 3.15...3.26)
cmake_minimum_required(VERSION 3.15...3.27)
project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})

# for cuda-gdb and verbose PTXAS output
# set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "-g -G -Xptxas -v")

# Workaround for LTO applied incorrectly to CUDA fatbin
# https://github.com/pybind/pybind11/issues/4825
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

# Enable OpenMP if requested and available
option(JAX_FINUFFT_USE_OPENMP "Enable OpenMP" ON)

if(JAX_FINUFFT_USE_OPENMP)
find_package(OpenMP)

if(OpenMP_CXX_FOUND)
message(STATUS "jax_finufft: OpenMP found")
set(FINUFFT_USE_OPENMP ON)
Expand All @@ -26,6 +24,7 @@ endif()

# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF)

if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)
Expand All @@ -48,16 +47,21 @@ endif()
# Add the FINUFFT project using the vendored version
add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/vendor/finufft")

# Find pybind11
set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
# Find Python and nanobind
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(nanobind CONFIG REQUIRED)

if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

# Build the CPU XLA bindings
pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc)
nanobind_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc)
target_link_libraries(jax_finufft_cpu PRIVATE finufft_static)
install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .)

if (FINUFFT_USE_OPENMP)
if(FINUFFT_USE_OPENMP)
target_compile_definitions(jax_finufft_cpu PRIVATE FINUFFT_USE_OPENMP)
endif()

Expand All @@ -75,7 +79,7 @@ if(FINUFFT_USE_CUDA)
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/cuda_samples
)
pybind11_add_module(jax_finufft_gpu
nanobind_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/cufinufft_wrapper.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
Expand Down
85 changes: 37 additions & 48 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// It is exposed as a standard nanobind module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "jax_finufft_cpu.h"

#include "pybind11_kernel_helpers.h"
#include "nanobind_kernel_helpers.h"

using namespace jax_finufft;
using namespace jax_finufft::cpu;
namespace py = pybind11;
namespace nb = nanobind;

namespace {

Expand Down Expand Up @@ -68,41 +68,14 @@ void nufft2(void *out, void **in) {
}

template <typename T>
py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
nb::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, finufft_opts opts) {
return pack_descriptor(
descriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts});
}

template <typename T>
finufft_opts *build_opts(bool modeord, bool chkbnds, int debug, int spread_debug, bool showwarn,
int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);

opts->modeord = int(modeord);
opts->chkbnds = int(chkbnds);
opts->debug = debug;
opts->spread_debug = spread_debug;
opts->showwarn = int(showwarn);
opts->nthreads = nthreads;
opts->fftw = fftw;
opts->spread_sort = spread_sort;
opts->spread_kerevalmeth = int(spread_kerevalmeth);
opts->spread_kerpad = int(spread_kerpad);
opts->upsampfac = upsampfac;
opts->spread_thread = int(spread_thread);
opts->maxbatchsize = maxbatchsize;
opts->spread_nthr_atomic = spread_nthr_atomic;
opts->spread_max_sp_size = spread_max_sp_size;

return opts;
}

pybind11::dict Registrations() {
pybind11::dict dict;
nb::dict Registrations() {
nb::dict dict;

dict["nufft1d1f"] = encapsulate_function(nufft1<1, float>);
dict["nufft1d2f"] = encapsulate_function(nufft2<1, float>);
Expand All @@ -121,7 +94,7 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft_cpu, m) {
NB_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);
Expand All @@ -134,20 +107,36 @@ PYBIND11_MODULE(jax_finufft_cpu, m) {
#endif
});

m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT);
m.attr("FFTW_EXHAUSTIVE") = py::int_(FFTW_EXHAUSTIVE);
m.attr("FFTW_WISDOM_ONLY") = py::int_(FFTW_WISDOM_ONLY);

py::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def(py::init(&build_opts<double>), py::arg("modeord") = false, py::arg("chkbnds") = true,
py::arg("debug") = 0, py::arg("spread_debug") = 0, py::arg("showwarn") = false,
py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE),
py::arg("spread_sort") = 2, py::arg("spread_kerevalmeth") = true,
py::arg("spread_kerpad") = true, py::arg("upsampfac") = 0.0,
py::arg("spread_thread") = 0, py::arg("maxbatchsize") = 0,
py::arg("spread_nthr_atomic") = -1, py::arg("spread_max_sp_size") = 0);
m.attr("FFTW_ESTIMATE") = nb::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = nb::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = nb::int_(FFTW_PATIENT);
m.attr("FFTW_EXHAUSTIVE") = nb::int_(FFTW_EXHAUSTIVE);
m.attr("FFTW_WISDOM_ONLY") = nb::int_(FFTW_WISDOM_ONLY);

nb::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def("__init__",
[](finufft_opts *self, bool modeord, bool chkbnds, int debug, int spread_debug,
bool showwarn, int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
new (self) finufft_opts;
default_opts<double>(self);
self->modeord = int(modeord);
self->chkbnds = int(chkbnds);
self->debug = debug;
self->spread_debug = spread_debug;
self->showwarn = int(showwarn);
self->nthreads = nthreads;
self->fftw = fftw;
self->spread_sort = spread_sort;
self->spread_kerevalmeth = int(spread_kerevalmeth);
self->spread_kerpad = int(spread_kerpad);
self->upsampfac = upsampfac;
self->spread_thread = int(spread_thread);
self->maxbatchsize = maxbatchsize;
self->spread_nthr_atomic = spread_nthr_atomic;
self->spread_max_sp_size = spread_max_sp_size;
});
}

} // namespace
68 changes: 29 additions & 39 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
@@ -1,51 +1,26 @@
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// It is exposed as a standard nanobind module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "cufinufft_wrapper.h"
#include "kernels.h"
#include "pybind11_kernel_helpers.h"
#include "nanobind_kernel_helpers.h"

using namespace jax_finufft;
using namespace jax_finufft::gpu;
namespace py = pybind11;
namespace nb = nanobind;

namespace {

template <typename T>
py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
nb::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, cufinufft_opts opts) {
return pack_descriptor(
descriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts});
}

template <typename T>
cufinufft_opts *build_opts(double upsampfac, int gpu_method, bool gpu_sort, int gpu_binsizex,
int gpu_binsizey, int gpu_binsizez, int gpu_obinsizex,
int gpu_obinsizey, int gpu_obinsizez, int gpu_maxsubprobsize,
bool gpu_kerevalmeth, int gpu_spreadinterponly, int gpu_maxbatchsize) {
cufinufft_opts *opts = new cufinufft_opts;
default_opts<T>(opts);

opts->upsampfac = upsampfac;
opts->gpu_method = gpu_method;
opts->gpu_sort = int(gpu_sort);
opts->gpu_binsizex = gpu_binsizex;
opts->gpu_binsizey = gpu_binsizey;
opts->gpu_binsizez = gpu_binsizez;
opts->gpu_obinsizex = gpu_obinsizex;
opts->gpu_obinsizey = gpu_obinsizey;
opts->gpu_obinsizez = gpu_obinsizez;
opts->gpu_maxsubprobsize = gpu_maxsubprobsize;
opts->gpu_kerevalmeth = gpu_kerevalmeth;
opts->gpu_spreadinterponly = gpu_spreadinterponly;
opts->gpu_maxbatchsize = gpu_maxbatchsize;

return opts;
}

pybind11::dict Registrations() {
pybind11::dict dict;
nb::dict Registrations() {
nb::dict dict;

// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"?
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
Expand All @@ -65,18 +40,33 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft_gpu, m) {
NB_MODULE(jax_finufft_gpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);

py::class_<cufinufft_opts> opts(m, "CufinufftOpts");
opts.def(py::init(&build_opts<double>), py::arg("upsampfac") = 2.0, py::arg("gpu_method") = 0,
py::arg("gpu_sort") = true, py::arg("gpu_binsizex") = -1, py::arg("gpu_binsizey") = -1,
py::arg("gpu_binsizez") = -1, py::arg("gpu_obinsizex") = -1,
py::arg("gpu_obinsizey") = -1, py::arg("gpu_obinsizez") = -1,
py::arg("gpu_maxsubprobsize") = 1024, py::arg("gpu_kerevalmeth") = true,
py::arg("gpu_spreadinterponly") = 0, py::arg("gpu_maxbatchsize") = 0);
nb::class_<cufinufft_opts> opts(m, "CufinufftOpts");
opts.def("__init__", [](cufinufft_opts *self, double upsampfac, int gpu_method, bool gpu_sort,
int gpu_binsizex, int gpu_binsizey, int gpu_binsizez, int gpu_obinsizex,
int gpu_obinsizey, int gpu_obinsizez, int gpu_maxsubprobsize,
bool gpu_kerevalmeth, int gpu_spreadinterponly, int gpu_maxbatchsize) {
new (self) cufinufft_opts;
default_opts<double>(self);

self->upsampfac = upsampfac;
self->gpu_method = gpu_method;
self->gpu_sort = int(gpu_sort);
self->gpu_binsizex = gpu_binsizex;
self->gpu_binsizey = gpu_binsizey;
self->gpu_binsizez = gpu_binsizez;
self->gpu_obinsizex = gpu_obinsizex;
self->gpu_obinsizey = gpu_obinsizey;
self->gpu_obinsizez = gpu_obinsizez;
self->gpu_maxsubprobsize = gpu_maxsubprobsize;
self->gpu_kerevalmeth = gpu_kerevalmeth;
self->gpu_spreadinterponly = gpu_spreadinterponly;
self->gpu_maxbatchsize = gpu_maxbatchsize;
});
}

} // namespace
5 changes: 0 additions & 5 deletions lib/kernel_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ bit_cast(const From& src) noexcept {
return dst;
}

template <typename T>
std::string pack_descriptor_as_string(const T& descriptor) {
return std::string(bit_cast<const char*>(&descriptor), sizeof(T));
}

template <typename T>
const T* unpack_descriptor(const char* opaque, std::size_t opaque_len) {
if (opaque_len != sizeof(T)) {
Expand Down
28 changes: 28 additions & 0 deletions lib/nanobind_kernel_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// This header extends kernel_helpers.h with the nanobind specific interface to
// serializing descriptors. It also adds a nanobind function for wrapping our
// custom calls in a Python capsule. This is separate from kernel_helpers so that
// the CUDA code itself doesn't include nanobind. I don't think that this is
// strictly necessary, but they do it in jaxlib, so let's do it here too.

#ifndef _JAX_FINUFFT_NANOBIND_KERNEL_HELPERS_H_
#define _JAX_FINUFFT_NANOBIND_KERNEL_HELPERS_H_

#include <nanobind/nanobind.h>

#include "kernel_helpers.h"

namespace jax_finufft {

template <typename T>
nanobind::bytes pack_descriptor(const T& descriptor) {
return nanobind::bytes(bit_cast<const char*>(&descriptor), sizeof(T));
}

template <typename T>
nanobind::capsule encapsulate_function(T* fn) {
return nanobind::capsule(bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}

} // namespace jax_finufft

#endif
28 changes: 0 additions & 28 deletions lib/pybind11_kernel_helpers.h

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["pybind11>=2.6", "scikit-build-core>=0.5"]
requires = ["nanobind", "scikit-build-core>=0.5"]
build-backend = "scikit_build_core.build"

[project]
Expand Down
Loading

0 comments on commit 816b4d5

Please sign in to comment.