Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libshortfin] Add proper unit tests for device_array, storage, and mapping. #164

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions libshortfin/bindings/python/array_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void BindArray(py::module_ &m) {
src_info.view().len);
},
DOCSTRING_STORAGE_DATA)
.def(py::self == py::self)
.def("__repr__", &storage::to_s);

// mapping
Expand All @@ -209,7 +210,7 @@ void BindArray(py::module_ &m) {
int operator()(mapping &self, Py_buffer *view, int flags) {
view->buf = self.data();
view->len = self.size();
view->readonly = self.writable();
view->readonly = !self.writable();
view->itemsize = 1;
view->format = (char *)"B"; // Byte
view->ndim = 1;
Expand Down Expand Up @@ -285,7 +286,12 @@ void BindArray(py::module_ &m) {
DOCSTRING_ARRAY_COPY_FROM)
.def("copy_to", &device_array::copy_to, py::arg("dest_array"),
DOCSTRING_ARRAY_COPY_TO)
.def("__repr__", &device_array::to_s);
.def("__repr__", &device_array::to_s)
.def("__str__", [](device_array &self) -> std::string {
auto contents = self.contents_to_s();
if (!contents) return "<<unmappable>>";
return *contents;
});
}

} // namespace shortfin::python
41 changes: 40 additions & 1 deletion libshortfin/bindings/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,

local::Worker &worker = self.init_worker();
py::object result = py::none();
py::object py_exception = py::none();
auto done_callback = [&](py::handle future) {
worker.Kill();
result = future.attr("result")();
py_exception = future.attr("exception")();
if (py_exception.is_none()) {
result = future.attr("result")();
}
};
worker.CallThreadsafe([&]() {
// Run within the worker we are about to donate to.
Expand Down Expand Up @@ -293,12 +297,47 @@ py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
}

self.Shutdown();

if (!py_exception.is_none()) {
// We got this exception from a future/user code, which could have done
// something nefarious. So type check it.
if (PyObject_IsInstance(py_exception.ptr(), PyExc_Exception)) {
PyErr_SetObject(py_exception.type().ptr(), py_exception.ptr());
} else {
PyErr_SetObject(PyExc_RuntimeError, py_exception.ptr());
}
throw py::python_error();
}
return result;
}

} // namespace

