From b0198dffc90dbe328ae219e50d40043a18de85ca Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Tue, 14 May 2024 20:15:36 +0000 Subject: [PATCH] fix comments --- torch_xla/csrc/dl_convertor.cpp | 18 ++++++------------ torch_xla/csrc/runtime/computation_client.h | 3 --- .../csrc/runtime/ifrt_computation_client.h | 5 ----- .../csrc/runtime/pjrt_computation_client.h | 4 ++-- 4 files changed, 8 insertions(+), 22 deletions(-) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 5790b7c2112d..9e44cd8f5b27 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -121,13 +122,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { runtime::GetComputationClient()->GetPjRtBuffer(handle); XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; - if (pjrt_buffer->IsTuple()) { - XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not " - "implemented for tuple buffers."; - } - if (pjrt_buffer->has_dynamic_dimensions()) { - XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack."; - } + XLA_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not " + "implemented for tuple buffers."; + XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack."; auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; @@ -294,11 +291,8 @@ absl::StatusOr> StridesToLayout( } at::Tensor fromDLPack(DLManagedTensor* dlmt) { - if (dlmt->dl_tensor.ndim < 0) { - XLA_ERROR() - << "Number of dimensions in DLManagedTensor must be nonnegative, got " + XLA_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; - } xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); absl::Span dimensions( const_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); @@ -333,7 +327,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { << "pjrt buffer is null in " << __FUNCTION__; runtime::ComputationClient::DataPtr data = - runtime::GetComputationClient()->CreateData( + runtime::PjRtComputationClient::CreateData( runtime::GetComputationClient()->PjRtDeviceToString(device), shape, std::move(pjrt_buffer.value())); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b275ef562ee2..0368962d4110 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -260,9 +260,6 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; - virtual DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) = 0; - // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input // wrapped inside a vector. diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index f843a2e53e47..ca40c8fb02c3 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -33,11 +33,6 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; - std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index aff4b781f99e..8e36bff5e99f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -32,8 +32,8 @@ class PjRtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; - DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) override; + static DataPtr CreateData(std::string device, xla::Shape shape, + std::shared_ptr pjrt_buffer); std::vector GetDataShards(DataPtr data) override;