Skip to content

Commit

Permalink
Handle multiple inplace update input output aliasing (#7023)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 3, 2024
1 parent d123585 commit e3fc033
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
44 changes: 44 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,50 @@ def test_aliasing_with_cloned(self):
torch.allclose(t1 - 1, t1_cloned)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)

def test_aliasing_across_mark_step(self):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.randn(4, 5).to(xla_device)
t1 += 1
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
t1 *= 100
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

def test_aliasing_with_multiple_inplace_update(self):
BATCH_SIZE = 1
SEQ_LEN = 128
NUM_KV_HEADS = 16
HEAD_SIZE = 256
BLOCK_SIZE = 16
DTYPE = torch.bfloat16
num_blocks = 1024
device = xm.xla_device()
key = torch.randn(
BATCH_SIZE * SEQ_LEN,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
k_cache = torch.randn(
num_blocks * BLOCK_SIZE,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
slot_mapping = torch.randint(
0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64)
# materalize k_cache to device data
xm.mark_step()
met.clear_all()
for _ in range(10):
k_cache.index_copy_(0, slot_mapping.flatten(), key)
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu())


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_input_output_aliases.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand Down
34 changes: 33 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/as_strided_view_update.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal_view_update.h"
#include "torch_xla/csrc/ops/einsum_utilities.h"
#include "torch_xla/csrc/ops/index_ops.h"
Expand Down Expand Up @@ -2538,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
// 1) Aid XLA's InputOutputAlias.
auto input_tensor = bridge::GetXlaTensor(input);
auto output_tensor = bridge::GetXlaTensor(output);
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
if (input_tensor->CurrentDataHandle() != nullptr ||
(input_tensor->CurrentIrValue().node != nullptr &&
torch_xla::DeviceData::Cast(
input_tensor->CurrentIrValue().node.get()))) {
/*
if input has a XLAData or holds a devicedata node, set alias_id to
tensor_id. Consider the case.
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
xm.mark_step()
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
// for this graph
x *= 1 of 1
*/
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
} else {
/*
Consider the case
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
// x.tensor_id = 3, x.alias_id should still be 1
x * = 2
xm.mark_step()
*/
output_tensor->data()->alias_id = input_tensor->data()->alias_id;
}

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
Expand Down

0 comments on commit e3fc033

Please sign in to comment.