diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa07381cfda..07e6d8f0e3b 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -47,9 +47,7 @@ "linalg.tensorinv", "linalg.tensorsolve", "linalg.vector_norm", - "linspace", "log_normal", - "logspace", "lu", "lu_solve", "lu_unpack", @@ -261,6 +259,15 @@ def test_reference_eager(self, device, dtype, op): continue check_output = op.name not in random_ops + #print("[DEBUG] sample_input: ", sample_input) + + # TODO: this is a workaround to skip int64 cast for linspace + # reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments + # we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546 + if op.name == "linspace": + if 'dtype' in sample_input.kwargs: + if sample_input.kwargs['dtype'] == torch.int64: + sample_input.kwargs['dtype'] = torch.float if op.name == "special.polygamma": # The polygamma function is inaccurate for values < 1. # To avoid errors during testing, replace values below 1 with 1. diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index f671e039839..12e4e10e89d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -41,7 +41,8 @@ torch.ops.aten.random_: torch.ops.aten.uniform, torch.ops.aten.uniform_: torch.ops.aten.uniform, torch.ops.aten.relu_: torch.ops.aten.relu, - torch.ops.aten.squeeze_: torch.ops.aten.squeeze, + # squeeze_ is expected to change tensor's shape. So replace with new value + torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True), torch.ops.aten.clamp_: torch.ops.aten.clamp, torch.ops.aten.ceil_: torch.ops.aten.ceil, torch.ops.aten.logical_not_: torch.ops.aten.logical_not, @@ -51,6 +52,10 @@ def make_mutation(op): + if type(mutation_ops_to_functional[op]) is tuple: + return op_base.InplaceOp(mutation_ops_to_functional[op][0], + replace=mutation_ops_to_functional[op][1], + position_to_mutate=0) return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0) @@ -103,7 +108,13 @@ def _aten_add(x, y, *, alpha=1): @op(torch.ops.aten.copy_, is_jax_function=False) def _aten_copy(x, y, memory_format=None): - x._elem = y._elem.astype(x._elem.dtype) + if x.ndim == 1 and y.ndim == 0: + # case of torch.empty((1,)).copy_(tensor(N)) + # we need to return 0D tensor([N]) and not scalar tensor(N) + # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 + x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) + else: + x._elem = y._elem.astype(x._elem.dtype) return x diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 2c4176a361d..203ec5a3686 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -17,13 +17,17 @@ class InplaceOp: - def __init__(self, functional_op, position_to_mutate=0): + def __init__(self, functional_op, replace=False, position_to_mutate=0): self.functional = functional_op + self.replace = replace self.position_to_mutate = position_to_mutate def __call__(self, *args, **kwargs): to_mutate = args[0] - to_mutate.copy_(self.functional(*args, **kwargs)) + if self.replace: + to_mutate._elem = self.functional(*args, **kwargs)._elem + else: + to_mutate.copy_(self.functional(*args, **kwargs)) return to_mutate