diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 09d67b87487..b242636c5ca 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -61,7 +61,6 @@ "nn.functional.multi_head_attention_forward", "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", - "nn.functional.pad", "nn.functional.pairwise_distance", "nn.functional.poisson_nll_loss", "nn.functional.rrelu", diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index 8a211e0aca6..e82621bec3b 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -156,6 +156,56 @@ def wrap_flash_attention(query, key, value): return wrap_flash_attention(query, key, value) +@register_function(torch.nn.functional.pad) +def pad(tensor, pad, mode="constant", value=None): + # For padding modes that have different names between Torch and NumPy, this + # dict provides a Torch-to-NumPy translation. Any string not in this dict will + # be passed through as-is. + MODE_NAME_TRANSLATION = { + "circular": "wrap", + "replicate": "edge", + } + + numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) + + num_prefix_dims = tensor.ndim - len(pad) // 2 + + numpy_pad_width = [(0, 0)] * num_prefix_dims + nd_slice = [slice(None)] * num_prefix_dims + + for i in range(len(pad) - 2, -1, -2): + pad_start, pad_end = pad[i:i + 2] + slice_start, slice_end = None, None + + if pad_start < 0: + slice_start = -pad_start + pad_start = 0 + + if pad_end < 0: + slice_end = pad_end + pad_end = 0 + + numpy_pad_width.append((pad_start, pad_end)) + nd_slice.append(slice(slice_start, slice_end)) + + nd_slice = tuple(nd_slice) + + # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, + # even if the value we pass in is `None`. (It treats `None` as `nan`.) + kwargs = dict() + if mode == "constant" and value is not None: + kwargs["constant_values"] = value + + # The "replicate" mode pads first and then slices, whereas the "circular" mode + # slices first and then pads. The latter approach deals with smaller tensors, + # so we default to that option in modes where the order of operations doesn't + # affect the result. + if mode == "replicate": + return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice] + else: + return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs) + + @register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True) def scaled_dot_product_attention( query, key, value, attn_mask=None,