diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index befaae21517..09d67b87487 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -90,7 +90,6 @@ "svd_lowrank", "unfold_copy", "unfold", - "unravel_index", "nanmean", "nn.functional.upsample_bilinear", "randint", diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index e355094694d..8a211e0aca6 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -245,6 +245,11 @@ def empty_strided( return empty(size, dtype=dtype) +@register_function(torch.unravel_index) +def unravel_index(indices, shape): + return jnp.unravel_index(indices, shape) + + @register_function(torch.rand, is_jax_function=False) def rand( *size, **kwargs