diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa0d84e71dc..f93ae21c997 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -28,12 +28,8 @@ "linalg.cholesky", "linalg.cholesky_ex", "linalg.det", - "linalg.ldl_factor", - "linalg.ldl_factor_ex", "linalg.ldl_solve", "linalg.lstsq", - "linalg.lu_factor", - "linalg.lu_factor_ex", "linalg.lu_solve", "linalg.matrix_norm", "linalg.matrix_power", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5ef39d40cc8..948b86813c0 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2412,6 +2412,30 @@ def _aten_linalg_eigh(A, UPLO='L'): return jnp.linalg.eigh(A, UPLO) +@op(torch.ops.aten.linalg_ldl_factor_ex) +def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): + # TODO: Replace with native LDL when available: + # https://github.com/jax-ml/jax/issues/12779 + # TODO: Not tested for complex inputs. Does not support hermitian=True + pivots = jnp.broadcast_to( + jnp.arange(1, A.shape[-1]+1, dtype=jnp.int32), A.shape[:-1] + ) + info = jnp.zeros(A.shape[:-2], jnp.int32) + C = jnp.linalg.cholesky(A) + if C.size == 0: + return C, pivots, info + + # Fill diagonals of stacked matrices + @functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)') + def fill_diagonal_batch(x, y): + return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) + + D = C * jnp.eye(C.shape[-1], dtype=A.dtype) + LD = C @ jnp.linalg.inv(D) + LD = fill_diagonal_batch(LD, D*D) + return LD, pivots, info + + @op(torch.ops.aten.linalg_lu) def _aten_linalg_lu(A, pivot=True, out=None): dtype = A.dtype @@ -2442,6 +2466,15 @@ def perm_to_P(perm): return P,L,U +@op(torch.ops.aten.linalg_lu_factor_ex) +def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False): + lu, pivots, _ = jax.lax.linalg.lu(A) + # PT pivots vector is 1-indexed + pivots = pivots + 1 + info = jnp.zeros(A.shape[:-2], jnp.int32) + return lu, pivots, info + + @op(torch.ops.aten.gcd) def _aten_gcd(input, other): return jnp.gcd(input, other)