Skip to content

Commit

Permalink
Fix support for torch.var_mean. (#8275)
Browse files Browse the repository at this point in the history
Co-authored-by: mrguenther <[email protected]>
  • Loading branch information
mrguenther and mrguenther authored Oct 18, 2024
1 parent fb34db8 commit 2d73a5f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
"unfold_copy",
"unfold",
"unravel_index",
"var_mean",
"nanmean",
"nn.functional.upsample_bilinear",
"randint",
Expand Down
16 changes: 11 additions & 5 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,11 +3027,17 @@ def _aten_to_dtype_layout(

# Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False
@op(torch.ops.aten.var_mean.correction)
def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False):
return (
jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim),
jnp.mean(self, dim, keepdims=keepdim),
)
def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
# The internal API technically has a default `correction` argument of `None`,
# but the public API has a default argument of 1. Therefore, we simply set our
# default argument to 1. However, since the argument is officially supposed to
# be nullable, we still need to check for `None` per the API contract.
if correction is None:
correction = 1
mean = jnp.mean(tensor, axis=dim, keepdims=keepdim)
# TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it.
var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim)
return var, mean


@op(torch.ops.aten.scalar_tensor)
Expand Down

0 comments on commit 2d73a5f

Please sign in to comment.