Skip to content

Commit

Permalink
refining the AOT compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jun 26, 2024
1 parent c094293 commit c5d4626
Show file tree
Hide file tree
Showing 11 changed files with 1,224 additions and 15 deletions.
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,12 @@ void InitXlaModuleBindings(py::module m) {
"_xla_get_rng_seed",
[](const std::string& device) { return GetRngSeed(device); },
py::arg("device") = "");
m.def(
"_xla_set_virtual_topology",
[](std::string& topology) {
torch_xla::runtime::SetVirtualTopology(topology);
},
py::arg("topology") = "");
m.def(
"_xla_set_should_alias_with_buffer_donor_config",
[](bool should_alias, const std::string& device_str) {
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cc_library(
":env_vars",
":ifrt_computation_client",
":pjrt_computation_client",
":pjrt_compilation_client",
"@tsl//tsl/platform:stacktrace",
],
)
Expand Down Expand Up @@ -132,6 +133,43 @@ cc_library(
],
)

cc_library(
name = "pjrt_compilation_client",
srcs = [
"pjrt_compilation_client.cc",
],
hdrs = [
"pjrt_compilation_client.h",
],
deps = [
":computation_client",
":debug_macros",
":env_hash",
":env_vars",
":operation_manager",
":pjrt_compile_only",
":pjrt_registry",
":profiler",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":xla_coordinator",
"//torch_xla/csrc:thread_pool",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@tsl//tsl/profiler/lib:traceme",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/distributed",
],
)

cc_library(
name = "cache",
hdrs = ["cache.h"],
Expand Down Expand Up @@ -188,6 +226,16 @@ cc_test(
],
)

cc_library(
name = "pjrt_compile_only",
srcs = ["pjrt_compile_only.cc"],
hdrs = ["pjrt_compile_only.h"],
deps = [
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:pjrt_future",
],
)

cc_library(
name = "pjrt_registry",
srcs = ["pjrt_registry.cc"],
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ComputationClient {
// of device.
// 3. xla::XlaComputation represent a xla computation, it is generated by the
// xla compiler.
// 4. xla::PjRtComputationClient::PjRtComputation which inherits from
// 4. xla::PjRtCompilationClient::PjRtComputation which inherits from
// runtime::ComputationClient::Computation and contains a handle to represent
// the compiled program.
class Computation : public torch::lazy::Computation {
Expand Down Expand Up @@ -126,7 +126,7 @@ class ComputationClient {
// ...
// 3. To represent a computation that is already compiled. In this case
// name_ and hash_ are not required. Computation will be a wrapper around
// an executable, PjRtComputationClient::PjRtComputation in our case. It
// an executable, PjRtCompilationClient::PjRtComputation in our case. It
// is not ideal to use same class for 3 different purposes but this is
// the path took by upstream ltc.
Computation(xla::XlaComputation computation,
Expand Down
Loading

0 comments on commit c5d4626

Please sign in to comment.