Skip to content

Commit

Permalink
Add diagonal_scatter aten ops (#8232)
Browse files Browse the repository at this point in the history
  • Loading branch information
nupurbaghel authored Oct 8, 2024
1 parent 7aa996c commit c29cccf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"cholesky_solve",
"complex",
"diagonal_copy",
"diagonal_scatter",
"digamma",
"exponential",
"geqrf",
Expand Down
34 changes: 34 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,40 @@ def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
return jnp.diagonal(input, offset, dim1, dim2)


def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1):
input_len = len(input_shape)
if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len):
raise ValueError("dim1 and dim2 must be different and in range [0, " + str(input_len-1)+ "]")

size1, size2 = input_shape[dim1], input_shape[dim2]
if offset >= 0:
indices1 = jnp.arange(min(size1, size2 - offset))
indices2 = jnp.arange(offset, offset + len(indices1))
else:
indices2 = jnp.arange(min(size1 + offset, size2 ))
indices1 = jnp.arange(-offset, -offset + len(indices2))
return [indices1, indices2]

@op(torch.ops.aten.diagonal_scatter)
def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2)

if input.ndim == 2:
return input.at[tuple(indexes)].set(src)
else:
# src has the same shape as the output of
# jnp.diagonal(input, offset, dim1, dim2).
# Last dimension always contains the diagonal elements,
# while the preceding dimensions represent the "slices"
# from which these diagonals are extracted. Thus,
# we alter input axes to match this assumption, write src
# and then move the axes back to the original state.
input = jnp.moveaxis(input, (dim1, dim2), (-2,-1))
multi_indexes = [slice(None)]*(input.ndim-2) + indexes
input = input.at[tuple(multi_indexes)].set(src)
return jnp.moveaxis(input, (-2,-1), (dim1, dim2))


# aten.diagflat
@op(torch.ops.aten.diagflat)
def _aten_diagflat(input, offset=0):
Expand Down

0 comments on commit c29cccf

Please sign in to comment.