Skip to content

Commit

Permalink
add missing aten op (#7078)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored May 22, 2024
1 parent 8d35eb0 commit 5e1d454
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 25 deletions.
111 changes: 111 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,117 @@ def test_aten_native_layer_norm_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs)

def test_aten_native_batch_norm_legit(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,2,2)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,4)).to(torch.float32),
None,
None,
torch.ones(channel),
torch.zeros(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_no_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs)

def test_aten_native_batch_norm_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_eval(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_ne_Scalar_0(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand Down
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"__getitem__",
"__rmatmul__",
"__rpow__",
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"argsort",
Expand Down Expand Up @@ -198,7 +197,6 @@
"nansum",
"narrow_copy",
"narrow",
"native_batch_norm",
"native_layer_norm",
"new_empty",
"new_empty_strided",
Expand Down
91 changes: 68 additions & 23 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
torch.ops.aten.eq_: torch.ops.aten.eq,
torch.ops.aten.ne_: torch.ops.aten.ne,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
}


Expand Down Expand Up @@ -545,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
def _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
)
"""JAX implementation of batch normalization with optional parameters.
Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
Args:
input (DeviceArray): Input data (N, C, H, W).
running_mean ([DeviceArray]): Running mean of input (C,).
running_var ([DeviceArray]): Running variance of input (C,).
weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None.
bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None.
training (bool): If True, use batch statistics for normalization.
If False, use running statistics.
momentum (float): Momentum factor for updating running statistics.
eps (float): Small constant for numerical stability.
Returns:
DeviceArray: Normalized output
DeviceArray: Batch mean (C,) or empty if training is False
DeviceArray: Reversed batch variance (C,) or empty if training is False
"""
reduction_dims = [0] + list(range(2, input.ndim))
reshape_dims = [1, -1] + [1]*(input.ndim-2)

if training:
# Calculate batch mean and variance
mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
saved_mean = jnp.squeeze(mean, reduction_dims)
var = jnp.var(input, axis=reduction_dims)
rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps)
# Update running statistics using momentum
running_mean = (1 - momentum) * running_mean + momentum * saved_mean
running_var = (1 - momentum) * running_var + momentum * var
saved_rstd = jnp.squeeze(rstd, reduction_dims)
else:
rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode
saved_rstd = jnp.array([])

# Normalize
if training:
# use batch statistics if training
x_hat = (input - mean) * rstd
else:
# Use running statistics in inference mode
x_hat = (input - running_mean.reshape(reshape_dims)) * rstd

# Scale and shift
if weight is not None:
x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
if bias is not None:
x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting

return x_hat, saved_mean, saved_rstd



@op(torch.ops.aten._native_batch_norm_legit_no_training)
def _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
):
if weight is None:
weight = jnp.ones_like(running_mean)
if bias is None:
bias = jnp.zeros_like(running_mean)

def broadcast(t):
return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,))

if running_mean is not None:
a = input - broadcast(running_mean)
else:
a = input
if running_var is not None:
b = broadcast(jnp.sqrt(running_var + eps))
else:
b = broadcast(jnp.sqrt(eps))
return (
a / b * broadcast(weight) + broadcast(bias),
jnp.array([]),
jnp.array([]),
return _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, False, momentum, eps
)


Expand Down Expand Up @@ -1950,3 +1983,15 @@ def _aten_outer(a, b):
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(input, other, rtol, atol, equal_nan)

@op(torch.ops.aten.native_batch_norm)
def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):

if running_mean is None:
running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None
if running_var is None:
running_var = jnp.ones(input.shape[1]) # Initialize running variance if None

if training:
return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
else:
return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)

0 comments on commit 5e1d454

Please sign in to comment.