Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed May 14, 2024
1 parent 1f1eeeb commit b0198df
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 22 deletions.
18 changes: 6 additions & 12 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
Expand Down Expand Up @@ -121,13 +122,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
runtime::GetComputationClient()->GetPjRtBuffer(handle);
XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";

if (pjrt_buffer->IsTuple()) {
XLA_ERROR() << "Unimplemented. BufferToDLPackManagedTensor is not "
"implemented for tuple buffers.";
}
if (pjrt_buffer->has_dynamic_dimensions()) {
XLA_ERROR() << "Unimplemented. DynamicShape is not implemented in DLPack.";
}
XLA_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not "
"implemented for tuple buffers.";
XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack.";

auto pack = std::make_unique<DLPackTensor>();
DLTensor& dt = pack->tensor.dl_tensor;
Expand Down Expand Up @@ -294,11 +291,8 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(
}

at::Tensor fromDLPack(DLManagedTensor* dlmt) {
if (dlmt->dl_tensor.ndim < 0) {
XLA_ERROR()
<< "Number of dimensions in DLManagedTensor must be nonnegative, got "
XLA_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got "
<< dlmt->dl_tensor.ndim;
}
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value();
absl::Span<int64_t const> dimensions(
const_cast<int64_t*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
Expand Down Expand Up @@ -333,7 +327,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
<< "pjrt buffer is null in " << __FUNCTION__;

runtime::ComputationClient::DataPtr data =
runtime::GetComputationClient()->CreateData(
runtime::PjRtComputationClient::CreateData(
runtime::GetComputationClient()->PjRtDeviceToString(device), shape,
std::move(pjrt_buffer.value()));

Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ class ComputationClient {
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) = 0;

virtual DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) = 0;

// Returns data shards. We expect this to be called on PjRtShardedData to
// retrieve the shards. If other data type is passed, it returns the input
// wrapped inside a vector.
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ class IfrtComputationClient : public ComputationClient {
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) override;

DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
};

std::vector<DataPtr> GetDataShards(DataPtr data) override;

DataPtr GetDataShard(DataPtr data, size_t index) override;
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class PjRtComputationClient : public ComputationClient {
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) override;

DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) override;
static DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer);

std::vector<DataPtr> GetDataShards(DataPtr data) override;

Expand Down

0 comments on commit b0198df

Please sign in to comment.