Skip to content

Commit

Permalink
[torch_xla2] Fix geometric and gcd (#8226)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg authored Oct 8, 2024
1 parent e3cf356 commit d50fecb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
9 changes: 4 additions & 5 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
"diagonal_scatter",
"digamma",
"exponential",
"gcd",
"geometric",
"geqrf",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
Expand Down Expand Up @@ -154,8 +152,9 @@
'empty_permuted',
'empty_strided',
'bernoulli',
"new_empty",
"new_empty_strided",
'geometric',
'new_empty',
'new_empty_strided',
'randint_like',
'randn',
'randn_like',
Expand Down Expand Up @@ -202,7 +201,7 @@ def run_export_and_compare(testcase,
atol, rtol = (1e-3, 1e-5)
if func.name in atol_dict:
atol, rtol = atol_dict[func.name]

with testcase.subTest("torch_eval"):
res = func(sample_input.input, *sample_input.args, **sample_input.kwargs)
with testcase.subTest("torch_xla2_eval"):
Expand Down
17 changes: 14 additions & 3 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@
torch.ops.aten.ge_: torch.ops.aten.ge,
torch.ops.aten.eq_: torch.ops.aten.eq,
torch.ops.aten.ne_: torch.ops.aten.ne,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.geometric_: torch.ops.aten.geometric,
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.ceil_: torch.ops.aten.ceil,
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
Expand Down Expand Up @@ -2258,6 +2259,10 @@ def _aten_linalg_eig(A):
def _aten_linalg_eigh(A, UPLO='L'):
return jnp.linalg.eigh(A, UPLO)

@op(torch.ops.aten.gcd)
def _aten_gcd(input, other):
return jnp.gcd(input, other)

# aten.lcm
@op(torch.ops.aten.lcm)
def _aten_lcm(input, other):
Expand Down Expand Up @@ -2717,6 +2722,12 @@ def _bernoulli(
return res


@op(torch.ops.aten.geometric, needs_env=True)
def geometric(self, p, *, generator=None, env=None):
key = env.get_and_rotate_prng_key(generator)
res = jax.random.geometric(key, p, self.shape)
return res


@op(torch.ops.aten.randn_like, needs_env=True)
@op_base.convert_dtype()
Expand Down

0 comments on commit d50fecb

Please sign in to comment.