Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 16, 2024
1 parent 70ebbba commit 8375645
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
12 changes: 7 additions & 5 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import torch.optim as optim
from torch.testing._internal.common_device_type import dtypes
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
all_types_and,
)
all_types_and_complex_and,
all_types_and,
)
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
Expand Down Expand Up @@ -2527,11 +2527,13 @@ def test_dlpack_pytorch_cuda_to_xla(self):
dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda)
xla_t3 = xdlpack.from_dlpack(dlt3)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(xla_t3.device.index, t3_cuda.device.index, msg='both value should 1. xla_t3.device should be xla:1.')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))


@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_xla_to_pytorch_cuda(self):
Expand Down
21 changes: 11 additions & 10 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/tensor.h"
Expand Down Expand Up @@ -122,9 +122,11 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
runtime::GetComputationClient()->GetPjRtBuffer(handle);
XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";

XLA_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not "
"implemented for tuple buffers.";
XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack.";
XLA_CHECK(!pjrt_buffer->IsTuple())
<< "Unimplemented. BufferToDLPackManagedTensor is not "
"implemented for tuple buffers.";
XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions())
<< "Unimplemented. DynamicShape is not implemented in DLPack.";

auto pack = std::make_unique<DLPackTensor>();
DLTensor& dt = pack->tensor.dl_tensor;
Expand Down Expand Up @@ -292,8 +294,9 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(
}

at::Tensor fromDLPack(DLManagedTensor* dlmt) {
XLA_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got "
<< dlmt->dl_tensor.ndim;
XLA_CHECK(dlmt->dl_tensor.ndim >= 0)
<< "Number of dimensions in DLManagedTensor must be nonnegative, got "
<< dlmt->dl_tensor.ndim;
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value();
absl::Span<int64_t const> dimensions(
const_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
Expand Down Expand Up @@ -322,10 +325,8 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
static_cast<char*>(dlmt->dl_tensor.data) +
dlmt->dl_tensor.byte_offset,
shape, device, on_delete_callback);
XLA_CHECK_OK(pjrt_buffer.status())
<< "Failed to create a pjrt buffer.";
XLA_CHECK(pjrt_buffer.value() != nullptr)
<< "pjrt buffer is null.";
XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer.";
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null.";

runtime::ComputationClient::DataPtr data =
runtime::PjRtComputationClient::CreateData(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PjRtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> sharding = std::nullopt) override;

static DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer);
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer);

std::vector<DataPtr> GetDataShards(DataPtr data) override;

Expand Down

0 comments on commit 8375645

Please sign in to comment.