Skip to content

Commit

Permalink
Add support for torch.nn.functional.pad. (#8290)
Browse files Browse the repository at this point in the history
Co-authored-by: mrguenther <[email protected]>
  • Loading branch information
mrguenther and mrguenther authored Oct 19, 2024
1 parent 387a274 commit e52dc4a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"nn.functional.multi_head_attention_forward",
"nn.functional.multi_margin_loss",
"nn.functional.multilabel_margin_loss",
"nn.functional.pad",
"nn.functional.pairwise_distance",
"nn.functional.poisson_nll_loss",
"nn.functional.rrelu",
Expand Down
50 changes: 50 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,56 @@ def wrap_flash_attention(query, key, value):
return wrap_flash_attention(query, key, value)


@register_function(torch.nn.functional.pad)
def pad(tensor, pad, mode="constant", value=None):
# For padding modes that have different names between Torch and NumPy, this
# dict provides a Torch-to-NumPy translation. Any string not in this dict will
# be passed through as-is.
MODE_NAME_TRANSLATION = {
"circular": "wrap",
"replicate": "edge",
}

numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode)

num_prefix_dims = tensor.ndim - len(pad) // 2

numpy_pad_width = [(0, 0)] * num_prefix_dims
nd_slice = [slice(None)] * num_prefix_dims

for i in range(len(pad) - 2, -1, -2):
pad_start, pad_end = pad[i:i + 2]
slice_start, slice_end = None, None

if pad_start < 0:
slice_start = -pad_start
pad_start = 0

if pad_end < 0:
slice_end = pad_end
pad_end = 0

numpy_pad_width.append((pad_start, pad_end))
nd_slice.append(slice(slice_start, slice_end))

nd_slice = tuple(nd_slice)

# `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg,
# even if the value we pass in is `None`. (It treats `None` as `nan`.)
kwargs = dict()
if mode == "constant" and value is not None:
kwargs["constant_values"] = value

# The "replicate" mode pads first and then slices, whereas the "circular" mode
# slices first and then pads. The latter approach deals with smaller tensors,
# so we default to that option in modes where the order of operations doesn't
# affect the result.
if mode == "replicate":
return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice]
else:
return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs)


@register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
def scaled_dot_product_attention(
query, key, value, attn_mask=None,
Expand Down

0 comments on commit e52dc4a

Please sign in to comment.