Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dlpack support #7025

Merged
merged 19 commits into from
May 22, 2024
Merged

Add dlpack support #7025

merged 19 commits into from
May 22, 2024

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented May 4, 2024

This PR adds basic dlpack support for pytorch/xla.

Execution path:

  • DLPack->XLATensor: dlpack.from_dlpack(python)->_from_dlpack(pybind in torch_xla/csrc/init_python_bindings.cpp)->torch_xla::fromDLPack(torch_xla/csrc/dl_convertor.cpp)
  • XLATensor-.DLPack: dlpack.to_dlpack(python)->_to_dlpack(pybind in torch_xla/csrc/init_python_bindings.cpp)->torch_xla::toDLPack(torch_xla/csrc/dl_convertor.cpp)

Test plans: PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py -k TestDLPack.test_

references:

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented May 7, 2024

Currently, the tests fails with

2024-05-07 17:55:28.847281: F external/xla/xla/pjrt/pjrt_stream_executor_client.cc:1322] Check failed: holds_[i] == 0 (1 vs. 0)
Aborted (core dumped)

callstack https://gist.github.com/vanbasten23/a42b196b2fcd5e985a083752386dd3f8

It fails at https://github.com/openxla/xla/blob/0aa2ab4df32bdb099664b0edeef991f40ff1af49/xla/pjrt/pjrt_stream_executor_client.cc#L1322 when i==1 (kExternalReference). I think we are calling the PjRtStreamExecutorBuffer destructor too early for the pjrt_buffer at https://github.com/pytorch/xla/pull/7025/files#diff-cf3922091c803bce3341e4e55b2c54d277812059adeafa297eb1c7b444213b2aR128.

Edit: i figured it out and fixed it.

@vanbasten23 vanbasten23 changed the title [WIP] Add dlpack support Add dlpack support May 10, 2024
@vanbasten23 vanbasten23 requested a review from JackCaoG May 10, 2024 21:45
@vanbasten23 vanbasten23 marked this pull request as ready for review May 10, 2024 21:46
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor nits.

