Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dlpack support #7025

Merged
merged 19 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
parser.add_argument('--verbosity', type=int, default=0)
FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers
from absl.testing import absltest, parameterized

# Normal imports section starts here.
import collections
Expand All @@ -28,6 +29,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.testing._internal.common_device_type import dtypes
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
all_types_and,
)
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
Expand All @@ -40,6 +46,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.dlpack as xdlpack
import torch_xla.utils.utils as xu
import torch_xla.utils.serialization as xser
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -2448,6 +2455,139 @@ def test_unsafe_buffer_pointer(self):
self.assertGreaterEqual(buf_ptr_3, 0)


class TestDLPack(parameterized.TestCase):

def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
xla_tensor2 = xdlpack.from_dlpack(dlpt)

self.assertEqual(xla_tensor.device, xla_tensor2.device)
self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu()))
self.assertRaisesRegex(RuntimeError,
"DLTensor capsule can be consumed only once",
lambda: xdlpack.from_dlpack(dlpt))

self.assertEqual(
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and(torch.half, torch.bfloat16))
def test_dlpack_roundtrip_tensor(self, dtype):
xla_device = xm.xla_device()
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
# xla_tensor_2 uses XLANativeFunctions::_to_copy
xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device)
self._test_dlpack_capsule_conversion_helper(xla_tensor_2)

# xla_tensor_3 uses arange_out IR node.
xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device())
xm.mark_step()
self._test_dlpack_capsule_conversion_helper(xla_tensor_3)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
@parameterized.parameters(*all_types_and_complex_and(torch.half,
torch.bfloat16,
torch.bool, torch.uint16,
torch.uint32,
torch.uint64))
def test_dlpack_roundtrip_scalar(self, dtype):
xla_device = xm.xla_device()
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)
# `mark_step` ensures xtensor->CurrentDataHandle() != nullptr
xm.mark_step()
self._test_dlpack_capsule_conversion_helper(xla_tensor_0)

xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device)
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
self._test_dlpack_capsule_conversion_helper(xla_tensor_1)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_roundtrip_bool(self):
xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device())
self._test_dlpack_capsule_conversion_helper(xla_tensor)
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_pytorch_cuda_to_xla(self):
t1_cuda = torch.arange(5).cuda()
dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda)
xla_t1 = xdlpack.from_dlpack(dlt1)
self.assertEqual(xla_t1.device.type, 'xla')
self.assertEqual(xla_t1.device.index, t1_cuda.device.index)
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
t1_cuda[0] = t1_cuda[0] + 20
self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu()))
vanbasten23 marked this conversation as resolved.
Show resolved Hide resolved

t2_cuda = torch.tensor(5).cuda()
dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda)
xla_t2 = xdlpack.from_dlpack(dlt2)
self.assertEqual(xla_t2.device.type, 'xla')
self.assertEqual(xla_t2.device.index, t2_cuda.device.index)
t2_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))
vanbasten23 marked this conversation as resolved.
Show resolved Hide resolved

cuda1 = torch.device('cuda:1')
t3_cuda = torch.tensor(5, device=cuda1)
dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda)
xla_t3 = xdlpack.from_dlpack(dlt3)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_xla_to_pytorch_cuda(self):
xla_t1 = torch.arange(5).to(xm.xla_device())
dlt1 = xdlpack.to_dlpack(xla_t1)
cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1)
self.assertEqual(cuda_t1.device.type, 'cuda')
self.assertEqual(cuda_t1.device.index, xla_t1.device.index)
cuda_t1[0] = cuda_t1[0] + 20
self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu()))
vanbasten23 marked this conversation as resolved.
Show resolved Hide resolved

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_non_default_layout(self):
cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5)

t1 = cuda_t.t()
xla_t1 = xdlpack.from_dlpack(t1.__dlpack__())
self.assertEqual(xla_t1.device.type, 'xla')
self.assertEqual(xla_t1.device.index, 0)
self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu()))

t2 = cuda_t[0]
xla_t2 = xdlpack.from_dlpack(t2.__dlpack__())
self.assertEqual(xla_t2.device.type, 'xla')
self.assertEqual(xla_t2.device.index, 0)
self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu()))

t3 = cuda_t[:, 0]
self.assertRaisesRegex(
RuntimeError,
r"Only DLPack tensors with trivial \(compact\) striding are supported",
lambda: xdlpack.from_dlpack(t3.__dlpack__()))

t4 = cuda_t[1, :]
xla_t4 = xdlpack.from_dlpack(t4.__dlpack__())
self.assertEqual(xla_t4.device.type, 'xla')
self.assertEqual(xla_t4.device.index, 0)
self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu()))

t5 = cuda_t[1]
xla_t5 = xdlpack.from_dlpack(t5.__dlpack__())
self.assertEqual(xla_t5.device.type, 'xla')
self.assertEqual(xla_t5.device.index, 0)
self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu()))


class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ptxla_cc_library(
"cross_replica_reduces.cpp",
"data_ops.cpp",
"debug_util.cpp",
"dl_convertor.cpp",
"elementwise.cpp",
"helpers.cpp",
"ir_dump_util.cpp",
Expand Down Expand Up @@ -81,6 +82,7 @@ ptxla_cc_library(
"cross_replica_reduces.h",
"data_ops.h",
"debug_util.h",
"dl_convertor.h",
"elementwise.h",
"generated_file_include.h",
"helpers.h",
Expand Down
Loading
Loading