NB_MODULE(lib, m) {
py::register_exception_translator(
[](const std::exception_ptr &p, void * /*unused*/) {
try {
std::rethrow_exception(p);
} catch (shortfin::iree::error &e) {
PyObject *exc_type;
switch (e.code()) {
case IREE_STATUS_INVALID_ARGUMENT:
case IREE_STATUS_OUT_OF_RANGE:
exc_type = PyExc_ValueError;
break;
case IREE_STATUS_FAILED_PRECONDITION:
exc_type = PyExc_AssertionError;
break;
case IREE_STATUS_UNIMPLEMENTED:
exc_type = PyExc_NotImplementedError;
break;
default:
exc_type = PyExc_RuntimeError;
}
PyErr_SetString(PyExc_ValueError, e.what());
}
});

py::class_<iree::vm_opaque_ref>(m, "_OpaqueVmRef");
auto local_m = m.def_submodule("local");
BindLocal(local_m);
Expand Down
11 changes: 11 additions & 0 deletions libshortfin/src/shortfin/array/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ template class InlinedDims<iree_hal_dim_t>;
// device_array
// -------------------------------------------------------------------------- //

device_array::device_array(class storage storage,
std::span<const Dims::value_type> shape, DType dtype)
: base_array(shape, dtype), storage_(std::move(storage)) {
auto needed_size = this->dtype().compute_dense_nd_size(this->shape());
if (storage_.byte_length() < needed_size) {
throw std::invalid_argument(
fmt::format("Array storage requires at least {} bytes but has only {}",
needed_size, storage_.byte_length()));
}
}

const mapping device_array::data() const { return storage_.map_read(); }

mapping device_array::data() { return storage_.map_read(); }
Expand Down
3 changes: 1 addition & 2 deletions libshortfin/src/shortfin/array/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ class SHORTFIN_API device_array
public local::ProgramInvocationMarshalable {
public:
device_array(class storage storage, std::span<const Dims::value_type> shape,
DType dtype)
: base_array(shape, dtype), storage_(std::move(storage)) {}
DType dtype);

class storage &storage() { return storage_; }
local::ScopedDevice &device() { return storage_.device(); }
Expand Down
4 changes: 4 additions & 0 deletions libshortfin/src/shortfin/array/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ class SHORTFIN_API storage : public local::ProgramInvocationMarshalable {

std::string to_s() const;

bool operator==(const storage &other) const {
return other.buffer_.get() == buffer_.get();
}

// Access raw buffer. This must not be retained apart from the storage for
// any length of time that may extend its lifetime (as the storage keeps
// underlying device references alive as needed).
Expand Down
4 changes: 3 additions & 1 deletion libshortfin/src/shortfin/support/iree_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ void SHORTFIN_API LogLiveRefs() {
} // namespace detail

error::error(std::string message, iree_status_t failing_status)
: message_(std::move(message)), failing_status_(failing_status) {
: code_(iree_status_code(failing_status)),
message_(std::move(message)),
failing_status_(failing_status) {
message_.append(": ");
}
error::error(iree_status_t failing_status) : failing_status_(failing_status) {}
Expand Down
3 changes: 3 additions & 0 deletions libshortfin/src/shortfin/support/iree_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,11 @@ class SHORTFIN_API error : public std::exception {
return message_.c_str();
};

iree_status_code_t code() const { return code_; }

private:
void AppendStatus() const noexcept;
iree_status_code_t code_;
mutable std::string message_;
mutable iree_status_t failing_status_;
mutable bool status_appended_ = false;
Expand Down
185 changes: 185 additions & 0 deletions libshortfin/tests/api/array_storage_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest

import shortfin as sf
import shortfin.array as sfnp


@pytest.fixture
def lsys():
sc = sf.host.CPUSystemBuilder()
lsys = sc.create_system()
yield lsys
lsys.shutdown()


@pytest.fixture
def scope(lsys):
return lsys.create_scope()


@pytest.fixture
def device(scope):
return scope.device(0)


def test_allocate_host(device):
s = sfnp.storage.allocate_host(device, 32)
assert len(bytes(s.data)) == 32


def test_allocate_device(device):
s = sfnp.storage.allocate_device(device, 64)
assert len(bytes(s.data)) == 64


def test_fill1(lsys, device):
async def main():
s = sfnp.storage.allocate_host(device, 8)
s.fill(b"0")
await device
assert bytes(s.data) == b"00000000"

lsys.run(main())


def test_fill2(lsys, device):
async def main():
s = sfnp.storage.allocate_host(device, 8)
s.fill(b"01")
await device
assert bytes(s.data) == b"01010101"

lsys.run(main())


def test_fill4(lsys, device):
async def main():
s = sfnp.storage.allocate_host(device, 8)
s.fill(b"0123")
await device
assert bytes(s.data) == b"01230123"

lsys.run(main())


def test_fill_error(device):
s = sfnp.storage.allocate_host(device, 8)
with pytest.raises(RuntimeError):
s.fill(b"")
with pytest.raises(RuntimeError):
s.fill(b"012")
with pytest.raises(RuntimeError):
s.fill(b"01234")
with pytest.raises(RuntimeError):
s.fill(b"01234567")


@pytest.mark.parametrize(
"pattern,size",
[
(b"", 8),
(b"012", 8),
(b"01234", 8),
(b"01234567", 8),
],
)
def test_fill_error(lsys, device, pattern, size):
async def main():
src = sfnp.storage.allocate_host(device, size)
src.fill(pattern)

with pytest.raises(
ValueError, match="fill value length is not one of the supported values"
):
lsys.run(main())


def test_map_read(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
src.fill(b"0123")
await device
with src.map(read=True) as m:
assert m.valid
assert bytes(m) == b"01230123"

lsys.run(main())


def test_map_read_not_writable(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
src.fill(b"0123")
await device
with src.map(read=True) as m:
mv = memoryview(m)
assert mv.readonly
mv[0] = ord(b"9")

with pytest.raises(TypeError, match="cannot modify"):
lsys.run(main())


def test_map_write(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
src.fill(b"0123")
await device
with src.map(read=True, write=True) as m:
mv = memoryview(m)
assert not mv.readonly
mv[0] = ord(b"9")
assert bytes(src.data) == b"91230123"

lsys.run(main())


def test_map_discard(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
src.fill(b"0123")
await device
with src.map(write=True, discard=True) as m:
mv = memoryview(m)
assert not mv.readonly
for i in range(8):
mv[i] = ord(b"9") - i
assert bytes(src.data) == b"98765432"

lsys.run(main())


def test_data_write(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
src.data = b"98765432"
assert bytes(src.data) == b"98765432"

lsys.run(main())


def test_mapping_explicit_close(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
m = src.map(write=True, discard=True)
assert m.valid
m.close()
assert not m.valid

lsys.run(main())


def test_mapping_context_manager(lsys, device):
async def main():
src = sfnp.storage.allocate_host(device, 8)
with src.map(write=True, discard=True) as m:
assert m.valid
assert not m.valid

lsys.run(main())
Loading
Loading