diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 7869368c528..72d92280e18 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -35,13 +35,14 @@ torch.ops.aten.ge_: torch.ops.aten.ge, torch.ops.aten.eq_: torch.ops.aten.eq, torch.ops.aten.ne_: torch.ops.aten.ne, + torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p, + torch.ops.aten.geometric_: torch.ops.aten.geometric, + torch.ops.aten.normal_: torch.ops.aten.normal, + torch.ops.aten.random_: torch.ops.aten.uniform, torch.ops.aten.uniform_: torch.ops.aten.uniform, torch.ops.aten.relu_: torch.ops.aten.relu, - torch.ops.aten.normal_: torch.ops.aten.normal, torch.ops.aten.squeeze_: torch.ops.aten.squeeze, - torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p, torch.ops.aten.clamp_: torch.ops.aten.clamp, - torch.ops.aten.random_: torch.ops.aten.uniform, torch.ops.aten.ceil_: torch.ops.aten.ceil, torch.ops.aten.logical_not_: torch.ops.aten.logical_not, torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze, @@ -2721,7 +2722,7 @@ def _bernoulli( return res -@op(torch.ops.aten.geometric_, needs_env=True) +@op(torch.ops.aten.geometric, needs_env=True) def geometric(self, p, *, generator=None, env=None): key = env.get_and_rotate_prng_key(generator) res = jax.random.geometric(key, p, self.shape)