diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 17710f3d48a..74667c67028 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -147,7 +147,6 @@ cc_library( ":env_hash", ":env_vars", ":operation_manager", - ":pjrt_compile_only", ":pjrt_registry", ":profiler", ":stablehlo_helper", @@ -226,16 +225,6 @@ cc_test( ], ) -cc_library( - name = "pjrt_compile_only", - srcs = ["pjrt_compile_only.cc"], - hdrs = ["pjrt_compile_only.h"], - deps = [ - "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_future", - ], -) - cc_library( name = "pjrt_registry", srcs = ["pjrt_registry.cc"], diff --git a/torch_xla/csrc/runtime/pjrt_compilation_client.cc b/torch_xla/csrc/runtime/pjrt_compilation_client.cc index b821b4351fe..dfb10ffab43 100644 --- a/torch_xla/csrc/runtime/pjrt_compilation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_compilation_client.cc @@ -53,15 +53,6 @@ std::unordered_map build_index_map( return device_index; } -// Builds the xla::Shape of the output xla::Literal on the host. -xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { - xla::Shape shape = xla::ShapeUtil::MakeShape( - buffer->element_type(), buffer->logical_dimensions().value()); - *shape.mutable_layout() = xla::GetXlaLayoutUnsafe(buffer->layout()); - - return xla::ShapeUtil::DeviceShapeToHostShape(shape); -} - torch::lazy::hash_t hash_comp_env() { // TODO(piz): since the client is nullptr, we can't retrive all information // like PjRtComputationClient. Think about a way to construct the hashing. @@ -196,10 +187,9 @@ ComputationClient::DataPtr PjRtCompilationClient::CreateDataPlaceholder( } ComputationClient::DataPtr PjRtCompilationClient::CreateData( - std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer) { + std::string device, xla::Shape shape, std::shared_ptr buffer) { return std::make_shared(std::move(device), std::move(shape), - pjrt_buffer); + buffer); } std::vector PjRtCompilationClient::GetDataShards( @@ -271,8 +261,7 @@ std::vector PjRtCompilationClient::TransferToDevice( absl::Span dynamic_dimensions; xla::Shape shape(tensor->primitive_type(), tensor->dimensions(), dynamic_dimensions, tuple_shape); - std::shared_ptr buffer = - std::make_shared(shape); + std::shared_ptr buffer = std::make_shared(shape); ComputationClient::DataPtr data = std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); @@ -314,15 +303,7 @@ ComputationClient::DataPtr PjRtCompilationClient::CopyToDevice( xla::PjRtDevice* dst_device = StringToPjRtDevice(dst); XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable."; - - // Returns error if the buffer is already on `dst_device`. - xla::StatusOr> status_or = - pjrt_data->buffer->CopyToDevice(dst_device); - if (!status_or.ok()) { - return data; - } - return std::make_shared(dst, pjrt_data->shape(), - std::move(status_or.value())); + return std::make_shared(dst, pjrt_data->shape(), pjrt_data->buffer); } std::shared_ptr @@ -473,19 +454,8 @@ std::uintptr_t PjRtCompilationClient::UnsafeBufferPointer( std::shared_ptr PjRtCompilationClient::GetPjRtBuffer( const DataPtr handle) { - std::shared_ptr pjrt_data = - std::dynamic_pointer_cast(handle); - - XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString(); - std::shared_ptr pjrt_buffer = pjrt_data->buffer; - if (pjrt_buffer != nullptr) { - return pjrt_buffer; - } else { - TF_VLOG(3) << "The pjrt buffer is null so we need to wait for device ops " - "to finish."; - WaitDeviceOps({}); - return std::dynamic_pointer_cast(handle)->buffer; - } + TF_LOG(ERROR) << "AOT compilation is unable to get buffer data from device"; + return std::shared_ptr(nullptr); } std::vector PjRtCompilationClient::TransferFromDevice( diff --git a/torch_xla/csrc/runtime/pjrt_compilation_client.h b/torch_xla/csrc/runtime/pjrt_compilation_client.h index b232d88410f..234a8363861 100644 --- a/torch_xla/csrc/runtime/pjrt_compilation_client.h +++ b/torch_xla/csrc/runtime/pjrt_compilation_client.h @@ -11,7 +11,6 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" -#include "torch_xla/csrc/runtime/pjrt_compile_only.h" // PEI: leave only pjrtbuffer inside #include "torch_xla/csrc/runtime/util.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" @@ -24,6 +23,11 @@ namespace torch_xla { namespace runtime { +struct Buffer { + xla::Shape shape; + Buffer(xla::Shape shape) : shape(shape) {} +}; + class PjRtCompilationClient : public ComputationClient { public: PjRtCompilationClient(std::string& virtual_topology_str); @@ -34,7 +38,7 @@ class PjRtCompilationClient : public ComputationClient { std::optional sharding = std::nullopt) override; static DataPtr CreateData(std::string device, xla::Shape shape, - std::shared_ptr pjrt_buffer); + std::shared_ptr buffer); std::vector GetDataShards(DataPtr data) override; @@ -175,15 +179,9 @@ class PjRtCompilationClient : public ComputationClient { : Data(std::move(device), std::move(device_shape)) {} PjRtData(std::string device, xla::Shape device_shape, - std::shared_ptr buffer) + std::shared_ptr buffer) : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} - PjRtData(std::string device, std::shared_ptr buffer) - : Data(std::move(device), - xla::Shape(buffer->element_type(), buffer->dimensions(), - buffer->is_dynamic_dimension(), {})), - buffer(buffer) {} - Handle GetHandle() override { XLA_CHECK(HasValue()) << "buffer with shape " << shape().ToString() << " on device " @@ -191,9 +189,7 @@ class PjRtCompilationClient : public ComputationClient { return reinterpret_cast(buffer.get()); }; void Assign(const torch::lazy::BackendData& data) override; - bool HasValue() const override { - return buffer != nullptr && !buffer->IsDeleted(); - }; + bool HasValue() const override { return buffer != nullptr; }; bool HasSharding() const override { return false; } @@ -217,7 +213,7 @@ class PjRtCompilationClient : public ComputationClient { return ss.str(); } - std::shared_ptr buffer; + std::shared_ptr buffer; }; struct PjRtShardedData : public Data { diff --git a/torch_xla/csrc/runtime/pjrt_compile_only.cc b/torch_xla/csrc/runtime/pjrt_compile_only.cc deleted file mode 100644 index d00f8f11e72..00000000000 --- a/torch_xla/csrc/runtime/pjrt_compile_only.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include "torch_xla/csrc/runtime/pjrt_compile_only.h" - -namespace xla { - -const Shape& CompileOnlyPjRtBuffer::on_device_shape() const { return shape_; } - -PjRtMemorySpace* CompileOnlyPjRtBuffer::memory_space() const { return nullptr; } - -PjRtDevice* CompileOnlyPjRtBuffer::device() const { return nullptr; } - -PjRtClient* CompileOnlyPjRtBuffer::client() const { return nullptr; } - -StatusOr> -CompileOnlyPjRtBuffer::AcquireExternalReference() { - return Unimplemented(""); -} - -PjRtFuture<> CompileOnlyPjRtBuffer::ToLiteral(MutableLiteralBase* literal) { - return PjRtFuture<>(); -} - -PjRtFuture<> CompileOnlyPjRtBuffer::LazyToLiteral( - absl::AnyInvocable() &&> generator) { - return PjRtFuture<>(); -} - -StatusOr CompileOnlyPjRtBuffer::GetOnDeviceSizeInBytes() const { - return Unimplemented(""); -} -PjRtFuture<> CompileOnlyPjRtBuffer::CopyRawToHost(void* dst, int64_t offset, - int64_t transfer_size) { - return PjRtFuture<>(); -} - -void CompileOnlyPjRtBuffer::Delete() { return; } - -StatusOr> -CompileOnlyPjRtBuffer::ReleaseDeviceMemoryOwnership( - bool wait_for_operations_to_complete) { - return Unimplemented(""); -} - -bool CompileOnlyPjRtBuffer::IsDeleted() { return false; } - -StatusOr> CompileOnlyPjRtBuffer::CopyToDevice( - PjRtDevice* dst_device) { - return Unimplemented(""); -} -StatusOr> CompileOnlyPjRtBuffer::CopyToMemorySpace( - PjRtMemorySpace* dst_memory_space) { - return Unimplemented(""); -} - -void CompileOnlyPjRtBuffer::CopyToRemoteDevice( - PjRtFuture serialized_descriptor, RemoteSendCallback on_done) { - return; -} - -void CompileOnlyPjRtBuffer::CopyToRemoteDeviceScattered( - PjRtFuture> serialized_descriptors, - std::vector callbacks, - const ScatterDetails& scatter_details) { - return; -} - -PjRtFuture<> CompileOnlyPjRtBuffer::GetReadyFuture() { return PjRtFuture<>(); } - -bool CompileOnlyPjRtBuffer::IsOnCpu() const { return false; } - -} // namespace xla diff --git a/torch_xla/csrc/runtime/pjrt_compile_only.h b/torch_xla/csrc/runtime/pjrt_compile_only.h deleted file mode 100644 index 55e93ca1ae3..00000000000 --- a/torch_xla/csrc/runtime/pjrt_compile_only.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef XLA_CLIENT_PJRT_COMPILE_ONLY_H_ -#define XLA_CLIENT_PJRT_COMPILE_ONLY_H_ - -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_future.h" - -namespace xla { - -class CompileOnlyPjRtBuffer final : public PjRtBuffer { - public: - CompileOnlyPjRtBuffer(const Shape& shape) { shape_ = shape; } - const Shape& on_device_shape() const override; - PjRtMemorySpace* memory_space() const override; - PjRtDevice* device() const override; - PjRtClient* client() const override; - StatusOr> AcquireExternalReference() - override; - PjRtFuture<> ToLiteral(MutableLiteralBase* literal) override; - PjRtFuture<> LazyToLiteral( - absl::AnyInvocable() &&> generator) - override; - StatusOr GetOnDeviceSizeInBytes() const override; - PjRtFuture<> CopyRawToHost(void* dst, int64_t offset, - int64_t transfer_size) override; - void Delete() override; - StatusOr> ReleaseDeviceMemoryOwnership( - bool wait_for_operations_to_complete) override; - bool IsDeleted() override; - StatusOr> CopyToDevice( - PjRtDevice* dst_device) override; - StatusOr> CopyToMemorySpace( - PjRtMemorySpace* dst_memory_space) override; - void CopyToRemoteDevice(PjRtFuture serialized_descriptor, - RemoteSendCallback on_done) override; - void CopyToRemoteDeviceScattered( - PjRtFuture> serialized_descriptors, - std::vector callbacks, - const ScatterDetails& scatter_details) override; - PjRtFuture<> GetReadyFuture() override; - bool IsOnCpu() const override; - - private: - Shape shape_; -}; - -} // namespace xla - -#endif