Skip to content

Commit

Permalink
[libshortfin] Implement invocation. (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident authored Sep 3, 2024
1 parent 89cc4c5 commit 08c69aa
Show file tree
Hide file tree
Showing 31 changed files with 1,505 additions and 412 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_linux_x64-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ jobs:
- name: Install Python packages
# TODO: Switch to `pip install -r requirements.txt -e libshortfin/`.
run: |
pip install nanobind
pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
pip freeze
- name: Build libshortfin (full)
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_linux_x64_asan-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ jobs:
run: |
eval "$(pyenv init -)"
pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
pip freeze
- name: Save Python dependencies cache
if: steps.cache-python-deps-restore.outputs.cache-hit != 'true'
Expand Down
126 changes: 84 additions & 42 deletions libshortfin/bindings/python/array_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ using namespace shortfin::array;
namespace shortfin::python {

namespace {
static const char DOCSTRING_ARRAY_COPY_FROM[] =
R"(Copy contents from a source array to this array.
Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
)";

static const char DOCSTRING_ARRAY_COPY_TO[] =
R"(Copy contents this array to a destination array.
Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
)";

static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value.
Equivalent to `array.storage.fill(pattern)`.
)";

static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents.
Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`.
Expand All @@ -28,6 +45,23 @@ As with `map`, this will only work on buffers that are host visible, which
includes all host buffers and device buffers created with the necessary access.
)";

static const char DOCSTRING_STORAGE_COPY_FROM[] =
R"(Copy contents from a source storage to this array.
This operation executes asynchronously and the effect will only be visible
once the execution scope has been synced to the point of mutation.
)";

static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value.
Takes as argument any value that can be interpreted as a buffer with the Python
buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly
with the pattern.
This operation executes asynchronously and the effect will only be visible
once the execution scope has been synced to the point of mutation.
)";

static const char DOCSTRING_STORAGE_MAP[] =
R"(Create a mapping of the buffer contents in host memory.
Expand Down Expand Up @@ -72,58 +106,47 @@ void BindArray(py::module_ &m) {
.def(py::self == py::self)
.def("__repr__", &DType::name);

m.attr("opaque8") = DType::opaque8();
m.attr("opaque16") = DType::opaque16();
m.attr("opaque32") = DType::opaque32();
m.attr("opaque64") = DType::opaque64();
m.attr("bool8") = DType::bool8();
m.attr("int4") = DType::int4();
m.attr("sint4") = DType::sint4();
m.attr("uint4") = DType::uint4();
m.attr("int8") = DType::int8();
m.attr("sint8") = DType::sint8();
m.attr("uint8") = DType::uint8();
m.attr("int16") = DType::int16();
m.attr("sint16") = DType::sint16();
m.attr("uint16") = DType::uint16();
m.attr("int32") = DType::int32();
m.attr("sint32") = DType::sint32();
m.attr("uint32") = DType::uint32();
m.attr("int64") = DType::int64();
m.attr("sint64") = DType::sint64();
m.attr("uint64") = DType::uint64();
m.attr("float16") = DType::float16();
m.attr("float32") = DType::float32();
m.attr("float64") = DType::float64();
m.attr("bfloat16") = DType::bfloat16();
m.attr("complex64") = DType::complex64();
m.attr("complex128") = DType::complex128();
#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident();
#include "shortfin/array/dtypes.inl"
#undef SHORTFIN_DTYPE_HANDLE

// storage
py::class_<storage>(m, "storage")
.def("__sfinv_marshal__",
[](device_array *self, py::capsule inv_capsule, int barrier) {
auto *inv =
static_cast<local::ProgramInvocation *>(inv_capsule.data());
static_cast<local::ProgramInvocationMarshalable *>(self)
->AddAsInvocationArgument(
inv, static_cast<local::ProgramResourceBarrier>(barrier));
})
.def_static(
"allocate_host",
[](local::ScopedDevice &device, iree_device_size_t allocation_size) {
return storage::AllocateHost(device, allocation_size);
return storage::allocate_host(device, allocation_size);
},
py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
.def_static(
"allocate_device",
[](local::ScopedDevice &device, iree_device_size_t allocation_size) {
return storage::AllocateDevice(device, allocation_size);
return storage::allocate_device(device, allocation_size);
},
py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
.def("fill",
[](storage &self, py::handle buffer) {
Py_buffer py_view;
int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
throw py::python_error();
}
PyBufferReleaser py_view_releaser(py_view);
self.Fill(py_view.buf, py_view.len);
})
.def("copy_from", [](storage &self, storage &src) { self.CopyFrom(src); })
.def(
"fill",
[](storage &self, py::handle buffer) {
Py_buffer py_view;
int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
throw py::python_error();
}
PyBufferReleaser py_view_releaser(py_view);
self.fill(py_view.buf, py_view.len);
},
py::arg("pattern"), DOCSTRING_STORAGE_FILL)
.def(
"copy_from", [](storage &self, storage &src) { self.copy_from(src); },
py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM)
.def(
"map",
[](storage &self, bool read, bool write, bool discard) {
Expand All @@ -137,7 +160,7 @@ void BindArray(py::module_ &m) {
}
mapping *cpp_mapping = nullptr;
py::object py_mapping = CreateMappingObject(&cpp_mapping);
self.MapExplicit(
self.map_explicit(
*cpp_mapping,
static_cast<iree_hal_memory_access_bits_t>(access));
return py_mapping;
Expand All @@ -154,12 +177,12 @@ void BindArray(py::module_ &m) {
[](storage &self) {
mapping *cpp_mapping = nullptr;
py::object py_mapping = CreateMappingObject(&cpp_mapping);
*cpp_mapping = self.MapRead();
*cpp_mapping = self.map_read();
return py_mapping;
},
[](storage &self, py::handle buffer_obj) {
PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE);
auto dest_data = self.MapWriteDiscard();
auto dest_data = self.map_write_discard();
if (src_info.view().len > dest_data.size()) {
throw std::invalid_argument(
fmt::format("Cannot write {} bytes into buffer of {} bytes",
Expand Down Expand Up @@ -219,6 +242,14 @@ void BindArray(py::module_ &m) {
py_type, /*keep_alive=*/device.scope(),
device_array::for_device(device, shape, dtype));
})
.def("__sfinv_marshal__",
[](device_array *self, py::capsule inv_capsule, int barrier) {
auto *inv =
static_cast<local::ProgramInvocation *>(inv_capsule.data());
static_cast<local::ProgramInvocationMarshalable *>(self)
->AddAsInvocationArgument(
inv, static_cast<local::ProgramResourceBarrier>(barrier));
})
.def_static("for_device",
[](local::ScopedDevice &device, std::span<const size_t> shape,
DType dtype) {
Expand All @@ -243,6 +274,17 @@ void BindArray(py::module_ &m) {
py::rv_policy::reference_internal)
.def_prop_ro("storage", &device_array::storage,
py::rv_policy::reference_internal)

.def(
"fill",
[](py::handle_t<device_array> self, py::handle buffer) {
self.attr("storage").attr("fill")(buffer);
},
py::arg("pattern"), DOCSTRING_ARRAY_FILL)
.def("copy_from", &device_array::copy_from, py::arg("source_array"),
DOCSTRING_ARRAY_COPY_FROM)
.def("copy_to", &device_array::copy_to, py::arg("dest_array"),
DOCSTRING_ARRAY_COPY_TO)
.def("__repr__", &device_array::to_s);
}

Expand Down
Loading

0 comments on commit 08c69aa

Please sign in to comment.