diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index aa7c2ae6a25..d29401be8fe 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -52,6 +52,7 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { << " cannot be used as a DLPack device."; } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { DLDevice dlDevice; dlDevice.device_type = DLDeviceTypeForDevice(device); @@ -59,6 +60,7 @@ DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) { return dlDevice; } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) { switch (type) { case xla::PrimitiveType::S8: @@ -161,6 +163,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { return &(pack.release()->tensor); } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { case DLDeviceType::kDLCPU: @@ -262,6 +265,7 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { } } +// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr> StridesToLayout( absl::Span dims, absl::Span strides) { XLA_CHECK_EQ(dims.size(), strides.size()); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e3fc165eb71..d59f73b6662 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2541,19 +2541,24 @@ void InitXlaModuleBindings(py::module m) { }); // from an XLA tensor to a dlpack tensor. + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle { DLManagedTensor* dlMTensor; { NoGilSection nogil; dlMTensor = torch_xla::toDLPack(input); } - // return py::reinterpret_steal(PyCapsule_New(dlMTensor, - // "dltensor", dlPack_Capsule_Destructor)); return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor); }); - // m.def("_to_dlpack", &tensor_toDLPack, ""); // // from a dlpack tensor to an XLA tensor + // If ext_data is the result of an CUDA computation, we should synchronize + // (waits for all kernels in all streams on a CUDA device to complete) if the + // current stream is different from the ext_data's stream. Otherwise, we may + // risk of getting incorrect results. m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor { return tensor_fromDLPack(ext_data.ptr()); });