diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa0d84e71dc..3db89611eea 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -38,7 +38,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.tensorsolve", - "lu_unpack", "masked.median", "max_pool2d_with_indices_backward", "nn.functional.adaptive_avg_pool3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5ef39d40cc8..f4b7882e8cd 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4560,3 +4560,105 @@ def _euclidean_direct(x1, x2): dist = jnp.sqrt(dist_sq).astype(jnp.float32) return dist + +@op(torch.ops.aten.lu_unpack) +def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): + # lu_unpack doesnt exist in jax. + # Get commonly used data shape variables + n = LU_data.shape[-2] + m = LU_data.shape[-1] + dim = min(n,m) + + ### Compute the Lower and Upper triangle + if unpack_data: + # Extract lower triangle + L = jnp.tril(LU_data, k=-1) + + #emulate pytorch behavior: Add ones to the diagonal of L + eye = jnp.eye(n, m, dtype=LU_data.dtype) + L = L + eye + + # emulate pytorch behavior: Reshape lower triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-1] = dim + L = jax.lax.slice(L, start_indices, limit_indices) + + # Extract upper triangle + U = jnp.triu(LU_data) + + # emulate pytorch behavior: Reshape upper triangle to match pivot + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-2] = dim + U = jax.lax.slice(U, start_indices, limit_indices) + else: + # emulate pytroch behavior: return empty tensors + L = torch.empty(torch.Size([0])) + U = torch.empty(torch.Size([0])) + + ### Compute the Permutation matrix + if unpack_pivots: + # We should return a permutation matrix (2D) for each pivot array (1D) + # The shape of the final Permutation matrix depends on the shape of the input + # data and the pivots + + # start with a 2D identity matrix and tile it to the other dims of input data + identity2d = jnp.identity(n, dtype=jnp.float32) + tile_shape = list(LU_data.shape) + tile_shape[-1] = 1 + tile_shape[-2] = 1 + P = jnp.tile(identity2d, tile_shape) + #print("debug: start permutation matrix:", P) + + # closure to be called for each input 2D matrix. + def _lu_unpack_2d(p, pivot): + jax.debug.print("unpack2d: {} {} {}", p , pivot, pivot.size) + _pivot = pivot - 1 # pivots are offset by 1 in jax + indices = jnp.array([*range(n)], dtype=jnp.int32) + def update_indices(i, _indices): + #jax.debug.print("fori <<: {} {} {} {}", i, _indices, _pivot, p) + tmp = _indices[i] + _indices = _indices.at[i].set(_indices[_pivot[i]]) + _indices = _indices.at[_pivot[i]].set(tmp) + #jax.debug.print("fori >>: {} {} {} {}", i, _indices, _pivot, p) + return _indices + indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) + #jax.debug.print("indices {}", indices) + p = p[jnp.array(indices)] + p = jnp.transpose(p) + return p + + if len(LU_pivots.shape) == 1: + # if we are dealing with a simple 2D input and 1D pivot, call the closure directly + P = _lu_unpack_2d(P, LU_pivots) + else: + # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the + # closure for each 2D matrix. Finally unflatten the result to match the input data + # shape. + + # reshape permutation matrix to 3d + dim_size = jnp.prod(jnp.array(P.shape[:-2])) + newPshape = (dim_size, P.shape[-2], P.shape[-1]) + reshapedP = P.reshape(newPshape) + + # reshape pivots to 3d + dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) + newPivotshape = (dim_size, LU_pivots.shape[-1]) + reshapedPivot = LU_pivots.reshape(newPivotshape) + + # vmap the reshaped 3d tensors + v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0,0)) + unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) + + # reshape result back to P's shape + newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1]) + #print("newshape: {} {}", newRetshape, unpackedP.shape, dim) + P = unpackedP.reshape(newRetshape) + #print("permutation after: ", P) + else: + # emulate pytroch behavior: return empty tensors + P = torch.empty(torch.Size([0])) + + #print("debug output:", P, L, U) + return P, L, U \ No newline at end of file