From 822d16932c2b42da9deb07fffff8c5ad1fc4d42a Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 2 May 2024 21:12:20 +0000 Subject: [PATCH 01/19] added the test and the pybinding. --- test/test_operations.py | 62 +++++++++++++++++++++++++ torch_xla/csrc/init_python_bindings.cpp | 10 ++++ torch_xla/utils/dlpack.py | 8 ++++ 3 files changed, 80 insertions(+) create mode 100644 torch_xla/utils/dlpack.py diff --git a/test/test_operations.py b/test/test_operations.py index ed8f5a88151..ea1caa389e7 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2448,6 +2448,68 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_3, 0) +class TestDLPack(test_utils.XlaTestCase): + + # TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_capsule_conversion(self): + # TODO(xw32): make sure to test the storage is tested. + t1 = torch.arange(5, device=xm.xla_device()) + xm.mark_step() + got1 = from_dlpack(to_dlpack(t1)) + self.assertEqual(t1.cpu(), got1.cpu()) + + t2 = torch.arange(5, device=xm.xla_device()) + got2 = from_dlpack(to_dlpack(t2)) + self.assertEqual(t2.cpu(), got2.cpu()) + + t3 = torch.tensor(5, device=xm.xla_device()) + got3 = from_dlpack(to_dlpack(t3)) + self.assertEqual(t3.cpu(), got3.cpu()) + + # TODO(xw32): figure it out what it is testing. + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_protocol_conversion(self): + t1 = torch.arange(5, device=xm.xla_device()) + xm.mark_step() + got1 = from_dlpack(t1) + self.assertEqual(t1.cpu(), got1.cpu()) + + t2 = torch.arange(5, device=xm.xla_device()) + got2 = from_dlpack(t2) + self.assertEqual(t2.cpu(), got2.cpu()) + + t3 = torch.tensor(5, device=xm.xla_device()) + got3 = from_dlpack(t3) + self.assertEqual(t3.cpu(), got3.cpu()) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_cuda_to_xla_shared_storage(self): + t1 = torch.arange(5).cuda() + dlt1 = torch.utils.dlpack.to_dlpack(t1) + xla_t1 = from_dlpack(dlt1) + t1[0] = t1[0] + 20 + self.assertEqual(t1, xla_t1.cpu()) + + t2 = torch.tensor(5).cuda() + dlt2 = torch.utils.dlpack.to_dlpack(t2) + xla_t2 = from_dlpack(dlt2) + t2.fill_(6) + self.assertEqual(t2, xla_t2.cpu()) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_xla_to_cuda_shared_storage(self): + xla_t1 = torch.arange(5, device=xm.xla_device()) + dlt1 = to_dlpack(xla_t1) + cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) + cuda_t1[0] = cuda_t1[0] + 20 + self.assertEqual(xla_t1.cpu(), cuda_t1.cpu()) + + class SimpleModelWithDropout(torch.nn.Module): def __init__(self): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bfb9f2bd3f2..11d9c7206da 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2508,6 +2508,16 @@ void InitXlaModuleBindings(py::module m) { "without a data handle or an IR."; }); + // from an XLA tensor to a dlpack tensor. + m.def("_to_dlpack", [](const at::Tensor& input) -> PyObject* { + return nullptr; + }); + // from a dlpack tensor to an XLA tensor + m.def("_from_dlpack", [](PyObject* ext_data) -> at::Tensor { + + }); + + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py new file mode 100644 index 00000000000..9ae99b8f802 --- /dev/null +++ b/torch_xla/utils/dlpack.py @@ -0,0 +1,8 @@ +from typing import Any +import torch_xla + +def to_dlpack(xla_tensor: Any): + return torch_xla._XLAC._to_dlpack(xla_tensor) + +def from_dlpack(ext_tensor: Any): + return torch_xla._XLAC._from_dlpack(ext_tensor) From 70996c41bd6ec02e597ffc7320d2ce6bc183bce7 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 2 May 2024 21:56:00 +0000 Subject: [PATCH 02/19] Add toDLPack and it compiles. --- torch_xla/csrc/BUILD | 2 + torch_xla/csrc/dl_convertor.cpp | 167 ++++++++++++++++++ torch_xla/csrc/dl_convertor.h | 13 ++ torch_xla/csrc/init_python_bindings.cpp | 14 +- torch_xla/csrc/runtime/computation_client.h | 3 + .../csrc/runtime/ifrt_computation_client.cc | 4 + .../csrc/runtime/ifrt_computation_client.h | 2 + .../csrc/runtime/pjrt_computation_client.cc | 7 + .../csrc/runtime/pjrt_computation_client.h | 2 + 9 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 torch_xla/csrc/dl_convertor.cpp create mode 100644 torch_xla/csrc/dl_convertor.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 2faf483f067..a2aadc0c633 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -42,6 +42,7 @@ ptxla_cc_library( "cross_replica_reduces.cpp", "data_ops.cpp", "debug_util.cpp", + "dl_convertor.cpp", "elementwise.cpp", "helpers.cpp", "ir_dump_util.cpp", @@ -81,6 +82,7 @@ ptxla_cc_library( "cross_replica_reduces.h", "data_ops.h", "debug_util.h", + "dl_convertor.h", "elementwise.h", "generated_file_include.h", "helpers.h", diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp new file mode 100644 index 00000000000..8a0067820a6 --- /dev/null +++ b/torch_xla/csrc/dl_convertor.cpp @@ -0,0 +1,167 @@ +#include "torch_xla/csrc/dl_convertor.h" + +#include "absl/types/span.h" + +#include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/ops/device_data.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/unwrap_data.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/status.h" + +namespace torch_xla { +namespace { + +std::shared_ptr get_data_handle(const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; + if (xtensor->CurrentDataHandle() != nullptr) { + return std::dynamic_pointer_cast(xtensor->CurrentDataHandle()); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + return UnwrapXlaData(device_data->data()); + } + TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; + } + return nullptr; +} + +struct TorchXLADLMTensor { + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +void TorchXLADLMTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { + if (device.client()->platform_id() == xla::CpuId()) { + return DLDeviceType::kDLCPU; + } else if (device.client()->platform_id() == xla::CudaId()) { + return DLDeviceType::kDLCUDA; + } else if (device.client()->platform_id() == xla::RocmId()) { + return DLDeviceType::kDLROCM; + } + XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; +} + +DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { + DLDevice dlDevice; + dlDevice.device_type = DLDeviceTypeForDevice(device); + dlDevice.device_id = device.local_hardware_id(); + return dlDevice; +} + +DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { + switch (type) { + case xla::PrimitiveType::S8: + return DLDataType{kDLInt, 8, 1}; + case xla::PrimitiveType::S16: + return DLDataType{kDLInt, 16, 1}; + case xla::PrimitiveType::S32: + return DLDataType{kDLInt, 32, 1}; + case xla::PrimitiveType::S64: + return DLDataType{kDLInt, 64, 1}; + case xla::PrimitiveType::U8: + return DLDataType{kDLUInt, 8, 1}; + case xla::PrimitiveType::U16: + return DLDataType{kDLUInt, 16, 1}; + case xla::PrimitiveType::U32: + return DLDataType{kDLUInt, 32, 1}; + case xla::PrimitiveType::U64: + return DLDataType{kDLUInt, 64, 1}; + case xla::PrimitiveType::F16: + return DLDataType{kDLFloat, 16, 1}; + case xla::PrimitiveType::F32: + return DLDataType{kDLFloat, 32, 1}; + case xla::PrimitiveType::F64: + return DLDataType{kDLFloat, 64, 1}; + case xla::PrimitiveType::BF16: + return DLDataType{kDLBfloat, 16, 1}; + case xla::PrimitiveType::PRED: + return DLDataType{kDLBool, 8, 1}; + case xla::PrimitiveType::C64: + return DLDataType{kDLComplex, 64, 1}; + case xla::PrimitiveType::C128: + return DLDataType{kDLComplex, 128, 1}; + default: + XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) << " has no DLPack equivalent"; + } +} + +std::vector StridesForShape(xla::PrimitiveType element_type, + absl::Span dimensions, + const xla::Layout& layout) { + XLA_CHECK_EQ(dimensions.size(), layout.minor_to_major().size()); + std::vector strides; + strides.resize(dimensions.size()); + int64_t stride = 1; + for (int i : layout.minor_to_major()) { + strides[i] = stride; + stride *= dimensions[i]; + } + return strides; +} + +// Convert an XLA tensor to dlPack tensor. +DLManagedTensor* toDLPack(const at::Tensor& input) { + std::shared_ptr handle = get_data_handle(input); + XLA_CHECK(handle) << "Could not extract a valid data handle from the input tensor"; + + // std::shared_ptr pjrt_data = std::dynamic_pointer_cast(data); + // xla::PjRtBuffer* pjrt_buffer = pjrt_data->buffer.get(); + xla::PjRtBuffer* pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle).get(); + + if (pjrt_buffer->IsTuple()) { + XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not implemented for tuple buffers."; + } + if (pjrt_buffer->has_dynamic_dimensions()) { + XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; + } + + auto torchXlaDLMTensor = std::make_unique(); + DLTensor& dt = torchXlaDLMTensor->tensor.dl_tensor; + { + auto external_ref = pjrt_buffer->AcquireExternalReference(); + XLA_CHECK_OK(external_ref.status()); + torchXlaDLMTensor->external_reference = std::move(external_ref.value()); + xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); + absl::Status status = future.Await(); + XLA_CHECK_OK(status); + } + // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? + + dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); + torchXlaDLMTensor->tensor.manager_ctx = torchXlaDLMTensor.get(); + torchXlaDLMTensor->tensor.deleter = TorchXLADLMTensorDeleter; + dt.device = DLDeviceForDevice(*pjrt_buffer->device()); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); + dt.ndim = pjrt_buffer->dimensions().size(); + dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); + + torchXlaDLMTensor->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); + xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); + torchXlaDLMTensor->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); + dt.shape = reinterpret_cast(torchXlaDLMTensor->shape.data()); + dt.strides = reinterpret_cast(torchXlaDLMTensor->strides.data()); + dt.byte_offset = 0; + + return &(torchXlaDLMTensor.release()->tensor); +} + +} // xw32: why do we need the extra namespace? +} diff --git a/torch_xla/csrc/dl_convertor.h b/torch_xla/csrc/dl_convertor.h new file mode 100644 index 00000000000..32c443231ff --- /dev/null +++ b/torch_xla/csrc/dl_convertor.h @@ -0,0 +1,13 @@ +#ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ +#define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ + +#include +#include + +namespace torch_xla { + +DLManagedTensor* toDLPack(const at::Tensor& src); + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 11d9c7206da..4c8def6d4c0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dtype.h" +#include "torch_xla/csrc/dl_convertor.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -1098,6 +1100,15 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("get_name_string", &PyLoweringContext::GetNameString); } +void dlPack_Capsule_Destructor(PyObject* data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + return; + } + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + dlMTensor->deleter(dlMTensor); +} + void InitXlaModuleBindings(py::module m) { m.def("_prepare_to_exit", []() { PrepareToExit(); }); m.def("_xla_runtime_is_initialized", []() { @@ -2510,7 +2521,8 @@ void InitXlaModuleBindings(py::module m) { // from an XLA tensor to a dlpack tensor. m.def("_to_dlpack", [](const at::Tensor& input) -> PyObject* { - return nullptr; + DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); + return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); // from a dlpack tensor to an XLA tensor m.def("_from_dlpack", [](PyObject* ext_data) -> at::Tensor { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index cc58736e8dc..d29ba0567f8 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -25,6 +25,7 @@ #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" #include "xla/types.h" @@ -304,6 +305,8 @@ class ComputationClient { virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; + virtual std::shared_ptr GetPjRtBuffer(const DataPtr handle) = 0; + // Compiles a set of computations. virtual std::vector Compile( std::vector instances) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 842398126d0..e059be41a08 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -402,6 +402,10 @@ std::uintptr_t IfrtComputationClient::UnsafeBufferPointer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } +std::shared_ptr IfrtComputationClient::GetPjRtBuffer(const DataPtr handle) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + std::vector IfrtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 4c10be9d1ca..c5493f0fd0f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -56,6 +56,8 @@ class IfrtComputationClient : public ComputationClient { std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + std::shared_ptr GetPjRtBuffer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3d5fbaf1f8e..69854f58f5b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -469,6 +469,13 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( return ptr.value(); } +std::shared_ptr PjRtComputationClient::GetPjRtBuffer(const DataPtr handle) { + std::shared_ptr pjrt_data = + std::dynamic_pointer_cast(handle); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + return pjrt_data->buffer; +} + std::vector PjRtComputationClient::TransferFromDevice( absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 350b1193ef7..fe04538052c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -57,6 +57,8 @@ class PjRtComputationClient : public ComputationClient { std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; + std::shared_ptr GetPjRtBuffer(const DataPtr handle) override; + DataPtr TransferShardsToDevice( absl::Span> tensor_shards, std::string device, xla::Shape shape, xla::OpSharding sharding) override; From 09ff265d22aba0c39b03df4c25caa0d6e3d6dfa6 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 3 May 2024 22:56:19 +0000 Subject: [PATCH 03/19] finished fromDLPack and it compiles --- torch_xla/csrc/dl_convertor.cpp | 206 ++++++++++++++++++ torch_xla/csrc/dl_convertor.h | 1 + torch_xla/csrc/init_python_bindings.cpp | 11 +- torch_xla/csrc/runtime/computation_client.h | 11 + .../csrc/runtime/ifrt_computation_client.h | 18 ++ .../csrc/runtime/pjrt_computation_client.cc | 6 + .../csrc/runtime/pjrt_computation_client.h | 20 +- 7 files changed, 268 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 8a0067820a6..bc63e50c295 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/dl_convertor.h" #include "absl/types/span.h" +#include #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -10,6 +11,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/unwrap_data.h" +#include "torch_xla/csrc/tensor_util.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -163,5 +165,209 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { return &(torchXlaDLMTensor.release()->tensor); } +absl::StatusOr DeviceForDLDevice(const DLDevice& context) { + switch (context.device_type) { + case DLDeviceType::kDLCPU: + // if (cpu_client == nullptr) { + // return InvalidArgument( + // "DLPack tensor is on CPU, but no CPU backend was provided."); + // } + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId()); + return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + case DLDeviceType::kDLCUDA: + // if (gpu_client == nullptr) { // xw32 TODO: check if client_ is GPU client + // return InvalidArgument( + // "DLPack tensor is on GPU, but no GPU backend was provided."); + // } + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId()); + return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + // case DLDeviceType::kDLROCM: + // // if (gpu_client == nullptr) { + // // return InvalidArgument( + // // "DLPack tensor is on GPU, but no GPU backend was provided."); + // // } + // XLA_CHECK_EQ(pjrt_client->platform_id(), xla::RocmId()); + // xla::PjRtDevice* device = pjrt_client->addressable_devices()[context.device_id]; + // return device; + default: + return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return tsl::errors::Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return xla::PrimitiveType::PRED; + default: + return tsl::errors::Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return xla::PrimitiveType::S8; + case 16: + return xla::PrimitiveType::S16; + case 32: + return xla::PrimitiveType::S32; + case 64: + return xla::PrimitiveType::S64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return xla::PrimitiveType::U8; + case 16: + return xla::PrimitiveType::U16; + case 32: + return xla::PrimitiveType::U32; + case 64: + return xla::PrimitiveType::U64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat: + switch (type.bits) { + case 16: + return xla::PrimitiveType::F16; + case 32: + return xla::PrimitiveType::F32; + case 64: + return xla::PrimitiveType::F64; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLBfloat: + switch (type.bits) { + case 16: + return xla::PrimitiveType::BF16; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return xla::PrimitiveType::C64; + case 128: + return xla::PrimitiveType::C128; + default: + return tsl::errors::Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return tsl::errors::Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + XLA_CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return tsl::errors::Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +at::Tensor fromDLPack(DLManagedTensor* dlmt) { + if (dlmt->dl_tensor.ndim < 0) { + XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; + } + xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); // client_ is a xla::PjRtClient. So this fromDLPack should be inside pjrt_computation_client class. + absl::Span dimensions( + const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + const_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + minor_to_major = StridesToLayout(dimensions, strides).value(); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + absl::StatusOr default_layout_from_client = + device->client()->GetDefaultLayout(element_type, dimensions); + xla::Layout default_layout; + if (default_layout_from_client.ok()) { + default_layout = *default_layout_from_client; + } else if (absl::IsUnimplemented(default_layout_from_client.status())) { + // TODO(skyewm): consider remove the fallback path when GetDefaultLayout is + // unimplemented. + xla::Shape host_shape = xla::ShapeUtil::MakeShape(element_type, dimensions); + default_layout = xla::LayoutUtil::GetWithDefaultLayout(host_shape).layout(); + } else { + XLA_ERROR() << "default_layout_from_client.status() is not ok."; + } + // if (shape.layout() != default_layout) { + // XLA_ERROR() << "from_dlpack got array with non-default layout with minor-to-major dimensions (" << absl::StrJoin(shape.layout().minor_to_major(), ",") << "), expected (" << absl::StrJoin(default_layout.minor_to_major(), ",") << ")"; + // } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + xla::StatusOr> pjrt_buffer = device->client()->CreateViewOfDeviceBuffer( + static_cast(dlmt->dl_tensor.data) + + dlmt->dl_tensor.byte_offset, + shape, device, on_delete_callback); + + runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); + + // xw32 note: XlaDataToTensors does a fromDeviceToHost transfer.XlaDataToTensors + at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); + return XlaDataToTensors({data}, {tensor_type})[0]; + +} + } // xw32: why do we need the extra namespace? } diff --git a/torch_xla/csrc/dl_convertor.h b/torch_xla/csrc/dl_convertor.h index 32c443231ff..07d4587146a 100644 --- a/torch_xla/csrc/dl_convertor.h +++ b/torch_xla/csrc/dl_convertor.h @@ -7,6 +7,7 @@ namespace torch_xla { DLManagedTensor* toDLPack(const at::Tensor& src); +at::Tensor fromDLPack(DLManagedTensor* src); } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4c8def6d4c0..91fc9c26ae3 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1109,6 +1109,15 @@ void dlPack_Capsule_Destructor(PyObject* data) { dlMTensor->deleter(dlMTensor); } +at::Tensor tensor_fromDLPack(PyObject* data) { + DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; + + at::Tensor tensor = torch_xla::fromDLPack(dlMTensor); + PyCapsule_SetName(data, "used_dltensor"); + return tensor; +} + void InitXlaModuleBindings(py::module m) { m.def("_prepare_to_exit", []() { PrepareToExit(); }); m.def("_xla_runtime_is_initialized", []() { @@ -2526,7 +2535,7 @@ void InitXlaModuleBindings(py::module m) { }); // from a dlpack tensor to an XLA tensor m.def("_from_dlpack", [](PyObject* ext_data) -> at::Tensor { - + return tensor_fromDLPack(ext_data); }); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index d29ba0567f8..6eca719c896 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -29,6 +29,7 @@ #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" #include "xla/types.h" +#include "xla/pjrt/pjrt_common.h" namespace torch_xla { namespace runtime { @@ -259,6 +260,10 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; + virtual DataPtr CreateData( + std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) = 0; + // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input // wrapped inside a vector. @@ -276,6 +281,8 @@ class ComputationClient { // structure will be empty if there is no sharding, like with PjRtData. virtual std::optional GetDataSharding(DataPtr handle) = 0; + virtual std::string PjRtDeviceToString(xla::PjRtDevice* const device) const = 0; + // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToDevice( absl::Span> tensors) = 0; @@ -347,6 +354,10 @@ class ComputationClient { virtual torch_xla::DeviceType GetDeviceType() const = 0; + virtual xla::PjRtPlatformId GetPlatformID() const = 0; + + virtual absl::StatusOr LookupAddressableDevice(int local_device_id) const = 0; + virtual size_t GetNumDevices() const = 0; virtual std::vector GetLocalDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index c5493f0fd0f..b2185842289 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -33,6 +33,12 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; + DataPtr CreateData( + std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; + std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; @@ -86,6 +92,14 @@ class IfrtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + xla::PjRtPlatformId GetPlatformID() const override { + return client_->platform_id(); + } + + absl::StatusOr LookupAddressableDevice(int local_device_id) const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; @@ -123,6 +137,10 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + std::string SerializeComputation(const ComputationPtr computation) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 69854f58f5b..e6d36a08df8 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -185,6 +185,12 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( return std::make_shared(std::move(device), std::move(shape)); } +ComputationClient::DataPtr PjRtComputationClient::CreateData( + std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) { + return std::make_shared(std::move(device), std::move(shape)); +} + std::vector PjRtComputationClient::GetDataShards( ComputationClient::DataPtr data) { tsl::profiler::TraceMe activity("PjRtComputationClient::GetDataShards", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index fe04538052c..1b1e0aa5c47 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,6 +32,10 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; + DataPtr CreateData( + std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override; + std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; @@ -91,6 +95,14 @@ class PjRtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + xla::PjRtPlatformId GetPlatformID() const override { + return client_->platform_id(); + } + + absl::StatusOr LookupAddressableDevice(int local_device_id) const override { + return client_->LookupAddressableDevice(xla::PjRtLocalDeviceId(local_device_id)); + } + std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; @@ -128,6 +140,10 @@ class PjRtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; }; + std::string PjRtDeviceToString(xla::PjRtDevice* const device) const override; + std::vector PjRtDevicesToString( + absl::Span devices) const; + private: std::unique_ptr client_; std::unique_ptr coordinator_; @@ -143,10 +159,6 @@ class PjRtComputationClient : public ComputationClient { xla::PjRtDevice* StringToPjRtDevice(const std::string& device); - std::string PjRtDeviceToString(xla::PjRtDevice* const device) const; - std::vector PjRtDevicesToString( - absl::Span devices) const; - struct PjRtData : public Data { PjRtData(std::string device, xla::Shape device_shape) : Data(std::move(device), std::move(device_shape)) {} From 1f0ed125fecd7afc8465d13a8d8f563fb094e7ea Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Sat, 4 May 2024 00:23:12 +0000 Subject: [PATCH 04/19] fixed an unrecognized symbol issue. --- test/test_operations.py | 19 ++++++++++--------- torch_xla/csrc/dl_convertor.cpp | 2 -- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index ea1caa389e7..59fdd4d6868 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -40,6 +40,7 @@ import torch_xla.distributed.spmd as xs from torch_xla import runtime as xr import torch_xla.test.test_utils as xtu +import torch_xla.utils.dlpack as xdlpack import torch_xla.utils.utils as xu import torch_xla.utils.serialization as xser import torch_xla.core.xla_model as xm @@ -2457,15 +2458,15 @@ def test_dlpack_capsule_conversion(self): # TODO(xw32): make sure to test the storage is tested. t1 = torch.arange(5, device=xm.xla_device()) xm.mark_step() - got1 = from_dlpack(to_dlpack(t1)) + got1 = xdlpack.from_dlpack(xdlpack.to_dlpack(t1)) self.assertEqual(t1.cpu(), got1.cpu()) t2 = torch.arange(5, device=xm.xla_device()) - got2 = from_dlpack(to_dlpack(t2)) + got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) self.assertEqual(t2.cpu(), got2.cpu()) t3 = torch.tensor(5, device=xm.xla_device()) - got3 = from_dlpack(to_dlpack(t3)) + got3 = xdlpack.from_dlpack(xdlpack.to_dlpack(t3)) self.assertEqual(t3.cpu(), got3.cpu()) # TODO(xw32): figure it out what it is testing. @@ -2474,15 +2475,15 @@ def test_dlpack_capsule_conversion(self): def test_dlpack_protocol_conversion(self): t1 = torch.arange(5, device=xm.xla_device()) xm.mark_step() - got1 = from_dlpack(t1) + got1 = xdlpack.from_dlpack(t1) self.assertEqual(t1.cpu(), got1.cpu()) t2 = torch.arange(5, device=xm.xla_device()) - got2 = from_dlpack(t2) + got2 = xdlpack.from_dlpack(t2) self.assertEqual(t2.cpu(), got2.cpu()) t3 = torch.tensor(5, device=xm.xla_device()) - got3 = from_dlpack(t3) + got3 = xdlpack.from_dlpack(t3) self.assertEqual(t3.cpu(), got3.cpu()) @onlyIfTorchSupportsCUDA @@ -2490,13 +2491,13 @@ def test_dlpack_protocol_conversion(self): def test_dlpack_cuda_to_xla_shared_storage(self): t1 = torch.arange(5).cuda() dlt1 = torch.utils.dlpack.to_dlpack(t1) - xla_t1 = from_dlpack(dlt1) + xla_t1 = xdlpack.from_dlpack(dlt1) t1[0] = t1[0] + 20 self.assertEqual(t1, xla_t1.cpu()) t2 = torch.tensor(5).cuda() dlt2 = torch.utils.dlpack.to_dlpack(t2) - xla_t2 = from_dlpack(dlt2) + xla_t2 = xdlpack.from_dlpack(dlt2) t2.fill_(6) self.assertEqual(t2, xla_t2.cpu()) @@ -2504,7 +2505,7 @@ def test_dlpack_cuda_to_xla_shared_storage(self): @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_cuda_shared_storage(self): xla_t1 = torch.arange(5, device=xm.xla_device()) - dlt1 = to_dlpack(xla_t1) + dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) cuda_t1[0] = cuda_t1[0] + 20 self.assertEqual(xla_t1.cpu(), cuda_t1.cpu()) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index bc63e50c295..d243504027d 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -18,7 +18,6 @@ #include "xla/status.h" namespace torch_xla { -namespace { std::shared_ptr get_data_handle(const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -369,5 +368,4 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { } -} // xw32: why do we need the extra namespace? } From b3516c2061f406c88df313d3dece9936870fa484 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Mon, 6 May 2024 20:45:21 +0000 Subject: [PATCH 05/19] Change test to use .to(xla_device). --- test/test_operations.py | 12 +++++---- torch_xla/csrc/dl_convertor.cpp | 8 +++++- torch_xla/csrc/init_python_bindings.cpp | 25 ++++++++++++++++--- .../csrc/runtime/pjrt_computation_client.cc | 5 ++-- torch_xla/utils/dlpack.py | 4 ++- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 59fdd4d6868..34277a3ffb3 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2456,12 +2456,14 @@ class TestDLPack(test_utils.XlaTestCase): @onlyIfPJRTDeviceIsCUDA def test_dlpack_capsule_conversion(self): # TODO(xw32): make sure to test the storage is tested. - t1 = torch.arange(5, device=xm.xla_device()) - xm.mark_step() - got1 = xdlpack.from_dlpack(xdlpack.to_dlpack(t1)) + t1 = torch.arange(5).to(xm.xla_device()) + dlpt1 = xdlpack.to_dlpack(t1) + print('xw32 finished the to_dlpack') + got1 = xdlpack.from_dlpack(dlpt1) self.assertEqual(t1.cpu(), got1.cpu()) + print('xw32 finished test case 1') - t2 = torch.arange(5, device=xm.xla_device()) + t2 = torch.arange(5).to(xm.xla_device()) got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) self.assertEqual(t2.cpu(), got2.cpu()) @@ -2504,7 +2506,7 @@ def test_dlpack_cuda_to_xla_shared_storage(self): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_cuda_shared_storage(self): - xla_t1 = torch.arange(5, device=xm.xla_device()) + xla_t1 = torch.arange(5).to(xm.xla_device()) dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) cuda_t1[0] = cuda_t1[0] + 20 diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index d243504027d..fc62acd435f 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -23,6 +23,7 @@ std::shared_ptr get_data_handle(const at::Tens XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; if (xtensor->CurrentDataHandle() != nullptr) { + TF_VLOG(4) << "The xla tensor has a current data handle."; return std::dynamic_pointer_cast(xtensor->CurrentDataHandle()); } else if (xtensor->CurrentIrValue().node != nullptr) { DeviceData* device_data = @@ -32,11 +33,13 @@ std::shared_ptr get_data_handle(const at::Tens } TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; } + TF_VLOG(4) << "The xla tensor either has no current data handle or has no IR value."; return nullptr; } struct TorchXLADLMTensor { std::unique_ptr external_reference; + std::shared_ptr buffer_reference; std::vector shape; std::vector strides; @@ -125,7 +128,8 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { // std::shared_ptr pjrt_data = std::dynamic_pointer_cast(data); // xla::PjRtBuffer* pjrt_buffer = pjrt_data->buffer.get(); - xla::PjRtBuffer* pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle).get(); + std::shared_ptr pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle); + XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; if (pjrt_buffer->IsTuple()) { XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not implemented for tuple buffers."; @@ -144,6 +148,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { absl::Status status = future.Await(); XLA_CHECK_OK(status); } + torchXlaDLMTensor->buffer_reference = pjrt_buffer; // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); @@ -161,6 +166,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { dt.strides = reinterpret_cast(torchXlaDLMTensor->strides.data()); dt.byte_offset = 0; + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; return &(torchXlaDLMTensor.release()->tensor); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 91fc9c26ae3..55947d98d0d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1100,15 +1100,29 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("get_name_string", &PyLoweringContext::GetNameString); } +// Used in the to_dlpack. void dlPack_Capsule_Destructor(PyObject* data) { + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; if (!PyCapsule_IsValid(data, "dltensor")) { return; } + HANDLE_TH_ERRORS DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); - dlMTensor->deleter(dlMTensor); + if (dlMTensor) { + dlMTensor->deleter(dlMTensor); + } else { + PyErr_Clear(); + } + END_HANDLE_TH_ERRORS_RET() } +// PyObject* tensor_toDLPack(const at::Tensor& input) { +// DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); +// std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; +// return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); +// } + at::Tensor tensor_fromDLPack(PyObject* data) { DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; @@ -2529,13 +2543,16 @@ void InitXlaModuleBindings(py::module m) { }); // from an XLA tensor to a dlpack tensor. - m.def("_to_dlpack", [](const at::Tensor& input) -> PyObject* { + m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); + // m.def("_to_dlpack", &tensor_toDLPack, ""); // + // from a dlpack tensor to an XLA tensor - m.def("_from_dlpack", [](PyObject* ext_data) -> at::Tensor { - return tensor_fromDLPack(ext_data); + m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { + return tensor_fromDLPack(ext_data.ptr()); }); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index e6d36a08df8..60779293934 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -496,7 +496,8 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data); + XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer) << "PjRt buffer is null in " << __FUNCTION__; xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); @@ -506,7 +507,7 @@ std::vector PjRtComputationClient::TransferFromDevice( } for (auto& future : futures) { absl::Status status = future.Await(); - XLA_CHECK_OK(status); + XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" << __FUNCTION__; } InboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 9ae99b8f802..7f46a121be7 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -2,7 +2,9 @@ import torch_xla def to_dlpack(xla_tensor: Any): - return torch_xla._XLAC._to_dlpack(xla_tensor) + dlt = torch_xla._XLAC._to_dlpack(xla_tensor) + print('xw32 torch_xla._XLAC._to_dlpack has returned.') + return dlt def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor) From f78d3200c3dff03d6b8a8bfdc06c49f989148e52 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 9 May 2024 17:03:21 +0000 Subject: [PATCH 06/19] check pjrt_buffer is not null. Also in toDlPack, avoid doing from device to host transfer. --- test/test_operations.py | 4 +++- torch_xla/csrc/dl_convertor.cpp | 7 +++++-- torch_xla/csrc/runtime/pjrt_computation_client.cc | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 34277a3ffb3..1236c626211 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2457,9 +2457,11 @@ class TestDLPack(test_utils.XlaTestCase): def test_dlpack_capsule_conversion(self): # TODO(xw32): make sure to test the storage is tested. t1 = torch.arange(5).to(xm.xla_device()) - dlpt1 = xdlpack.to_dlpack(t1) + dlpt1 = xdlpack.to_dlpack(t1) # dlpt1 has type PyCapsule print('xw32 finished the to_dlpack') got1 = xdlpack.from_dlpack(dlpt1) + print('xw32 finished the from_dlpack') + self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(t1),torch_xla._XLAC._unsafe_buffer_pointer(got1)) self.assertEqual(t1.cpu(), got1.cpu()) print('xw32 finished test case 1') diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index fc62acd435f..e98ce6a42ce 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -365,13 +365,16 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, shape, device, on_delete_callback); + XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer in " << __FUNCTION__; + XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null in " << __FUNCTION__; runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); // xw32 note: XlaDataToTensors does a fromDeviceToHost transfer.XlaDataToTensors at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); - return XlaDataToTensors({data}, {tensor_type})[0]; - + // return XlaDataToTensors({data}, {tensor_type})[0]; + XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); + return bridge::AtenFromXlaTensor(xla_tensor); } } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 60779293934..1303e1b1574 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -469,6 +469,7 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; xla::StatusOr ptr = client_->UnsafeBufferPointer(pjrt_data->buffer.get()); XLA_CHECK(ptr.ok()); @@ -497,7 +498,7 @@ std::vector PjRtComputationClient::TransferFromDevice( // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer) << "PjRt buffer is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); From 3e2ff673328625e0b2af6bfe57eb2b29486e0132 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 9 May 2024 17:57:06 +0000 Subject: [PATCH 07/19] Improved the basic test --- test/test_operations.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 1236c626211..d7048f909b8 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2461,10 +2461,20 @@ def test_dlpack_capsule_conversion(self): print('xw32 finished the to_dlpack') got1 = xdlpack.from_dlpack(dlpt1) print('xw32 finished the from_dlpack') - self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(t1),torch_xla._XLAC._unsafe_buffer_pointer(got1)) + + print('t1.device=', t1.device, ', got1.device=', got1.device) + self.assertEqual(t1.device, got1.device) + print('t1.cpu()=', t1.cpu()) + print('got1.cpu()=', got1.cpu()) self.assertEqual(t1.cpu(), got1.cpu()) - print('xw32 finished test case 1') + self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt1)) + + print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(t1)=', torch_xla._XLAC._unsafe_buffer_pointer(t1)) + print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got1)=', torch_xla._XLAC._unsafe_buffer_pointer(got1)) + self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(t1),torch_xla._XLAC._unsafe_buffer_pointer(got1)) + print('xw32 first test passed.') + # TODO(xw32): for the below test cases, test the same thing as above. May create a helper function if needed. t2 = torch.arange(5).to(xm.xla_device()) got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) self.assertEqual(t2.cpu(), got2.cpu()) From 01691c07eba9f531ac2c1398d334efbbdce9434d Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 9 May 2024 20:28:40 +0000 Subject: [PATCH 08/19] add release GIL in to_dlpack. --- torch_xla/csrc/dl_convertor.cpp | 25 ++++++++----------------- torch_xla/csrc/init_python_bindings.cpp | 14 ++++++++------ torch_xla/utils/dlpack.py | 2 +- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index e98ce6a42ce..b5604a23210 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -39,7 +39,8 @@ std::shared_ptr get_data_handle(const at::Tens struct TorchXLADLMTensor { std::unique_ptr external_reference; - std::shared_ptr buffer_reference; + // std::shared_ptr buffer_reference; + at::Tensor source_tensor; std::vector shape; std::vector strides; @@ -124,10 +125,8 @@ std::vector StridesForShape(xla::PrimitiveType element_type, // Convert an XLA tensor to dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { std::shared_ptr handle = get_data_handle(input); - XLA_CHECK(handle) << "Could not extract a valid data handle from the input tensor"; + XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; - // std::shared_ptr pjrt_data = std::dynamic_pointer_cast(data); - // xla::PjRtBuffer* pjrt_buffer = pjrt_data->buffer.get(); std::shared_ptr pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; @@ -141,6 +140,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { auto torchXlaDLMTensor = std::make_unique(); DLTensor& dt = torchXlaDLMTensor->tensor.dl_tensor; { + // AcquireExternalReference may block auto external_ref = pjrt_buffer->AcquireExternalReference(); XLA_CHECK_OK(external_ref.status()); torchXlaDLMTensor->external_reference = std::move(external_ref.value()); @@ -148,7 +148,8 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { absl::Status status = future.Await(); XLA_CHECK_OK(status); } - torchXlaDLMTensor->buffer_reference = pjrt_buffer; + // torchXlaDLMTensor->buffer_reference = pjrt_buffer; + torchXlaDLMTensor->source_tensor = input; // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); @@ -166,7 +167,6 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { dt.strides = reinterpret_cast(torchXlaDLMTensor->strides.data()); dt.byte_offset = 0; - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; return &(torchXlaDLMTensor.release()->tensor); } @@ -342,17 +342,8 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { // special case when non-default layouts are better supported by JAX. absl::StatusOr default_layout_from_client = device->client()->GetDefaultLayout(element_type, dimensions); - xla::Layout default_layout; - if (default_layout_from_client.ok()) { - default_layout = *default_layout_from_client; - } else if (absl::IsUnimplemented(default_layout_from_client.status())) { - // TODO(skyewm): consider remove the fallback path when GetDefaultLayout is - // unimplemented. - xla::Shape host_shape = xla::ShapeUtil::MakeShape(element_type, dimensions); - default_layout = xla::LayoutUtil::GetWithDefaultLayout(host_shape).layout(); - } else { - XLA_ERROR() << "default_layout_from_client.status() is not ok."; - } + XLA_CHECK_OK(default_layout_from_client.status()) << "Failed to get a default layout in " << __FUNCTION__; + xla::Layout default_layout = default_layout_from_client.value(); // TODO(xw32): the check below is needed due to an limitation in ifrt. Since torch_xla uses pjrt, we may not need the check below and the var default_layout. // if (shape.layout() != default_layout) { // XLA_ERROR() << "from_dlpack got array with non-default layout with minor-to-major dimensions (" << absl::StrJoin(shape.layout().minor_to_major(), ",") << "), expected (" << absl::StrJoin(default_layout.minor_to_major(), ",") << ")"; // } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 55947d98d0d..93ed0434e70 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1102,19 +1102,18 @@ void BuildLoweringContextSubmodule(py::module* m) { // Used in the to_dlpack. void dlPack_Capsule_Destructor(PyObject* data) { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; if (!PyCapsule_IsValid(data, "dltensor")) { return; } - HANDLE_TH_ERRORS DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + static_cast(PyCapsule_GetPointer(data, "dltensor")); if (dlMTensor) { dlMTensor->deleter(dlMTensor); } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. PyErr_Clear(); } - END_HANDLE_TH_ERRORS_RET() } // PyObject* tensor_toDLPack(const at::Tensor& input) { @@ -2544,8 +2543,11 @@ void InitXlaModuleBindings(py::module m) { // from an XLA tensor to a dlpack tensor. m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { - DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; + DLManagedTensor* dlMTensor; + { + NoGilSection nogil; + dlMTensor = torch_xla::toDLPack(input); + } return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); // m.def("_to_dlpack", &tensor_toDLPack, ""); // diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 7f46a121be7..236d42aacdf 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -3,7 +3,7 @@ def to_dlpack(xla_tensor: Any): dlt = torch_xla._XLAC._to_dlpack(xla_tensor) - print('xw32 torch_xla._XLAC._to_dlpack has returned.') + print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__")) return dlt def from_dlpack(ext_tensor: Any): From 251ae89ef21c9de4193befdb23e155ab028b103e Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 9 May 2024 21:10:05 +0000 Subject: [PATCH 09/19] fixed a bug of not passing in the pjrt_buffer --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.h | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1303e1b1574..c7896c09c20 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -188,7 +188,7 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( ComputationClient::DataPtr PjRtComputationClient::CreateData( std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer) { - return std::make_shared(std::move(device), std::move(shape)); + return std::make_shared(std::move(device), std::move(shape), pjrt_buffer); } std::vector PjRtComputationClient::GetDataShards( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 1b1e0aa5c47..b17b8c451fd 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -181,7 +181,10 @@ class PjRtComputationClient : public ComputationClient { }; void Assign(const torch::lazy::BackendData& data) override; bool HasValue() const override { - return buffer != nullptr && !buffer->IsDeleted(); + // bool has_value = buffer != nullptr && !buffer->IsDeleted(); + // std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": buffer != nullptr=" << (buffer != nullptr) << ", buffer->IsDeleted()=" << buffer->IsDeleted() << std::endl; + // return has_value; + return buffer != nullptr && !buffer->IsDeleted(); // TODO(xw32): uncomment this line and remove all above lines in the method. }; bool HasSharding() const override { return false; } @@ -237,6 +240,7 @@ class PjRtComputationClient : public ComputationClient { } bool HasValue() const override { + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": PjRtShardedData::HasValue is called." << std::endl; if (shards.empty()) { return false; } From aac102b2a7231058371091788bd3ad56d5011777 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 9 May 2024 22:36:05 +0000 Subject: [PATCH 10/19] in the middle of fixing holds_[i] == 0 (1 vs. 0) --- test/test_operations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index d7048f909b8..be932ca6f5a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2475,13 +2475,13 @@ def test_dlpack_capsule_conversion(self): print('xw32 first test passed.') # TODO(xw32): for the below test cases, test the same thing as above. May create a helper function if needed. - t2 = torch.arange(5).to(xm.xla_device()) - got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) - self.assertEqual(t2.cpu(), got2.cpu()) + # t2 = torch.arange(5).to(xm.xla_device()) + # got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) + # self.assertEqual(t2.cpu(), got2.cpu()) - t3 = torch.tensor(5, device=xm.xla_device()) - got3 = xdlpack.from_dlpack(xdlpack.to_dlpack(t3)) - self.assertEqual(t3.cpu(), got3.cpu()) + # t3 = torch.tensor(5, device=xm.xla_device()) + # got3 = xdlpack.from_dlpack(xdlpack.to_dlpack(t3)) + # self.assertEqual(t3.cpu(), got3.cpu()) # TODO(xw32): figure it out what it is testing. @onlyIfTorchSupportsCUDA From 57a936e63937a5c4c03f714d6a7797a6433e87c1 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 10 May 2024 01:21:10 +0000 Subject: [PATCH 11/19] fixed the issue: pjrt_stream_executor_client.cc:1322] Check failed: holds_[i] == 0 (1 vs. 0) --- torch_xla/csrc/dl_convertor.cpp | 17 +++++++++++++---- torch_xla/csrc/init_python_bindings.cpp | 2 ++ torch_xla/utils/dlpack.py | 7 ++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index b5604a23210..d2cb0110bd2 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -38,17 +38,26 @@ std::shared_ptr get_data_handle(const at::Tens } struct TorchXLADLMTensor { + ~TorchXLADLMTensor(); std::unique_ptr external_reference; - // std::shared_ptr buffer_reference; - at::Tensor source_tensor; + std::shared_ptr buffer_reference; + // at::Tensor source_tensor; std::vector shape; std::vector strides; DLManagedTensor tensor; }; +TorchXLADLMTensor::~TorchXLADLMTensor() { + if (external_reference) { + external_reference.reset(nullptr); + } +} + void TorchXLADLMTensorDeleter(DLManagedTensor* t) { + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; if (t) { + std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; delete static_cast(t->manager_ctx); } } @@ -148,8 +157,8 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { absl::Status status = future.Await(); XLA_CHECK_OK(status); } - // torchXlaDLMTensor->buffer_reference = pjrt_buffer; - torchXlaDLMTensor->source_tensor = input; + torchXlaDLMTensor->buffer_reference = pjrt_buffer; + // torchXlaDLMTensor->source_tensor = input; // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 93ed0434e70..e9e4a9ef0fe 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1128,6 +1128,7 @@ at::Tensor tensor_fromDLPack(PyObject* data) { at::Tensor tensor = torch_xla::fromDLPack(dlMTensor); PyCapsule_SetName(data, "used_dltensor"); + PyCapsule_SetDestructor(data, nullptr); return tensor; } @@ -2548,6 +2549,7 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; dlMTensor = torch_xla::toDLPack(input); } + // return py::reinterpret_steal(PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor)); return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); // m.def("_to_dlpack", &tensor_toDLPack, ""); // diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 236d42aacdf..67a37d5dac1 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -2,9 +2,10 @@ import torch_xla def to_dlpack(xla_tensor: Any): - dlt = torch_xla._XLAC._to_dlpack(xla_tensor) - print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__")) - return dlt + return torch_xla._XLAC._to_dlpack(xla_tensor) + # dlt = torch_xla._XLAC._to_dlpack(xla_tensor) + # print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__")) + # return dlt def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor) From 9ff3dd9b216b55758d94448cfebef7284890fcc8 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 10 May 2024 18:51:50 +0000 Subject: [PATCH 12/19] improve the unit tests. All test passed. --- test/test_operations.py | 134 +++++++++++++++++++++++++--------------- 1 file changed, 84 insertions(+), 50 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index be932ca6f5a..afbb9ba805d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -13,6 +13,7 @@ parser.add_argument('--verbosity', type=int, default=0) FLAGS, leftovers = parser.parse_known_args() sys.argv = [sys.argv[0]] + leftovers +from absl.testing import absltest, parameterized # Normal imports section starts here. import collections @@ -28,6 +29,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.testing._internal.common_device_type import dtypes +from torch.testing._internal.common_dtype import (all_types_and_complex_and) import torch_xla import torch_xla.core.xla_builder as xb import torch_xla.core.xla_op_registry as xor @@ -2449,80 +2452,111 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_3, 0) -class TestDLPack(test_utils.XlaTestCase): +class TestDLPack(parameterized.TestCase): - # TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_capsule_conversion(self): - # TODO(xw32): make sure to test the storage is tested. - t1 = torch.arange(5).to(xm.xla_device()) - dlpt1 = xdlpack.to_dlpack(t1) # dlpt1 has type PyCapsule + def _test_dlpack_capsule_conversion_helper(self, xla_tensor): + dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule print('xw32 finished the to_dlpack') - got1 = xdlpack.from_dlpack(dlpt1) + got = xdlpack.from_dlpack(dlpt) print('xw32 finished the from_dlpack') - print('t1.device=', t1.device, ', got1.device=', got1.device) - self.assertEqual(t1.device, got1.device) - print('t1.cpu()=', t1.cpu()) - print('got1.cpu()=', got1.cpu()) - self.assertEqual(t1.cpu(), got1.cpu()) - self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt1)) - - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(t1)=', torch_xla._XLAC._unsafe_buffer_pointer(t1)) - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got1)=', torch_xla._XLAC._unsafe_buffer_pointer(got1)) - self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(t1),torch_xla._XLAC._unsafe_buffer_pointer(got1)) - print('xw32 first test passed.') + print('xla_tensor.device=', xla_tensor.device, ', got.device=', got.device) + self.assertEqual(xla_tensor.device, got.device) + print('xla_tensor.cpu()=', xla_tensor.cpu()) + print('got.cpu()=', got.cpu()) + self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) + self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) - # TODO(xw32): for the below test cases, test the same thing as above. May create a helper function if needed. - # t2 = torch.arange(5).to(xm.xla_device()) - # got2 = xdlpack.from_dlpack(xdlpack.to_dlpack(t2)) - # self.assertEqual(t2.cpu(), got2.cpu()) + print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)=', torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)) + print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got)=', torch_xla._XLAC._unsafe_buffer_pointer(got)) + self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got)) - # t3 = torch.tensor(5, device=xm.xla_device()) - # got3 = xdlpack.from_dlpack(xdlpack.to_dlpack(t3)) - # self.assertEqual(t3.cpu(), got3.cpu()) - - # TODO(xw32): figure it out what it is testing. + # TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - def test_dlpack_protocol_conversion(self): - t1 = torch.arange(5, device=xm.xla_device()) + @parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) + def test_dlpack_roundtrip(self, dtype): + print('xw32 dtype=', dtype) + # "arange_cpu" not implemented for complex64 and complex128. + # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. + if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }: + return + xla_device = xm.xla_device() + xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) + # `mark_step` ensures xtensor->CurrentDataHandle() != nullptr xm.mark_step() - got1 = xdlpack.from_dlpack(t1) - self.assertEqual(t1.cpu(), got1.cpu()) + self._test_dlpack_capsule_conversion_helper(xla_tensor_0) + + xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device) + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + self._test_dlpack_capsule_conversion_helper(xla_tensor_1) - t2 = torch.arange(5, device=xm.xla_device()) - got2 = xdlpack.from_dlpack(t2) - self.assertEqual(t2.cpu(), got2.cpu()) + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) + self._test_dlpack_capsule_conversion_helper(xla_tensor_2) - t3 = torch.tensor(5, device=xm.xla_device()) - got3 = xdlpack.from_dlpack(t3) - self.assertEqual(t3.cpu(), got3.cpu()) + xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) + xm.mark_step() + # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. + xm.wait_device_ops() + self._test_dlpack_capsule_conversion_helper(xla_tensor_3) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - def test_dlpack_cuda_to_xla_shared_storage(self): - t1 = torch.arange(5).cuda() - dlt1 = torch.utils.dlpack.to_dlpack(t1) + def test_dlpack_roundtrip_bool(self): + xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) + self._test_dlpack_capsule_conversion_helper(xla_tensor) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_pytorch_cuda_to_xla(self): + t1_cuda = torch.arange(5).cuda() + dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda) xla_t1 = xdlpack.from_dlpack(dlt1) - t1[0] = t1[0] + 20 - self.assertEqual(t1, xla_t1.cpu()) + t1_cuda[0] = t1_cuda[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) - t2 = torch.tensor(5).cuda() - dlt2 = torch.utils.dlpack.to_dlpack(t2) + t2_cuda = torch.tensor(5).cuda() + dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda) xla_t2 = xdlpack.from_dlpack(dlt2) - t2.fill_(6) - self.assertEqual(t2, xla_t2.cpu()) + t2_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - def test_dlpack_xla_to_cuda_shared_storage(self): + def test_dlpack_xla_to_pytorch_cuda(self): xla_t1 = torch.arange(5).to(xm.xla_device()) dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) cuda_t1[0] = cuda_t1[0] + 20 - self.assertEqual(xla_t1.cpu(), cuda_t1.cpu()) + self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_non_default_layout(self): + cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5) + + t1 = cuda_t.t() + xla_t1 = xdlpack.from_dlpack(t1.__dlpack__()) + self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu())) + + t2 = cuda_t[0] + xla_t2 = xdlpack.from_dlpack(t2.__dlpack__()) + self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) + + t3 = cuda_t[:, 0] + self.assertRaisesRegex(RuntimeError, r"Only DLPack tensors with trivial \(compact\) striding are supported", lambda: xdlpack.from_dlpack(t3.__dlpack__())) + + t4 = cuda_t[1, :] + xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) + self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu())) + + t5 = cuda_t[1] + xla_t5 = xdlpack.from_dlpack(t5.__dlpack__()) + self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) + + + class SimpleModelWithDropout(torch.nn.Module): From 8dd8381e9aa17becfdec515c60c6253687cb2aed Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 10 May 2024 21:33:27 +0000 Subject: [PATCH 13/19] Clean up unused prints and comments --- test/test_operations.py | 9 --- torch_xla/csrc/dl_convertor.cpp | 66 +++++-------------- torch_xla/csrc/init_python_bindings.cpp | 6 -- .../csrc/runtime/pjrt_computation_client.h | 6 +- torch_xla/utils/dlpack.py | 3 - 5 files changed, 16 insertions(+), 74 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index afbb9ba805d..d2429a080cb 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2456,27 +2456,18 @@ class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule - print('xw32 finished the to_dlpack') got = xdlpack.from_dlpack(dlpt) - print('xw32 finished the from_dlpack') - print('xla_tensor.device=', xla_tensor.device, ', got.device=', got.device) self.assertEqual(xla_tensor.device, got.device) - print('xla_tensor.cpu()=', xla_tensor.cpu()) - print('got.cpu()=', got.cpu()) self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)=', torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor)) - print('xw32 torch_xla._XLAC._unsafe_buffer_pointer(got)=', torch_xla._XLAC._unsafe_buffer_pointer(got)) self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got)) - # TODO(xw32): need to test different data type such as pytorch/test/test_dlpack.py @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip(self, dtype): - print('xw32 dtype=', dtype) # "arange_cpu" not implemented for complex64 and complex128. # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }: diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index d2cb0110bd2..e1a21aed46e 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -55,9 +55,7 @@ TorchXLADLMTensor::~TorchXLADLMTensor() { } void TorchXLADLMTensorDeleter(DLManagedTensor* t) { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; if (t) { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; delete static_cast(t->manager_ctx); } } @@ -67,8 +65,6 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { return DLDeviceType::kDLCPU; } else if (device.client()->platform_id() == xla::CudaId()) { return DLDeviceType::kDLCUDA; - } else if (device.client()->platform_id() == xla::RocmId()) { - return DLDeviceType::kDLROCM; } XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; } @@ -131,7 +127,7 @@ std::vector StridesForShape(xla::PrimitiveType element_type, return strides; } -// Convert an XLA tensor to dlPack tensor. +// Convert an XLA tensor to a dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { std::shared_ptr handle = get_data_handle(input); XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; @@ -146,63 +142,46 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; } - auto torchXlaDLMTensor = std::make_unique(); - DLTensor& dt = torchXlaDLMTensor->tensor.dl_tensor; + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block auto external_ref = pjrt_buffer->AcquireExternalReference(); XLA_CHECK_OK(external_ref.status()); - torchXlaDLMTensor->external_reference = std::move(external_ref.value()); + pack->external_reference = std::move(external_ref.value()); xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); absl::Status status = future.Await(); XLA_CHECK_OK(status); } - torchXlaDLMTensor->buffer_reference = pjrt_buffer; - // torchXlaDLMTensor->source_tensor = input; - // pack->buffer_reference = nb::borrow(py_buffer); // xw32: should we do it? + pack->buffer_reference = pjrt_buffer; + // pack->source_tensor = input; - dt.data = torchXlaDLMTensor->external_reference->OpaqueDeviceMemoryDataPointer(); - torchXlaDLMTensor->tensor.manager_ctx = torchXlaDLMTensor.get(); - torchXlaDLMTensor->tensor.deleter = TorchXLADLMTensorDeleter; + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = TorchXLADLMTensorDeleter; dt.device = DLDeviceForDevice(*pjrt_buffer->device()); dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); dt.ndim = pjrt_buffer->dimensions().size(); dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); - torchXlaDLMTensor->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); - torchXlaDLMTensor->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); - dt.shape = reinterpret_cast(torchXlaDLMTensor->shape.data()); - dt.strides = reinterpret_cast(torchXlaDLMTensor->strides.data()); + pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; - return &(torchXlaDLMTensor.release()->tensor); + return &(pack.release()->tensor); } absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: - // if (cpu_client == nullptr) { - // return InvalidArgument( - // "DLPack tensor is on CPU, but no CPU backend was provided."); - // } XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId()); return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); case DLDeviceType::kDLCUDA: - // if (gpu_client == nullptr) { // xw32 TODO: check if client_ is GPU client - // return InvalidArgument( - // "DLPack tensor is on GPU, but no GPU backend was provided."); - // } XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId()); return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); - // case DLDeviceType::kDLROCM: - // // if (gpu_client == nullptr) { - // // return InvalidArgument( - // // "DLPack tensor is on GPU, but no GPU backend was provided."); - // // } - // XLA_CHECK_EQ(pjrt_client->platform_id(), xla::RocmId()); - // xla::PjRtDevice* device = pjrt_client->addressable_devices()[context.device_id]; - // return device; default: return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); @@ -325,7 +304,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->dl_tensor.ndim < 0) { XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; } - xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); // client_ is a xla::PjRtClient. So this fromDLPack should be inside pjrt_computation_client class. + xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); @@ -344,19 +323,6 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, minor_to_major); - // Raise an error if the resulting PjRtBuffer would have a non-default layout. - // TODO(skyewm): we do this because JAX doesn't currently have good support - // for non-default layouts, and will return wrong results if a non-default - // layout is passed to a computation expecting default layouts. Remove this - // special case when non-default layouts are better supported by JAX. - absl::StatusOr default_layout_from_client = - device->client()->GetDefaultLayout(element_type, dimensions); - XLA_CHECK_OK(default_layout_from_client.status()) << "Failed to get a default layout in " << __FUNCTION__; - xla::Layout default_layout = default_layout_from_client.value(); // TODO(xw32): the check below is needed due to an limitation in ifrt. Since torch_xla uses pjrt, we may not need the check below and the var default_layout. - // if (shape.layout() != default_layout) { - // XLA_ERROR() << "from_dlpack got array with non-default layout with minor-to-major dimensions (" << absl::StrJoin(shape.layout().minor_to_major(), ",") << "), expected (" << absl::StrJoin(default_layout.minor_to_major(), ",") << ")"; - // } - std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; @@ -370,9 +336,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); - // xw32 note: XlaDataToTensors does a fromDeviceToHost transfer.XlaDataToTensors at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); - // return XlaDataToTensors({data}, {tensor_type})[0]; XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); return bridge::AtenFromXlaTensor(xla_tensor); } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e9e4a9ef0fe..2f4b90e42cf 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1116,12 +1116,6 @@ void dlPack_Capsule_Destructor(PyObject* data) { } } -// PyObject* tensor_toDLPack(const at::Tensor& input) { -// DLManagedTensor* dlMTensor = torch_xla::toDLPack(input); -// std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": " << std::endl; -// return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); -// } - at::Tensor tensor_fromDLPack(PyObject* data) { DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b17b8c451fd..1b1e0aa5c47 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -181,10 +181,7 @@ class PjRtComputationClient : public ComputationClient { }; void Assign(const torch::lazy::BackendData& data) override; bool HasValue() const override { - // bool has_value = buffer != nullptr && !buffer->IsDeleted(); - // std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": buffer != nullptr=" << (buffer != nullptr) << ", buffer->IsDeleted()=" << buffer->IsDeleted() << std::endl; - // return has_value; - return buffer != nullptr && !buffer->IsDeleted(); // TODO(xw32): uncomment this line and remove all above lines in the method. + return buffer != nullptr && !buffer->IsDeleted(); }; bool HasSharding() const override { return false; } @@ -240,7 +237,6 @@ class PjRtComputationClient : public ComputationClient { } bool HasValue() const override { - std::cout << "xw32, file=" << __FILE__ << ", line=" << __LINE__ << "function=" << __FUNCTION__ << ": PjRtShardedData::HasValue is called." << std::endl; if (shards.empty()) { return false; } diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 67a37d5dac1..9ae99b8f802 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -3,9 +3,6 @@ def to_dlpack(xla_tensor: Any): return torch_xla._XLAC._to_dlpack(xla_tensor) - # dlt = torch_xla._XLAC._to_dlpack(xla_tensor) - # print('xw32 torch_xla._XLAC._to_dlpack has returned. dlt has __dlpack_+=', hasattr(dlt, "__dlpack__"), ', dlt has __dlpack_device__=', hasattr(dlt, "__dlpack_device__")) - # return dlt def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor) From 60b4d59f670eb9c8be08fac53406a00afe31f348 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 10 May 2024 21:44:01 +0000 Subject: [PATCH 14/19] linter --- test/test_operations.py | 31 ++++-- torch_xla/csrc/dl_convertor.cpp | 104 +++++++++++------- torch_xla/csrc/dl_convertor.h | 2 +- torch_xla/csrc/init_python_bindings.cpp | 18 +-- torch_xla/csrc/runtime/computation_client.h | 18 +-- .../csrc/runtime/ifrt_computation_client.cc | 3 +- .../csrc/runtime/ifrt_computation_client.h | 12 +- .../csrc/runtime/pjrt_computation_client.cc | 17 ++- .../csrc/runtime/pjrt_computation_client.h | 11 +- torch_xla/utils/dlpack.py | 2 + 10 files changed, 134 insertions(+), 84 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index d2429a080cb..b55cadcbcd7 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2455,22 +2455,33 @@ def test_unsafe_buffer_pointer(self): class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): - dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule + dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule got = xdlpack.from_dlpack(dlpt) self.assertEqual(xla_tensor.device, got.device) self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) - self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) + self.assertRaisesRegex(RuntimeError, + "DLTensor capsule can be consumed only once", + lambda: xdlpack.from_dlpack(dlpt)) - self.assertEqual(torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),torch_xla._XLAC._unsafe_buffer_pointer(got)) + self.assertEqual( + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), + torch_xla._XLAC._unsafe_buffer_pointer(got)) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) + @parameterized.parameters(*all_types_and_complex_and(torch.half, + torch.bfloat16, + torch.bool, torch.uint16, + torch.uint32, + torch.uint64)) def test_dlpack_roundtrip(self, dtype): # "arange_cpu" not implemented for complex64 and complex128. # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. - if dtype in { torch.complex128, torch.complex64, torch.uint64, torch.uint32, torch.uint16, torch.bool }: + if dtype in { + torch.complex128, torch.complex64, torch.uint64, torch.uint32, + torch.uint16, torch.bool + }: return xla_device = xm.xla_device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) @@ -2496,7 +2507,7 @@ def test_dlpack_roundtrip(self, dtype): @onlyIfPJRTDeviceIsCUDA def test_dlpack_roundtrip_bool(self): xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) - self._test_dlpack_capsule_conversion_helper(xla_tensor) + self._test_dlpack_capsule_conversion_helper(xla_tensor) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @@ -2536,7 +2547,10 @@ def test_dlpack_non_default_layout(self): self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) t3 = cuda_t[:, 0] - self.assertRaisesRegex(RuntimeError, r"Only DLPack tensors with trivial \(compact\) striding are supported", lambda: xdlpack.from_dlpack(t3.__dlpack__())) + self.assertRaisesRegex( + RuntimeError, + r"Only DLPack tensors with trivial \(compact\) striding are supported", + lambda: xdlpack.from_dlpack(t3.__dlpack__())) t4 = cuda_t[1, :] xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) @@ -2547,9 +2561,6 @@ def test_dlpack_non_default_layout(self): self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) - - - class SimpleModelWithDropout(torch.nn.Module): def __init__(self): diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index e1a21aed46e..3c0371efcf6 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -1,17 +1,17 @@ #include "torch_xla/csrc/dl_convertor.h" -#include "absl/types/span.h" #include -#include "torch_xla/csrc/tensor.h" +#include "absl/types/span.h" #include "torch_xla/csrc/aten_xla_bridge.h" -#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/ops/device_data.h" +#include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/unwrap_data.h" +#include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/unwrap_data.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -19,12 +19,14 @@ namespace torch_xla { -std::shared_ptr get_data_handle(const at::Tensor& input) { +std::shared_ptr get_data_handle( + const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; if (xtensor->CurrentDataHandle() != nullptr) { TF_VLOG(4) << "The xla tensor has a current data handle."; - return std::dynamic_pointer_cast(xtensor->CurrentDataHandle()); + return std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); } else if (xtensor->CurrentIrValue().node != nullptr) { DeviceData* device_data = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); @@ -33,7 +35,8 @@ std::shared_ptr get_data_handle(const at::Tens } TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; } - TF_VLOG(4) << "The xla tensor either has no current data handle or has no IR value."; + TF_VLOG(4) + << "The xla tensor either has no current data handle or has no IR value."; return nullptr; } @@ -66,7 +69,8 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { } else if (device.client()->platform_id() == xla::CudaId()) { return DLDeviceType::kDLCUDA; } - XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; + XLA_ERROR() << "Device " << device.DebugString() + << " cannot be used as a DLPack device."; } DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { @@ -109,7 +113,8 @@ DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { case xla::PrimitiveType::C128: return DLDataType{kDLComplex, 128, 1}; default: - XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) << " has no DLPack equivalent"; + XLA_ERROR() << "XLA type " << xla::PrimitiveType_Name(type) + << " has no DLPack equivalent"; } } @@ -129,14 +134,18 @@ std::vector StridesForShape(xla::PrimitiveType element_type, // Convert an XLA tensor to a dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { - std::shared_ptr handle = get_data_handle(input); - XLA_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; + std::shared_ptr handle = + get_data_handle(input); + XLA_CHECK(handle != nullptr) + << "Could not extract a valid data handle from the input tensor"; - std::shared_ptr pjrt_buffer = runtime::GetComputationClient()->GetPjRtBuffer(handle); + std::shared_ptr pjrt_buffer = + runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; if (pjrt_buffer->IsTuple()) { - XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not implemented for tuple buffers."; + XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; } if (pjrt_buffer->has_dynamic_dimensions()) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; @@ -164,9 +173,11 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { dt.ndim = pjrt_buffer->dimensions().size(); dt.dtype = PrimitiveTypeToDLDataType(pjrt_buffer->element_type()); - pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); - pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); dt.shape = reinterpret_cast(pack->shape.data()); dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; @@ -177,21 +188,25 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: - XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CpuId()); - return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CpuId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); case DLDeviceType::kDLCUDA: - XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), xla::CudaId()); - return runtime::GetComputationClient()->LookupAddressableDevice(context.device_id); + XLA_CHECK_EQ(runtime::GetComputationClient()->GetPlatformID(), + xla::CudaId()); + return runtime::GetComputationClient()->LookupAddressableDevice( + context.device_id); default: - return tsl::errors::InvalidArgument("Unknown/unsupported DLPack device type %d", - context.device_type); + return tsl::errors::InvalidArgument( + "Unknown/unsupported DLPack device type %d", context.device_type); } } absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { if (type.lanes != 1) { - return tsl::errors::Unimplemented("DLPack types with lanes != 1 not implemented, got %d", - type.lanes); + return tsl::errors::Unimplemented( + "DLPack types with lanes != 1 not implemented, got %d", type.lanes); } switch (type.code) { case kDLBool: @@ -265,7 +280,8 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { type.bits); } default: - return tsl::errors::Unimplemented("Unknown or invalid DLPack type code %d", type.code); + return tsl::errors::Unimplemented( + "Unknown or invalid DLPack type code %d", type.code); } } @@ -302,43 +318,51 @@ absl::StatusOr> StridesToLayout( at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->dl_tensor.ndim < 0) { - XLA_ERROR() << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; + XLA_ERROR() + << "Number of dimensions in DLManagedTensor must be nonnegative, got " + << dlmt->dl_tensor.ndim; } xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); - xla::PrimitiveType element_type = DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); + xla::PrimitiveType element_type = + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype).value(); std::vector minor_to_major; if (dlmt->dl_tensor.strides && absl::c_find(dimensions, 0) == dimensions.end()) { absl::Span strides( - const_cast(dlmt->dl_tensor.strides), - dlmt->dl_tensor.ndim); + const_cast(dlmt->dl_tensor.strides), dlmt->dl_tensor.ndim); minor_to_major = StridesToLayout(dimensions, strides).value(); } else { minor_to_major.resize(dlmt->dl_tensor.ndim); std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); } - xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, - minor_to_major); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, dimensions, minor_to_major); std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - xla::StatusOr> pjrt_buffer = device->client()->CreateViewOfDeviceBuffer( - static_cast(dlmt->dl_tensor.data) + - dlmt->dl_tensor.byte_offset, - shape, device, on_delete_callback); - XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer in " << __FUNCTION__; - XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null in " << __FUNCTION__; + xla::StatusOr> pjrt_buffer = + device->client()->CreateViewOfDeviceBuffer( + static_cast(dlmt->dl_tensor.data) + + dlmt->dl_tensor.byte_offset, + shape, device, on_delete_callback); + XLA_CHECK_OK(pjrt_buffer.status()) + << "Failed to create a pjrt buffer in " << __FUNCTION__; + XLA_CHECK(pjrt_buffer.value() != nullptr) + << "pjrt buffer is null in " << __FUNCTION__; + + runtime::ComputationClient::DataPtr data = + runtime::GetComputationClient()->CreateData( + runtime::GetComputationClient()->PjRtDeviceToString(device), shape, + std::move(pjrt_buffer.value())); - runtime::ComputationClient::DataPtr data = runtime::GetComputationClient()->CreateData(runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); - at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); return bridge::AtenFromXlaTensor(xla_tensor); } -} +} // namespace torch_xla diff --git a/torch_xla/csrc/dl_convertor.h b/torch_xla/csrc/dl_convertor.h index 07d4587146a..f5a54823e2e 100644 --- a/torch_xla/csrc/dl_convertor.h +++ b/torch_xla/csrc/dl_convertor.h @@ -1,8 +1,8 @@ #ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ #define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ -#include #include +#include namespace torch_xla { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2f4b90e42cf..e3fc165eb71 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #include #include #include @@ -35,8 +35,8 @@ #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" -#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/dl_convertor.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -1117,8 +1117,12 @@ void dlPack_Capsule_Destructor(PyObject* data) { } at::Tensor tensor_fromDLPack(PyObject* data) { - DLManagedTensor* dlMTensor = (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); - XLA_CHECK(dlMTensor != nullptr) << "from_dlpack received an invalid capsule. Note that a DLTensor capsule can be consumed only once. You may have already constructed a tensor from it once."; + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + XLA_CHECK(dlMTensor != nullptr) + << "from_dlpack received an invalid capsule. Note that a DLTensor " + "capsule can be consumed only once. You may have already constructed " + "a tensor from it once."; at::Tensor tensor = torch_xla::fromDLPack(dlMTensor); PyCapsule_SetName(data, "used_dltensor"); @@ -2543,17 +2547,17 @@ void InitXlaModuleBindings(py::module m) { NoGilSection nogil; dlMTensor = torch_xla::toDLPack(input); } - // return py::reinterpret_steal(PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor)); + // return py::reinterpret_steal(PyCapsule_New(dlMTensor, + // "dltensor", dlPack_Capsule_Destructor)); return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); - // m.def("_to_dlpack", &tensor_toDLPack, ""); // + // m.def("_to_dlpack", &tensor_toDLPack, ""); // // from a dlpack tensor to an XLA tensor m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { return tensor_fromDLPack(ext_data.ptr()); }); - // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 6eca719c896..b275ef562ee 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -25,11 +25,11 @@ #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" -#include "xla/types.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/types.h" namespace torch_xla { namespace runtime { @@ -260,9 +260,8 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; - virtual DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) = 0; + virtual DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) = 0; // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input @@ -281,7 +280,8 @@ class ComputationClient { // structure will be empty if there is no sharding, like with PjRtData. virtual std::optional GetDataSharding(DataPtr handle) = 0; - virtual std::string PjRtDeviceToString(xla::PjRtDevice* const device) const = 0; + virtual std::string PjRtDeviceToString( + xla::PjRtDevice* const device) const = 0; // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToDevice( @@ -312,7 +312,8 @@ class ComputationClient { virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; - virtual std::shared_ptr GetPjRtBuffer(const DataPtr handle) = 0; + virtual std::shared_ptr GetPjRtBuffer( + const DataPtr handle) = 0; // Compiles a set of computations. virtual std::vector Compile( @@ -356,7 +357,8 @@ class ComputationClient { virtual xla::PjRtPlatformId GetPlatformID() const = 0; - virtual absl::StatusOr LookupAddressableDevice(int local_device_id) const = 0; + virtual absl::StatusOr LookupAddressableDevice( + int local_device_id) const = 0; virtual size_t GetNumDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index e059be41a08..e2a72992d6f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -402,7 +402,8 @@ std::uintptr_t IfrtComputationClient::UnsafeBufferPointer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::shared_ptr IfrtComputationClient::GetPjRtBuffer(const DataPtr handle) { +std::shared_ptr IfrtComputationClient::GetPjRtBuffer( + const DataPtr handle) { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index b2185842289..f843a2e53e4 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -33,11 +33,10 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; + DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; std::vector GetDataShards(DataPtr data) override; @@ -96,7 +95,8 @@ class IfrtComputationClient : public ComputationClient { return client_->platform_id(); } - absl::StatusOr LookupAddressableDevice(int local_device_id) const override { + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c7896c09c20..ba3e9baf8c8 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -188,7 +188,8 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( ComputationClient::DataPtr PjRtComputationClient::CreateData( std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer) { - return std::make_shared(std::move(device), std::move(shape), pjrt_buffer); + return std::make_shared(std::move(device), std::move(shape), + pjrt_buffer); } std::vector PjRtComputationClient::GetDataShards( @@ -469,17 +470,19 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer( std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); - XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::StatusOr ptr = client_->UnsafeBufferPointer(pjrt_data->buffer.get()); XLA_CHECK(ptr.ok()); return ptr.value(); } -std::shared_ptr PjRtComputationClient::GetPjRtBuffer(const DataPtr handle) { +std::shared_ptr PjRtComputationClient::GetPjRtBuffer( + const DataPtr handle) { std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); - XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); return pjrt_data->buffer; } @@ -498,7 +501,8 @@ std::vector PjRtComputationClient::TransferFromDevice( // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; + XLA_CHECK(pjrt_data->buffer != nullptr) + << "PjRt buffer is null in " << __FUNCTION__; xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); @@ -508,7 +512,8 @@ std::vector PjRtComputationClient::TransferFromDevice( } for (auto& future : futures) { absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" << __FUNCTION__; + XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" + << __FUNCTION__; } InboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 1b1e0aa5c47..aff4b781f99 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,9 +32,8 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override; + DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer) override; std::vector GetDataShards(DataPtr data) override; @@ -99,8 +98,10 @@ class PjRtComputationClient : public ComputationClient { return client_->platform_id(); } - absl::StatusOr LookupAddressableDevice(int local_device_id) const override { - return client_->LookupAddressableDevice(xla::PjRtLocalDeviceId(local_device_id)); + absl::StatusOr LookupAddressableDevice( + int local_device_id) const override { + return client_->LookupAddressableDevice( + xla::PjRtLocalDeviceId(local_device_id)); } std::vector GetLocalDevices() const override; diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index 9ae99b8f802..9f93d532b27 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -1,8 +1,10 @@ from typing import Any import torch_xla + def to_dlpack(xla_tensor: Any): return torch_xla._XLAC._to_dlpack(xla_tensor) + def from_dlpack(ext_tensor: Any): return torch_xla._XLAC._from_dlpack(ext_tensor) From 93dba65821d142b9bd68c95eba12c32e6b313e6f Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Mon, 13 May 2024 20:27:10 +0000 Subject: [PATCH 15/19] fix comment --- test/test_operations.py | 10 ++--- torch_xla/csrc/dl_convertor.cpp | 37 ++++--------------- .../csrc/runtime/pjrt_computation_client.cc | 11 +++++- torch_xla/csrc/tensor_util.cpp | 22 +++++++++++ torch_xla/csrc/tensor_util.h | 3 ++ 5 files changed, 46 insertions(+), 37 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index b55cadcbcd7..0e2c603204b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2456,17 +2456,17 @@ class TestDLPack(parameterized.TestCase): def _test_dlpack_capsule_conversion_helper(self, xla_tensor): dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule - got = xdlpack.from_dlpack(dlpt) + xla_tensor2 = xdlpack.from_dlpack(dlpt) - self.assertEqual(xla_tensor.device, got.device) - self.assertTrue(torch.allclose(xla_tensor.cpu(), got.cpu())) + self.assertEqual(xla_tensor.device, xla_tensor2.device) + self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu())) self.assertRaisesRegex(RuntimeError, "DLTensor capsule can be consumed only once", lambda: xdlpack.from_dlpack(dlpt)) self.assertEqual( torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), - torch_xla._XLAC._unsafe_buffer_pointer(got)) + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @@ -2499,8 +2499,6 @@ def test_dlpack_roundtrip(self, dtype): xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) xm.mark_step() - # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. - xm.wait_device_ops() self._test_dlpack_capsule_conversion_helper(xla_tensor_3) @onlyIfTorchSupportsCUDA diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 3c0371efcf6..5790b7c2112 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -19,47 +19,25 @@ namespace torch_xla { -std::shared_ptr get_data_handle( - const at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; - if (xtensor->CurrentDataHandle() != nullptr) { - TF_VLOG(4) << "The xla tensor has a current data handle."; - return std::dynamic_pointer_cast( - xtensor->CurrentDataHandle()); - } else if (xtensor->CurrentIrValue().node != nullptr) { - DeviceData* device_data = - DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (device_data != nullptr) { - return UnwrapXlaData(device_data->data()); - } - TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; - } - TF_VLOG(4) - << "The xla tensor either has no current data handle or has no IR value."; - return nullptr; -} - -struct TorchXLADLMTensor { - ~TorchXLADLMTensor(); +struct DLPackTensor { + ~DLPackTensor(); std::unique_ptr external_reference; std::shared_ptr buffer_reference; - // at::Tensor source_tensor; std::vector shape; std::vector strides; DLManagedTensor tensor; }; -TorchXLADLMTensor::~TorchXLADLMTensor() { +DLPackTensor::~DLPackTensor() { if (external_reference) { external_reference.reset(nullptr); } } -void TorchXLADLMTensorDeleter(DLManagedTensor* t) { +void DLPackTensorDeleter(DLManagedTensor* t) { if (t) { - delete static_cast(t->manager_ctx); + delete static_cast(t->manager_ctx); } } @@ -151,7 +129,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; } - auto pack = std::make_unique(); + auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block @@ -163,11 +141,10 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { XLA_CHECK_OK(status); } pack->buffer_reference = pjrt_buffer; - // pack->source_tensor = input; dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); pack->tensor.manager_ctx = pack.get(); - pack->tensor.deleter = TorchXLADLMTensorDeleter; + pack->tensor.deleter = DLPackTensorDeleter; dt.device = DLDeviceForDevice(*pjrt_buffer->device()); dt.device.device_id = pjrt_buffer->device()->local_hardware_id(); dt.ndim = pjrt_buffer->dimensions().size(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ba3e9baf8c8..d46f2712166 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -482,8 +482,17 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( const DataPtr handle) { std::shared_ptr pjrt_data = std::dynamic_pointer_cast(handle); + XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); - return pjrt_data->buffer; + std::shared_ptr pjrt_buffer = pjrt_data->buffer; + if (pjrt_buffer != nullptr) { + return pjrt_buffer; + } else { + TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops " + "to finish."; + WaitDeviceOps({}); + return std::dynamic_pointer_cast(handle)->buffer; + } } std::vector PjRtComputationClient::TransferFromDevice( diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index dd13bd63d1b..870f6945973 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -17,6 +17,7 @@ #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" @@ -931,4 +932,25 @@ xla::PrimitiveType GetShapeDimensionType( return xla::PrimitiveType::S32; } +std::shared_ptr get_data_handle( + const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; + if (xtensor->CurrentDataHandle() != nullptr) { + TF_VLOG(4) << "The xla tensor has a current data handle."; + return std::dynamic_pointer_cast( + xtensor->CurrentDataHandle()); + } else if (xtensor->CurrentIrValue().node != nullptr) { + DeviceData* device_data = + DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (device_data != nullptr) { + return UnwrapXlaData(device_data->data()); + } + TF_VLOG(4) << "The xla tensor has IR value but does not have device data."; + } + TF_VLOG(4) + << "The xla tensor either has no current data handle or has no IR value."; + return nullptr; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 7d726c00b50..0804d3e9f78 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -212,6 +212,9 @@ inline std::vector xla_expand_outplace(at::TensorList to_expand) { } } +std::shared_ptr get_data_handle( + const at::Tensor& input); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_TENSOR_UTIL_H_ From 867dd53f22d5fe8dfafa30679e0793cddbaea23a Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 14 May 2024 20:15:36 +0000 Subject: [PATCH 16/19] fix comments --- test/test_operations.py | 14 +++++++++++ torch_xla/csrc/dl_convertor.cpp | 23 ++++++++----------- torch_xla/csrc/runtime/computation_client.h | 3 --- .../csrc/runtime/ifrt_computation_client.h | 5 ---- .../csrc/runtime/pjrt_computation_client.h | 4 ++-- torch_xla/csrc/tensor_util.cpp | 1 - 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 0e2c603204b..eb2ccb1332d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2513,12 +2513,16 @@ def test_dlpack_pytorch_cuda_to_xla(self): t1_cuda = torch.arange(5).cuda() dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda) xla_t1 = xdlpack.from_dlpack(dlt1) + self.assertEqual(xla_t1.device.type, 'xla') + self.assertEqual(xla_t1.device.index, t1_cuda.device.index) t1_cuda[0] = t1_cuda[0] + 20 self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) t2_cuda = torch.tensor(5).cuda() dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda) xla_t2 = xdlpack.from_dlpack(dlt2) + self.assertEqual(xla_t2.device.type, 'xla') + self.assertEqual(xla_t2.device.index, t2_cuda.device.index) t2_cuda.fill_(6) self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) @@ -2528,6 +2532,8 @@ def test_dlpack_xla_to_pytorch_cuda(self): xla_t1 = torch.arange(5).to(xm.xla_device()) dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) + self.assertEqual(cuda_t1.device.type, 'cuda') + self.assertEqual(cuda_t1.device.index, xla_t1.device.index) cuda_t1[0] = cuda_t1[0] + 20 self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) @@ -2538,10 +2544,14 @@ def test_dlpack_non_default_layout(self): t1 = cuda_t.t() xla_t1 = xdlpack.from_dlpack(t1.__dlpack__()) + self.assertEqual(xla_t1.device.type, 'xla') + self.assertEqual(xla_t1.device.index, 0) self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu())) t2 = cuda_t[0] xla_t2 = xdlpack.from_dlpack(t2.__dlpack__()) + self.assertEqual(xla_t2.device.type, 'xla') + self.assertEqual(xla_t2.device.index, 0) self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) t3 = cuda_t[:, 0] @@ -2552,10 +2562,14 @@ def test_dlpack_non_default_layout(self): t4 = cuda_t[1, :] xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) + self.assertEqual(xla_t4.device.type, 'xla') + self.assertEqual(xla_t4.device.index, 0) self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu())) t5 = cuda_t[1] xla_t5 = xdlpack.from_dlpack(t5.__dlpack__()) + self.assertEqual(xla_t5.device.type, 'xla') + self.assertEqual(xla_t5.device.index, 0) self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 5790b7c2112..3742f5af397 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -121,13 +122,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; - if (pjrt_buffer->IsTuple()) { - XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not " - "implemented for tuple buffers."; - } - if (pjrt_buffer->has_dynamic_dimensions()) { - XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; - } + XLA_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; + XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack."; auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; @@ -180,6 +177,7 @@ absl::StatusOr DeviceForDLDevice(const DLDevice& context) { } } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { if (type.lanes != 1) { return tsl::errors::Unimplemented( @@ -294,11 +292,8 @@ absl::StatusOr> StridesToLayout( } at::Tensor fromDLPack(DLManagedTensor* dlmt) { - if (dlmt->dl_tensor.ndim < 0) { - XLA_ERROR() - << "Number of dimensions in DLManagedTensor must be nonnegative, got " + XLA_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; - } xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); @@ -328,12 +323,12 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { dlmt->dl_tensor.byte_offset, shape, device, on_delete_callback); XLA_CHECK_OK(pjrt_buffer.status()) - << "Failed to create a pjrt buffer in " << __FUNCTION__; + << "Failed to create a pjrt buffer."; XLA_CHECK(pjrt_buffer.value() != nullptr) - << "pjrt buffer is null in " << __FUNCTION__; + << "pjrt buffer is null."; runtime::ComputationClient::DataPtr data = - runtime::GetComputationClient()->CreateData( + runtime::PjRtComputationClient::CreateData( runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b275ef562ee..0368962d411 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -260,9 +260,6 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; - virtual DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) = 0; - // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input // wrapped inside a vector. diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f843a2e53e4..ca40c8fb02c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -33,11 +33,6 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index aff4b781f99..8e36bff5e99 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,8 +32,8 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override; + static DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer); std::vector GetDataShards(DataPtr data) override; diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 870f6945973..8822b6de7c4 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -935,7 +935,6 @@ xla::PrimitiveType GetShapeDimensionType( std::shared_ptr get_data_handle( const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor) << "The input has to be an XLA tensor."; if (xtensor->CurrentDataHandle() != nullptr) { TF_VLOG(4) << "The xla tensor has a current data handle."; return std::dynamic_pointer_cast( From 345945eb3618245db5b977bcd56cc13865745c27 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 16 May 2024 19:59:53 +0000 Subject: [PATCH 17/19] fix comment --- test/test_operations.py | 47 ++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index eb2ccb1332d..60e42c7a7a2 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -30,7 +30,10 @@ import torch.nn.functional as F import torch.optim as optim from torch.testing._internal.common_device_type import dtypes -from torch.testing._internal.common_dtype import (all_types_and_complex_and) +from torch.testing._internal.common_dtype import ( + all_types_and_complex_and, + all_types_and, + ) import torch_xla import torch_xla.core.xla_builder as xb import torch_xla.core.xla_op_registry as xor @@ -2468,6 +2471,21 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) + def test_dlpack_roundtrip_tensor(self, dtype): + xla_device = xm.xla_device() + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr + # xla_tensor_2 uses XLANativeFunctions::_to_copy + xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) + self._test_dlpack_capsule_conversion_helper(xla_tensor_2) + + # xla_tensor_3 uses arange_out IR node. + xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) + xm.mark_step() + self._test_dlpack_capsule_conversion_helper(xla_tensor_3) + @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and_complex_and(torch.half, @@ -2475,14 +2493,7 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): torch.bool, torch.uint16, torch.uint32, torch.uint64)) - def test_dlpack_roundtrip(self, dtype): - # "arange_cpu" not implemented for complex64 and complex128. - # xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) failed with `RuntimeError: false INTERNAL ASSERT FAILED at "/ansible/pytorch/torch/csrc/lazy/core/hash.h":139, please report a bug to PyTorch. Unsupported scalar type:UInt64`, similar to other uint. - if dtype in { - torch.complex128, torch.complex64, torch.uint64, torch.uint32, - torch.uint16, torch.bool - }: - return + def test_dlpack_roundtrip_scalar(self, dtype): xla_device = xm.xla_device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) # `mark_step` ensures xtensor->CurrentDataHandle() != nullptr @@ -2493,14 +2504,6 @@ def test_dlpack_roundtrip(self, dtype): # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr self._test_dlpack_capsule_conversion_helper(xla_tensor_1) - # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) - self._test_dlpack_capsule_conversion_helper(xla_tensor_2) - - xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) - xm.mark_step() - self._test_dlpack_capsule_conversion_helper(xla_tensor_3) - @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_roundtrip_bool(self): @@ -2526,6 +2529,16 @@ def test_dlpack_pytorch_cuda_to_xla(self): t2_cuda.fill_(6) self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) + cuda1 = torch.device('cuda:1') + t3_cuda = torch.tensor(5, device=cuda1) + dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda) + xla_t3 = xdlpack.from_dlpack(dlt3) + self.assertEqual(xla_t3.device.type, 'xla') + self.assertEqual(xla_t3.device.index, t3_cuda.device.index, msg='both value should 1. xla_t3.device should be xla:1.') + t3_cuda.fill_(6) + self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) + + @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda(self): From 2e8ff4ae45a5c54b72465391c10f5be3bcc66ebd Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Thu, 16 May 2024 20:52:52 +0000 Subject: [PATCH 18/19] linter --- test/test_operations.py | 12 ++++++----- torch_xla/csrc/dl_convertor.cpp | 21 ++++++++++--------- .../csrc/runtime/pjrt_computation_client.h | 2 +- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 60e42c7a7a2..9be9e6c90cb 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -31,9 +31,9 @@ import torch.optim as optim from torch.testing._internal.common_device_type import dtypes from torch.testing._internal.common_dtype import ( - all_types_and_complex_and, - all_types_and, - ) + all_types_and_complex_and, + all_types_and, +) import torch_xla import torch_xla.core.xla_builder as xb import torch_xla.core.xla_op_registry as xor @@ -2534,11 +2534,13 @@ def test_dlpack_pytorch_cuda_to_xla(self): dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda) xla_t3 = xdlpack.from_dlpack(dlt3) self.assertEqual(xla_t3.device.type, 'xla') - self.assertEqual(xla_t3.device.index, t3_cuda.device.index, msg='both value should 1. xla_t3.device should be xla:1.') + self.assertEqual( + xla_t3.device.index, + t3_cuda.device.index, + msg='both value should 1. xla_t3.device should be xla:1.') t3_cuda.fill_(6) self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) - @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda(self): diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 3742f5af397..aa7c2ae6a25 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -6,8 +6,8 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/tensor.h" @@ -122,9 +122,11 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; - XLA_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not " - "implemented for tuple buffers."; - XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack."; + XLA_CHECK(!pjrt_buffer->IsTuple()) + << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; + XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) + << "Unimplemented. DynamicShape is not implemented in DLPack."; auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; @@ -292,8 +294,9 @@ absl::StatusOr> StridesToLayout( } at::Tensor fromDLPack(DLManagedTensor* dlmt) { - XLA_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got " - << dlmt->dl_tensor.ndim; + XLA_CHECK(dlmt->dl_tensor.ndim >= 0) + << "Number of dimensions in DLManagedTensor must be nonnegative, got " + << dlmt->dl_tensor.ndim; xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); @@ -322,10 +325,8 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, shape, device, on_delete_callback); - XLA_CHECK_OK(pjrt_buffer.status()) - << "Failed to create a pjrt buffer."; - XLA_CHECK(pjrt_buffer.value() != nullptr) - << "pjrt buffer is null."; + XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer."; + XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null."; runtime::ComputationClient::DataPtr data = runtime::PjRtComputationClient::CreateData( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 8e36bff5e99..7b54e68ff3f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -33,7 +33,7 @@ class PjRtComputationClient : public ComputationClient { std::optional sharding = std::nullopt) override; static DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer); + std::shared_ptr pjrt_buffer); std::vector GetDataShards(DataPtr data) override; From 98cc7398a90d2400b6be86ce62d96c53beaabc48 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 21 May 2024 16:35:57 +0000 Subject: [PATCH 19/19] fix comments and run linter --- torch_xla/csrc/dl_convertor.cpp | 4 ++++ torch_xla/csrc/init_python_bindings.cpp | 11 ++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index aa7c2ae6a25..d29401be8fe 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -52,6 +52,7 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { << " cannot be used as a DLPack device."; } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { DLDevice dlDevice; dlDevice.device_type = DLDeviceTypeForDevice(device); @@ -59,6 +60,7 @@ DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { return dlDevice; } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { switch (type) { case xla::PrimitiveType::S8: @@ -161,6 +163,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { return &(pack.release()->tensor); } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: @@ -262,6 +265,7 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { } } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr> StridesToLayout( absl::Span dims, absl::Span strides) { XLA_CHECK_EQ(dims.size(), strides.size()); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e3fc165eb71..d59f73b6662 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2541,19 +2541,24 @@ void InitXlaModuleBindings(py::module m) { }); // from an XLA tensor to a dlpack tensor. + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { DLManagedTensor* dlMTensor; { NoGilSection nogil; dlMTensor = torch_xla::toDLPack(input); } - // return py::reinterpret_steal(PyCapsule_New(dlMTensor, - // "dltensor", dlPack_Capsule_Destructor)); return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); - // m.def("_to_dlpack", &tensor_toDLPack, ""); // // from a dlpack tensor to an XLA tensor + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { return tensor_fromDLPack(ext_data.ptr()); });