diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index c7c04f781c3..b2c5fc50b21 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,6 +38,50 @@ def test_aliasing_with_cloned(self): torch.allclose(t1 - 1, t1_cloned) self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + def test_aliasing_across_mark_step(self): + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.randn(4, 5).to(xla_device) + t1 += 1 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + t1 *= 100 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) + + def test_aliasing_with_multiple_inplace_update(self): + BATCH_SIZE = 1 + SEQ_LEN = 128 + NUM_KV_HEADS = 16 + HEAD_SIZE = 256 + BLOCK_SIZE = 16 + DTYPE = torch.bfloat16 + num_blocks = 1024 + device = xm.xla_device() + key = torch.randn( + BATCH_SIZE * SEQ_LEN, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + k_cache = torch.randn( + num_blocks * BLOCK_SIZE, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + slot_mapping = torch.randint( + 0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64) + # materalize k_cache to device data + xm.mark_step() + met.clear_all() + for _ in range(10): + k_cache.index_copy_(0, slot_mapping.flatten(), key) + xm.mark_step() + xm.wait_device_ops() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index b2a8fff33dc..a0eddebc3d5 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -21,6 +21,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py python3 test/test_pallas.py +python3 test/test_input_output_aliases.py python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8464d1320c2..12a49a91ad9 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -29,6 +29,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" #include "torch_xla/csrc/ops/as_strided_view_update.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal_view_update.h" #include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/ops/index_ops.h" @@ -2538,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // 1) Aid XLA's InputOutputAlias. auto input_tensor = bridge::GetXlaTensor(input); auto output_tensor = bridge::GetXlaTensor(output); - output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + if (input_tensor->CurrentDataHandle() != nullptr || + (input_tensor->CurrentIrValue().node != nullptr && + torch_xla::DeviceData::Cast( + input_tensor->CurrentIrValue().node.get()))) { + /* + if input has a XLAData or holds a devicedata node, set alias_id to + tensor_id. Consider the case. + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + xm.mark_step() + // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2 + // for this graph + x *= 1 of 1 + */ + output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + } else { + /* + Consider the case + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + // x.tensor_id = 3, x.alias_id should still be 1 + x * = 2 + xm.mark_step() + */ + output_tensor->data()->alias_id = input_tensor->data()->alias_id; + } // 2) Aid SPMD. XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();