Skip to content

Commit

Permalink
fix comments and run linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 21, 2024
1 parent 2e8ff4a commit 98cc739
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
4 changes: 4 additions & 0 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) {
<< " cannot be used as a DLPack device.";
}

// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
DLDevice DLDeviceForDevice(const xla::PjRtDevice& device) {
DLDevice dlDevice;
dlDevice.device_type = DLDeviceTypeForDevice(device);
dlDevice.device_id = device.local_hardware_id();
return dlDevice;
}

// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
DLDataType PrimitiveTypeToDLDataType(xla::PrimitiveType type) {
switch (type) {
case xla::PrimitiveType::S8:
Expand Down Expand Up @@ -161,6 +163,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
return &(pack.release()->tensor);
}

// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
switch (context.device_type) {
case DLDeviceType::kDLCPU:
Expand Down Expand Up @@ -262,6 +265,7 @@ absl::StatusOr<xla::PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
}
}

// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
absl::StatusOr<std::vector<int64_t>> StridesToLayout(
absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
XLA_CHECK_EQ(dims.size(), strides.size());
Expand Down
11 changes: 8 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2541,19 +2541,24 @@ void InitXlaModuleBindings(py::module m) {
});

// from an XLA tensor to a dlpack tensor.
// If ext_data is the result of an CUDA computation, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results.
m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle {
DLManagedTensor* dlMTensor;
{
NoGilSection nogil;
dlMTensor = torch_xla::toDLPack(input);
}
// return py::reinterpret_steal<py::object>(PyCapsule_New(dlMTensor,
// "dltensor", dlPack_Capsule_Destructor));
return PyCapsule_New(dlMTensor, "dltensor", dlPack_Capsule_Destructor);
});
// m.def("_to_dlpack", &tensor_toDLPack, ""); //

// from a dlpack tensor to an XLA tensor
// If ext_data is the result of an CUDA computation, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results.
m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor {
return tensor_fromDLPack(ext_data.ptr());
});
Expand Down

0 comments on commit 98cc739

Please sign in to comment.