Skip to content

Commit

Permalink
Implement lu_unpack in jax
Browse files Browse the repository at this point in the history
- For Lower and Upper matrices, use jnp.tril, jnp.triu
- Shape both triangle matrices and add 1's to the lower triangle to
  match the pytorch behavior
- For 2D permutation matrix start with an identity matrix and mutate it
  based on pivots
  - start with sequential indices and apply the pivot operations to it
  - finally use the pivoted indices to index the identity matrix to
    generate the final permutation matrix for that pivot.
- For 2D inputs and 1D pivot the above logic would work
- For >=3d inputs, we first reshape the inputs to become 3D and then
  call vmap along the first dim with the 2d logic for each 2d matrix
  • Loading branch information
barney-s committed Oct 18, 2024
1 parent f71c02d commit 9a292c5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"linalg.matrix_norm",
"linalg.matrix_power",
"linalg.tensorsolve",
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"nn.functional.adaptive_avg_pool3d",
Expand Down
102 changes: 102 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4560,3 +4560,105 @@ def _euclidean_direct(x1, x2):
dist = jnp.sqrt(dist_sq).astype(jnp.float32)

return dist

@op(torch.ops.aten.lu_unpack)
def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
# lu_unpack doesnt exist in jax.
# Get commonly used data shape variables
n = LU_data.shape[-2]
m = LU_data.shape[-1]
dim = min(n,m)

### Compute the Lower and Upper triangle
if unpack_data:
# Extract lower triangle
L = jnp.tril(LU_data, k=-1)

#emulate pytorch behavior: Add ones to the diagonal of L
eye = jnp.eye(n, m, dtype=LU_data.dtype)
L = L + eye

# emulate pytorch behavior: Reshape lower triangle to match pivot
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
limit_indices = list(LU_data.shape)
limit_indices[-1] = dim
L = jax.lax.slice(L, start_indices, limit_indices)

# Extract upper triangle
U = jnp.triu(LU_data)

# emulate pytorch behavior: Reshape upper triangle to match pivot
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
limit_indices = list(LU_data.shape)
limit_indices[-2] = dim
U = jax.lax.slice(U, start_indices, limit_indices)
else:
# emulate pytroch behavior: return empty tensors
L = torch.empty(torch.Size([0]))
U = torch.empty(torch.Size([0]))

### Compute the Permutation matrix
if unpack_pivots:
# We should return a permutation matrix (2D) for each pivot array (1D)
# The shape of the final Permutation matrix depends on the shape of the input
# data and the pivots

# start with a 2D identity matrix and tile it to the other dims of input data
identity2d = jnp.identity(n, dtype=jnp.float32)
tile_shape = list(LU_data.shape)
tile_shape[-1] = 1
tile_shape[-2] = 1
P = jnp.tile(identity2d, tile_shape)
#print("debug: start permutation matrix:", P)

# closure to be called for each input 2D matrix.
def _lu_unpack_2d(p, pivot):
jax.debug.print("unpack2d: {} {} {}", p , pivot, pivot.size)
_pivot = pivot - 1 # pivots are offset by 1 in jax
indices = jnp.array([*range(n)], dtype=jnp.int32)
def update_indices(i, _indices):
#jax.debug.print("fori <<: {} {} {} {}", i, _indices, _pivot, p)
tmp = _indices[i]
_indices = _indices.at[i].set(_indices[_pivot[i]])
_indices = _indices.at[_pivot[i]].set(tmp)
#jax.debug.print("fori >>: {} {} {} {}", i, _indices, _pivot, p)
return _indices
indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices)
#jax.debug.print("indices {}", indices)
p = p[jnp.array(indices)]
p = jnp.transpose(p)
return p

if len(LU_pivots.shape) == 1:
# if we are dealing with a simple 2D input and 1D pivot, call the closure directly
P = _lu_unpack_2d(P, LU_pivots)
else:
# We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the
# closure for each 2D matrix. Finally unflatten the result to match the input data
# shape.

# reshape permutation matrix to 3d
dim_size = jnp.prod(jnp.array(P.shape[:-2]))
newPshape = (dim_size, P.shape[-2], P.shape[-1])
reshapedP = P.reshape(newPshape)

# reshape pivots to 3d
dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1]))
newPivotshape = (dim_size, LU_pivots.shape[-1])
reshapedPivot = LU_pivots.reshape(newPivotshape)

# vmap the reshaped 3d tensors
v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0))
unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot)

# reshape result back to P's shape
newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1])
#print("newshape: {} {}", newRetshape, unpackedP.shape, dim)
P = unpackedP.reshape(newRetshape)
#print("permutation after: ", P)
else:
# emulate pytroch behavior: return empty tensors
P = torch.empty(torch.Size([0]))

#print("debug output:", P, L, U)
return P, L, U

0 comments on commit 9a292c5

Please sign in to comment.