Skip to content

Commit

Permalink
get rid of PjRtBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jun 28, 2024
1 parent 4747718 commit 1d61699
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 178 deletions.
11 changes: 0 additions & 11 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ cc_library(
":env_hash",
":env_vars",
":operation_manager",
":pjrt_compile_only",
":pjrt_registry",
":profiler",
":stablehlo_helper",
Expand Down Expand Up @@ -226,16 +225,6 @@ cc_test(
],
)

cc_library(
name = "pjrt_compile_only",
srcs = ["pjrt_compile_only.cc"],
hdrs = ["pjrt_compile_only.h"],
deps = [
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:pjrt_future",
],
)

cc_library(
name = "pjrt_registry",
srcs = ["pjrt_registry.cc"],
Expand Down
42 changes: 6 additions & 36 deletions torch_xla/csrc/runtime/pjrt_compilation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,6 @@ std::unordered_map<int, int> build_index_map(
return device_index;
}

// Builds the xla::Shape of the output xla::Literal on the host.
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = xla::GetXlaLayoutUnsafe(buffer->layout());

return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}

torch::lazy::hash_t hash_comp_env() {
// TODO(piz): since the client is nullptr, we can't retrive all information
// like PjRtComputationClient. Think about a way to construct the hashing.
Expand Down Expand Up @@ -196,10 +187,9 @@ ComputationClient::DataPtr PjRtCompilationClient::CreateDataPlaceholder(
}

ComputationClient::DataPtr PjRtCompilationClient::CreateData(
std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) {
std::string device, xla::Shape shape, std::shared_ptr<Buffer> buffer) {
return std::make_shared<PjRtData>(std::move(device), std::move(shape),
pjrt_buffer);
buffer);
}

std::vector<ComputationClient::DataPtr> PjRtCompilationClient::GetDataShards(
Expand Down Expand Up @@ -271,8 +261,7 @@ std::vector<ComputationClient::DataPtr> PjRtCompilationClient::TransferToDevice(
absl::Span<const bool> dynamic_dimensions;
xla::Shape shape(tensor->primitive_type(), tensor->dimensions(),
dynamic_dimensions, tuple_shape);
std::shared_ptr<xla::PjRtBuffer> buffer =
std::make_shared<xla::CompileOnlyPjRtBuffer>(shape);
std::shared_ptr<Buffer> buffer = std::make_shared<Buffer>(shape);
ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
datas.push_back(data);
Expand Down Expand Up @@ -314,15 +303,7 @@ ComputationClient::DataPtr PjRtCompilationClient::CopyToDevice(

xla::PjRtDevice* dst_device = StringToPjRtDevice(dst);
XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable.";

// Returns error if the buffer is already on `dst_device`.
xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> status_or =
pjrt_data->buffer->CopyToDevice(dst_device);
if (!status_or.ok()) {
return data;
}
return std::make_shared<PjRtData>(dst, pjrt_data->shape(),
std::move(status_or.value()));
return std::make_shared<PjRtData>(dst, pjrt_data->shape(), pjrt_data->buffer);
}

std::shared_ptr<PjRtCompilationClient::PjRtData>
Expand Down Expand Up @@ -473,19 +454,8 @@ std::uintptr_t PjRtCompilationClient::UnsafeBufferPointer(

std::shared_ptr<xla::PjRtBuffer> PjRtCompilationClient::GetPjRtBuffer(
const DataPtr handle) {
std::shared_ptr<PjRtData> pjrt_data =
std::dynamic_pointer_cast<PjRtData>(handle);

XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString();
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = pjrt_data->buffer;
if (pjrt_buffer != nullptr) {
return pjrt_buffer;
} else {
TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops "
"to finish.";
WaitDeviceOps({});
return std::dynamic_pointer_cast<PjRtData>(handle)->buffer;
}
TF_LOG(ERROR) << "AOT compilation is unable to get buffer data from device";
return std::shared_ptr<xla::PjRtBuffer>(nullptr);
}

std::vector<xla::Literal> PjRtCompilationClient::TransferFromDevice(
Expand Down
22 changes: 9 additions & 13 deletions torch_xla/csrc/runtime/pjrt_compilation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/pjrt_compile_only.h" // PEI: leave only pjrtbuffer inside
#include "torch_xla/csrc/runtime/util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/threadpool.h"
Expand All @@ -24,6 +23,11 @@
namespace torch_xla {
namespace runtime {

struct Buffer {
xla::Shape shape;
Buffer(xla::Shape shape) : shape(shape) {}
};

class PjRtCompilationClient : public ComputationClient {
public:
PjRtCompilationClient(std::string& virtual_topology_str);
Expand All @@ -34,7 +38,7 @@ class PjRtCompilationClient : public ComputationClient {
std::optional<xla::OpSharding> sharding = std::nullopt) override;

static DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer);
std::shared_ptr<Buffer> buffer);

std::vector<DataPtr> GetDataShards(DataPtr data) override;

Expand Down Expand Up @@ -175,25 +179,17 @@ class PjRtCompilationClient : public ComputationClient {
: Data(std::move(device), std::move(device_shape)) {}

PjRtData(std::string device, xla::Shape device_shape,
std::shared_ptr<xla::PjRtBuffer> buffer)
std::shared_ptr<Buffer> buffer)
: Data(std::move(device), std::move(device_shape)), buffer(buffer) {}

PjRtData(std::string device, std::shared_ptr<xla::PjRtBuffer> buffer)
: Data(std::move(device),
xla::Shape(buffer->element_type(), buffer->dimensions(),
buffer->is_dynamic_dimension(), {})),
buffer(buffer) {}

Handle GetHandle() override {
XLA_CHECK(HasValue())
<< "buffer with shape " << shape().ToString() << " on device "
<< device() << (buffer == nullptr ? " is null" : " is deleted");
return reinterpret_cast<std::uintptr_t>(buffer.get());
};
void Assign(const torch::lazy::BackendData& data) override;
bool HasValue() const override {
return buffer != nullptr && !buffer->IsDeleted();
};
bool HasValue() const override { return buffer != nullptr; };

bool HasSharding() const override { return false; }

Expand All @@ -217,7 +213,7 @@ class PjRtCompilationClient : public ComputationClient {
return ss.str();
}

std::shared_ptr<xla::PjRtBuffer> buffer;
std::shared_ptr<Buffer> buffer;
};

struct PjRtShardedData : public Data {
Expand Down
70 changes: 0 additions & 70 deletions torch_xla/csrc/runtime/pjrt_compile_only.cc

This file was deleted.

48 changes: 0 additions & 48 deletions torch_xla/csrc/runtime/pjrt_compile_only.h

This file was deleted.

0 comments on commit 1d61699

Please sign in to comment.