Skip to content

Commit

Permalink
Implement avg_pool1d, avg_pool2d, and avg_pool3d (#7517) (#8255)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Teo <[email protected]>
  • Loading branch information
simonteozw and Simon Teo authored Oct 14, 2024
1 parent a4e295a commit 82bcdda
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
"nn.functional.adaptive_max_pool2d",
"nn.functional.adaptive_max_pool3d",
"nn.functional.alpha_dropout",
"nn.functional.avg_pool1d",
"nn.functional.avg_pool2d",
"nn.functional.avg_pool3d",
"nn.functional.bilinear",
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
Expand Down
46 changes: 42 additions & 4 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ def adaptive_kernel_size(input_shape, output_shape):
return y


# aten.avg_pool2d
@op(torch.ops.aten.avg_pool1d)
@op(torch.ops.aten.avg_pool2d)
@op(torch.ops.aten.avg_pool3d)
def _aten_avg_pool(
Expand All @@ -1787,6 +1787,8 @@ def _aten_avg_pool(
num_batch_dims = len(inputs.shape) - len(kernel_size) - 1
kernel_size = tuple(kernel_size)
strides = tuple(strides) if strides else kernel_size
if isinstance(padding, list) and len(padding) == 1:
padding = padding[0]
if isinstance(padding, int):
padding = [padding for _ in range(len(kernel_size))]

Expand All @@ -1800,7 +1802,26 @@ def _aten_avg_pool(
if divisor_override is not None:
y = y / jnp.array(divisor_override, y.dtype)
elif count_include_pad:
y = y / jnp.array(np.prod(kernel_size), y.dtype)
div_shape = list(y.shape)
div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size)
unequal_paddings = map(lambda pad: pad[0] != pad[1], padding)
unequal_padding_indices = np.where(list(unequal_paddings))[0]
if len(unequal_padding_indices) > 0:
# indices to update kernel size
offset = len(div_shape) - len(padding)
skip_indices = list(map(lambda x: x + offset, unequal_padding_indices))
indices = _generate_indices(div_shape, skip_dim_indices=skip_indices)
# updated kernel size accounting for maximum padding
new_kernel_size = list(kernel_size)
for j in unequal_padding_indices:
new_kernel_size[j] = kernel_size[j] - padding[j][1] + padding[j][0]

for idx in indices:
for j in unequal_padding_indices:
idx[j + offset] = -1
div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size))

y = y / div_by
else:
div_shape = list(inputs.shape)
div_shape[num_batch_dims] = 1
Expand All @@ -1815,8 +1836,25 @@ def _aten_avg_pool(
strides,
padding,
)
return y

return y.astype(inputs.dtype)

# helper function to generate all indices to iterate through ndarray
def _generate_indices(dims, skip_dim_indices = []):
res = []
def _helper(curr_dim_idx, sofar):
if curr_dim_idx in skip_dim_indices:
_helper(curr_dim_idx + 1, sofar[:])
return
if curr_dim_idx >= len(dims):
print(sofar)
res.append(sofar)
return
for i in range(dims[curr_dim_idx]):
sofar[curr_dim_idx] = i
_helper(curr_dim_idx + 1, sofar[:])

_helper(0, [0 for _ in dims])
return res

# aten.sym_numel
# aten.reciprocal
Expand Down

0 comments on commit 82bcdda

Please sign in to comment.