test/test_operations.py Outdated Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Outdated Show resolved Hide resolved
test/test_operations.py Outdated Show resolved Hide resolved
Comment on lines 481 to 496
std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer(
const DataPtr handle) {
std::shared_ptr<PjRtData> pjrt_data =
std::dynamic_pointer_cast<PjRtData>(handle);
XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString();
return pjrt_data->buffer;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method has to be a class method of PjRtComputationClient? It doesn;t access any class private members. If it is just a helper, you don't need this to b a class method. @will-cromar wdyt?

Copy link
Collaborator Author

@vanbasten23 vanbasten23 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The struct PjRtData is private to PjRtComputationClient. That's why I have to make it a class method.

Comment on lines +188 to +193
ComputationClient::DataPtr PjRtComputationClient::CreateData(
std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) {
return std::make_shared<PjRtData>(std::move(device), std::move(shape),
pjrt_buffer);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you just need a helper, not the class method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The struct PjRtData is private to PjRtComputationClient so I can't make it a helper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the input pjrt_buffer depend on the current instance of PjRtClient? ie do you need to check here that this buffer belongs to this client? If you don't depend on any members of PjRtComputationClient, can you make this a static method to make that clearer?

Also, this should be private and probably not part of ComputationClient. The public API of PjRtComputationClient shouldn't operate on PJRT primitives. It should be a complete wrapper to leave the door open to other runtime interfaces (namely IFRT).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the pjrt_buffer doesn't depend on the current instance of PjRtClient. I have made it a static method.

I also made it not a part of ComputationClient. Not sure how I can make it private and call it in the dl_convertor.cpp at the same time.

@vanbasten23 vanbasten23 requested a review from JackCaoG May 13, 2024 23:20
@vanbasten23
Copy link
Collaborator Author

hey @will-cromar @JackCaoG , it seems the CI skips tests that require PyTorch CUDA support, for example

test_aten_move_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'
test_aten_move_scalar_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'

so the tests I added in this PR are not actually run by the CI.

What do you think of adding a new CI workflow that builds pytorch with CUDA, builds pytorch/xla, and runs the test? By preserving the existing CI workflow (build pytorch with CUDA disabled), people can still get faster feedback.

@will-cromar
Copy link
Collaborator

Yeah, that's a loose end from all of the refactoring I did. I agree with adding a new workflow or branch of the workflow that adds CUDA separately.

If it's okay with you, the easiest option would be to download the pre-built nightly CUDA wheel from https://download.pytorch.org/whl/nightly/cu121. This will be faster and require less maintenance, but it will also break periodically if there's a breaking change on head since the last nightly build.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is really neat

torch_xla/csrc/dl_convertor.cpp Outdated Show resolved Hide resolved
Comment on lines +188 to +193
ComputationClient::DataPtr PjRtComputationClient::CreateData(
std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer) {
return std::make_shared<PjRtData>(std::move(device), std::move(shape),
pjrt_buffer);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the input pjrt_buffer depend on the current instance of PjRtClient? ie do you need to check here that this buffer belongs to this client? If you don't depend on any members of PjRtComputationClient, can you make this a static method to make that clearer?

Also, this should be private and probably not part of ComputationClient. The public API of PjRtComputationClient shouldn't operate on PJRT primitives. It should be a complete wrapper to leave the door open to other runtime interfaces (namely IFRT).

@@ -84,6 +91,15 @@ class IfrtComputationClient : public ComputationClient {
absl::AsciiStrToUpper(client_->platform_name()));
};

xla::PjRtPlatformId GetPlatformID() const override {
return client_->platform_id();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this method doesn't exist in the PJRT C API: https://github.com/openxla/xla/blob/80b854f9ef3e8b913ed9ca2d930c81e32c6d02da/xla/pjrt/pjrt_c_api_client.h#L209-L211

Since dlpack is a very special case, what do you think of just adding a is_cuda or supports_dlpack attribute to PjRtComputationClient and setting it during init?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this method doesn't exist in the PJRT C API

The client_ here has type xla::ifrt::PjRtClient which implements platform_id at https://github.com/openxla/xla/blob/62568f0ef380c883defd44a060722dd62e81df1b/xla/python/pjrt_ifrt/pjrt_client.h#L147-L150. I wonder how the xla::ifrt::PjRtClient relate to the PJRT C API PjRtCApiTopologyDescription.

what do you think of just adding a is_cuda or supports_dlpack attribute to PjRtComputationClient and setting it during init?

Can you elaborate more? Do you mean adding a supports_dlpack to ComputationClient and set it true for PjRtComputationClient and false to IfRtComputationClient?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PJRT C API is a subset of the PJRT C++ API, so some methods of the C++ -> C wrapper are unimplemented. I would add a field to PjRtComputationClient to indicate if dlpack is expected to work with the underlying PjRtClient (only true when PJRT_DEVICE=CUDA), since that depends on the specific device type, which we should not rely on after initialization.

I'm not concerned about IFRT right now. Also, remember that xla::PjRtClient is different than xla::ifrt::PjRtClient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If the concern is dlpack should only be used when PJRT_DEVICE=CUDA, then there is a check in this PR https://github.com/pytorch/xla/pull/7025/files#diff-cf3922091c803bce3341e4e55b2c54d277812059adeafa297eb1c7b444213b2aR45-R53 to make sure dlpack is used for CUDA instead of TPU.

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great work, @vanbasten23. I have left a few comments in the PR. I do have a couple of questions, though:

  1. Could you add an overview of the functions involved in the DLPack-XLA (to/from) execution paths to the PR description? I feel like that would make your PR easier to follow.

  2. Before converting CUDA to DLPack to XLA, don't we need to call torch.cuda.synchronize()? That's because IIRC, some CUDA computation is also lazy.

test/test_operations.py Outdated Show resolved Hide resolved
test/test_operations.py Show resolved Hide resolved
test/test_operations.py Show resolved Hide resolved
test/test_operations.py Show resolved Hide resolved
test/test_operations.py Show resolved Hide resolved
test/test_operations.py Outdated Show resolved Hide resolved
torch_xla/csrc/dl_convertor.cpp Show resolved Hide resolved
torch_xla/csrc/tensor_util.cpp Outdated Show resolved Hide resolved
@vanbasten23
Copy link
Collaborator Author

Really great work, @vanbasten23. I have left a few comments in the PR. I do have a couple of questions, though:

  1. Could you add an overview of the functions involved in the DLPack-XLA (to/from) execution paths to the PR description? I feel like that would make your PR easier to follow.

Added the execution paths.

  1. Before converting CUDA to DLPack to XLA, don't we need to call torch.cuda.synchronize()? That's because IIRC, some CUDA computation is also lazy.

Good callout. I couldn't find this in the pytorch's documentation. But from their test pytorch/test/test_dlpack.py, it seems it is only needed when stream is used. Since this PR is not supporting stream, so I think it is not needed.

@ysiraichi
Copy link
Collaborator

it seems the CI skips tests that require PyTorch CUDA support, for example

test_aten_move_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'
test_aten_move_scalar_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'

@will-cromar Do you mean that, at some point, we stopped using .circleci/build.sh? I remember adding USE_CUDA=1 to it in #6070.

@ysiraichi
Copy link
Collaborator

it seems it is only needed when stream is used. Since this PR is not supporting stream, so I think it is not needed.

>>> a = torch.rand(1024, 1024, 1024, device="cuda")

# This returns instantly
>>> for i in range(1000):
>>>    a = a @ a

# Does NOT return instantly. Can hear the GPU fans going up.
>>> torch.cuda.synchronize()

I think that, at first, we can just call torch.cuda.synchronize() and then do the transition using DLPack. That said, a better way to handle this would be to use the same stream object (in PyTorch and XLA), so that no explicit synchronization is actually needed.

@will-cromar
Copy link
Collaborator

it seems the CI skips tests that require PyTorch CUDA support, for example

test_aten_move_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'
test_aten_move_scalar_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'

@will-cromar Do you mean that, at some point, we stopped using .circleci/build.sh? I remember adding USE_CUDA=1 to it in #6070.

Yeah, that's right. The TPU CI, CPU/GPU CI, and nightly build were all using different build scripts. While splitting the XLA CUDA build from the torch_xla build, I moved them all to the ansible config at infra/ansible.

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented May 16, 2024

it seems the CI skips tests that require PyTorch CUDA support, for example

test_aten_move_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'
test_aten_move_scalar_cuda_to_xla (__main__.TestGeneric) ... skipped 'requires PyTorch CUDA support'

@will-cromar Do you mean that, at some point, we stopped using .circleci/build.sh? I remember adding USE_CUDA=1 to it in #6070.

Right I observed the same in #7025 (comment). I'm adding a new workflow that build pytorch with cuda enabled and only exercise those tests requiring pytorch cuda. For now, I'll just run locally and the tests pass.

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented May 16, 2024

Yeah, that's a loose end from all of the refactoring I did. I agree with adding a new workflow or branch of the workflow that adds CUDA separately.

If it's okay with you, the easiest option would be to download the pre-built nightly CUDA wheel from https://download.pytorch.org/whl/nightly/cu121. This will be faster and require less maintenance, but it will also break periodically if there's a breaking change on head since the last nightly build.

Thanks for the suggestion. I've seen a few incompatible torch wheel and torch_xla wheel recently so I feel it is not very uncommon. If they are incompatible, our CI will be red for a day or two and we'll have to communicate to the team which is kind of a hassle lol. So I'll strive for building pytorch with cuda from source: #7073

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need more comments, e.g.:

  • where each function was copied from?
  • how to use the added API correctly?

Otherwise, LGTM.

torch_xla/csrc/init_python_bindings.cpp Outdated Show resolved Hide resolved
@JackCaoG
Copy link
Collaborator

I will try to take a look tmr

@vanbasten23
Copy link
Collaborator Author

I think we need more comments, e.g.:

  • where each function was copied from?
  • how to use the added API correctly?

Otherwise, LGTM.

Thanks for the review. I've added the comments you suggested.

Comment on lines +2544 to +2547
// If ext_data is the result of an CUDA computation, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is _from_dlpack's comment.

@vanbasten23
Copy link
Collaborator Author

hey @JackCaoG could you take a look at the PR when you are available? Thanks!

@vanbasten23
Copy link
Collaborator Author

Thanks for the review Jack and Yukio!

@vanbasten23 vanbasten23 merged commit 6023855 into master May 22, 2024
20 checks passed
qihqi pushed a commit that referenced this pull request May 29, 2024
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null.";

runtime::ComputationClient::DataPtr data =
runtime::PjRtComputationClient::CreateData(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Does this only apply to PjRtComputationClient?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Background: I am working on a new ComputationClient and here reports the error:

torch_xla/csrc/dl_convertor.cpp:337:16: error: 'torch_xla::runtime::PjRtComputationClient' has not been declared
  337 |       runtime::PjRtComputationClient::CreateData(
      |                ^~~~~~~~~~~~~~~~~~~~~

Looks like we have to use PjRtComputationClient with this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. As far as I understand it, the DLPack API only works with XLA:CUDA, i.e. PJRT.

yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants