From 82bcddadf84a5676e861e0c5228dc588934477dd Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Mon, 14 Oct 2024 15:10:01 +0800 Subject: [PATCH] Implement avg_pool1d, avg_pool2d, and avg_pool3d (#7517) (#8255) Co-authored-by: Simon Teo --- experimental/torch_xla2/test/test_ops.py | 3 -- .../torch_xla2/torch_xla2/ops/jaten.py | 46 +++++++++++++++++-- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 811f8499199..4cd02340675 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -49,9 +49,6 @@ "nn.functional.adaptive_max_pool2d", "nn.functional.adaptive_max_pool3d", "nn.functional.alpha_dropout", - "nn.functional.avg_pool1d", - "nn.functional.avg_pool2d", - "nn.functional.avg_pool3d", "nn.functional.bilinear", "nn.functional.conv_transpose1d", "nn.functional.conv_transpose2d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a8685f68e86..b8afcc35fec 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1772,7 +1772,7 @@ def adaptive_kernel_size(input_shape, output_shape): return y -# aten.avg_pool2d +@op(torch.ops.aten.avg_pool1d) @op(torch.ops.aten.avg_pool2d) @op(torch.ops.aten.avg_pool3d) def _aten_avg_pool( @@ -1787,6 +1787,8 @@ def _aten_avg_pool( num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) strides = tuple(strides) if strides else kernel_size + if isinstance(padding, list) and len(padding) == 1: + padding = padding[0] if isinstance(padding, int): padding = [padding for _ in range(len(kernel_size))] @@ -1800,7 +1802,26 @@ def _aten_avg_pool( if divisor_override is not None: y = y / jnp.array(divisor_override, y.dtype) elif count_include_pad: - y = y / jnp.array(np.prod(kernel_size), y.dtype) + div_shape = list(y.shape) + div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) + unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) + unequal_padding_indices = np.where(list(unequal_paddings))[0] + if len(unequal_padding_indices) > 0: + # indices to update kernel size + offset = len(div_shape) - len(padding) + skip_indices = list(map(lambda x: x + offset, unequal_padding_indices)) + indices = _generate_indices(div_shape, skip_dim_indices=skip_indices) + # updated kernel size accounting for maximum padding + new_kernel_size = list(kernel_size) + for j in unequal_padding_indices: + new_kernel_size[j] = kernel_size[j] - padding[j][1] + padding[j][0] + + for idx in indices: + for j in unequal_padding_indices: + idx[j + offset] = -1 + div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) + + y = y / div_by else: div_shape = list(inputs.shape) div_shape[num_batch_dims] = 1 @@ -1815,8 +1836,25 @@ def _aten_avg_pool( strides, padding, ) - return y - + return y.astype(inputs.dtype) + +# helper function to generate all indices to iterate through ndarray +def _generate_indices(dims, skip_dim_indices = []): + res = [] + def _helper(curr_dim_idx, sofar): + if curr_dim_idx in skip_dim_indices: + _helper(curr_dim_idx + 1, sofar[:]) + return + if curr_dim_idx >= len(dims): + print(sofar) + res.append(sofar) + return + for i in range(dims[curr_dim_idx]): + sofar[curr_dim_idx] = i + _helper(curr_dim_idx + 1, sofar[:]) + + _helper(0, [0 for _ in dims]) + return res # aten.sym_numel # aten.reciprocal