Skip to content

Commit

Permalink
Enable linspace & logspace tests & fix the failures
Browse files Browse the repository at this point in the history
1. fix _copy to handle self shape=(1) and copy src shape=0, ref: #7505 (comment)
2. for squeeze where we expect the self's shape to change, dont use copy_ instead replace.
3. in test_ops reset dtype to float when an int64 is passed for linspace
   case. This is to workaround known pytorch failure: pytorch/pytorch#137546
4. logspace tests depend on linspace. Both are passing now

ref:
* #7505 (comment)
* #7505 (comment)
* #7505 (comment)
* pytorch/pytorch#137546
  • Loading branch information
barney-s committed Oct 9, 2024
1 parent 07d0823 commit b182ea0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
11 changes: 9 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
"linalg.tensorinv",
"linalg.tensorsolve",
"linalg.vector_norm",
"linspace",
"log_normal",
"logspace",
"lu",
"lu_solve",
"lu_unpack",
Expand Down Expand Up @@ -261,6 +259,15 @@ def test_reference_eager(self, device, dtype, op):
continue
check_output = op.name not in random_ops

#print("[DEBUG] sample_input: ", sample_input)

# TODO: this is a workaround to skip int64 cast for linspace
# reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments
# we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546
if op.name == "linspace":
if 'dtype' in sample_input.kwargs:
if sample_input.kwargs['dtype'] == torch.int64:
sample_input.kwargs['dtype'] = torch.float
if op.name == "special.polygamma":
# The polygamma function is inaccurate for values < 1.
# To avoid errors during testing, replace values below 1 with 1.
Expand Down
15 changes: 13 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
# squeeze_ is expected to change tensor's shape. So replace with new value
torch.ops.aten.squeeze_: (torch.ops.aten.squeeze, True),
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.ceil_: torch.ops.aten.ceil,
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
Expand All @@ -51,6 +52,10 @@


def make_mutation(op):
if type(mutation_ops_to_functional[op]) is tuple:
return op_base.InplaceOp(mutation_ops_to_functional[op][0],
replace=mutation_ops_to_functional[op][1],
position_to_mutate=0)
return op_base.InplaceOp(mutation_ops_to_functional[op], position_to_mutate=0)


Expand Down Expand Up @@ -103,7 +108,13 @@ def _aten_add(x, y, *, alpha=1):

@op(torch.ops.aten.copy_, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
x._elem = y._elem.astype(x._elem.dtype)
if x.ndim == 1 and y.ndim == 0:
# case of torch.empty((1,)).copy_(tensor(N))
# we need to return 0D tensor([N]) and not scalar tensor(N)
# ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131
x._elem = jnp.array([y._elem.astype(x._elem.dtype)])
else:
x._elem = y._elem.astype(x._elem.dtype)
return x


Expand Down
8 changes: 6 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/op_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

class InplaceOp:

def __init__(self, functional_op, position_to_mutate=0):
def __init__(self, functional_op, replace=False, position_to_mutate=0):
self.functional = functional_op
self.replace = replace
self.position_to_mutate = position_to_mutate

def __call__(self, *args, **kwargs):
to_mutate = args[0]
to_mutate.copy_(self.functional(*args, **kwargs))
if self.replace:
to_mutate._elem = self.functional(*args, **kwargs)._elem
else:
to_mutate.copy_(self.functional(*args, **kwargs))
return to_mutate


Expand Down

0 comments on commit b182ea0

Please sign in to comment.