Skip to content

Commit

Permalink
[torch_xla2] summary hard ops to non_support_ops_list (#8235)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Oct 8, 2024
1 parent 93b4a90 commit 07d0823
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
17 changes: 10 additions & 7 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@


skiplist = {
"__rpow__", # NOTE: cannot fix because torch test case has undefined behavior
# such as 0 to negative power.
"_segment_reduce",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"byte",
"cat",
"cdist",
"ceil",
"cholesky",
"cholesky_solve",
"diagonal_copy",
Expand Down Expand Up @@ -116,21 +113,27 @@
"special.zeta",
"svd",
"svd_lowrank",
"to_sparse", # We are not supporting sparse tensors yet.
"unfold_copy",
"unfold",
"unique_consecutive",
"unique",
"unravel_index",
"trunc",
"var_mean",
"argwhere",
"nanmean",
"chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180
"nn.functional.upsample_bilinear",
"randint",
}

not_support_ops_list = {
"chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180
"__rpow__", # NOTE: cannot fix because torch test case has undefined behavior
# such as 0 to negative power.
"ceil", # only failed with python 3.9
"trunc", # only failed with python 3.9
"to_sparse", # We are not supporting sparse tensors yet.
}

# These inputs are themselves views
# We cannot know how are the views created so cannot replicate the behavior.
variant_test_name_to_skip = {
Expand Down Expand Up @@ -219,7 +222,7 @@ def run_export_and_compare(testcase,

ops_to_test = [
test for test in op_db
if (test.name not in skiplist and
if (test.name not in (skiplist | not_support_ops_list) and
test.variant_test_name not in variant_test_name_to_skip)
]

Expand Down
6 changes: 6 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def _aten_clone(x, memory_format=None):
return x


# aten.trunc
@op(torch.ops.aten.trunc)
def _aten_trunc(x):
return jnp.trunc(x)


@op(torch.ops.aten.index_copy)
def _aten_index_copy(x, dim, indexes, source):
# return jax.lax.scatter(x, index, dim)
Expand Down

0 comments on commit 07d0823

Please sign in to comment.