Skip to content

Commit

Permalink
rearrange mutations
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Oct 7, 2024
1 parent ce33501 commit 68a66cb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@
torch.ops.aten.ge_: torch.ops.aten.ge,
torch.ops.aten.eq_: torch.ops.aten.eq,
torch.ops.aten.ne_: torch.ops.aten.ne,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.geometric_: torch.ops.aten.geometric,
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
torch.ops.aten.random_: torch.ops.aten.uniform,
torch.ops.aten.ceil_: torch.ops.aten.ceil,
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
Expand Down Expand Up @@ -2721,7 +2722,7 @@ def _bernoulli(
return res


@op(torch.ops.aten.geometric_, needs_env=True)
@op(torch.ops.aten.geometric, needs_env=True)
def geometric(self, p, *, generator=None, env=None):
key = env.get_and_rotate_prng_key(generator)
res = jax.random.geometric(key, p, self.shape)
Expand Down

0 comments on commit 68a66cb

Please sign in to comment.