diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 91eccaa15d0..aa07381cfda 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -11,14 +11,11 @@ skiplist = { - "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior - # such as 0 to negative power. "_segment_reduce", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", "cat", "cdist", - "ceil", "cholesky", "cholesky_solve", "diagonal_copy", @@ -116,21 +113,27 @@ "special.zeta", "svd", "svd_lowrank", - "to_sparse", # We are not supporting sparse tensors yet. "unfold_copy", "unfold", "unique_consecutive", "unique", "unravel_index", - "trunc", "var_mean", "argwhere", "nanmean", - "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 "nn.functional.upsample_bilinear", "randint", } +not_support_ops_list = { + "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 + "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior + # such as 0 to negative power. + "ceil", # only failed with python 3.9 + "trunc", # only failed with python 3.9 + "to_sparse", # We are not supporting sparse tensors yet. +} + # These inputs are themselves views # We cannot know how are the views created so cannot replicate the behavior. variant_test_name_to_skip = { @@ -219,7 +222,7 @@ def run_export_and_compare(testcase, ops_to_test = [ test for test in op_db - if (test.name not in skiplist and + if (test.name not in (skiplist | not_support_ops_list) and test.variant_test_name not in variant_test_name_to_skip) ] diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 8d3f8c6b96a..f671e039839 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -112,6 +112,12 @@ def _aten_clone(x, memory_format=None): return x +# aten.trunc +@op(torch.ops.aten.trunc) +def _aten_trunc(x): + return jnp.trunc(x) + + @op(torch.ops.aten.index_copy) def _aten_index_copy(x, dim, indexes, source): # return jax.lax.scatter(x, index, dim)