diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa0d84e71dc..7dffb7020c7 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -38,7 +38,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.tensorsolve", - "lu_unpack", "masked.median", "max_pool2d_with_indices_backward", "nn.functional.adaptive_avg_pool3d", @@ -249,7 +248,7 @@ def test_reference_eager(self, device, dtype, op): continue check_output = op.name not in random_ops - #print("[DEBUG] sample_input: ", sample_input) + print("[DEBUG] sample_input: ", sample_input) # TODO: this is a workaround to skip int64 cast for linspace # reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments diff --git a/experimental/torch_xla2/test_logspace.py b/experimental/torch_xla2/test_logspace.py new file mode 100644 index 00000000000..d53b2cdf1bd --- /dev/null +++ b/experimental/torch_xla2/test_logspace.py @@ -0,0 +1,270 @@ +import torch +import torch_xla2 +import jax.numpy as jnp +from jax import vmap + +env = torch_xla2.default_env() +#env.config.debug_print_each_op = True +#env.config.debug_accuracy_for_each_op = True + + +def squeeze(): + with env: + t1 = torch.tensor([-3.5]) + r1 = t1.squeeze_(-1) + print("xla | torch.squeeze :", r1) + t1 = torch.tensor([-3.5]) + r1 = t1.squeeze_(-1) + print("native| torch.squeeze :", r1) + +def nanquantile(): + with env: + t1 = torch.tensor([-7.0, 0.0, torch.nan]) + r1 = t1.nanquantile(0.5) + print("xla | torch.nanquantile(",t1,") :", r1) + t1 = torch.tensor([-7.0, 0.0, torch.nan]) + r1 = t1.nanquantile(0.5) + print("native| torch.nanquantile(",t1,") :", r1) + +def empty(): + with env: + #print("xla torch.full: ", torch.full((2,), 3)) + start1 = torch.tensor(-2) + print("xla | torch.tensor(-2) ->", start1) + start2 = torch.tensor([-2]) + print("xla | torch.tensor([-2]) ->", start2) + emp1 = torch.empty((1,)) + ret1 = emp1.copy_(start1) + print("xla | torch.empty((1,)).copy_(tensor(-2)) :", ret1) + emp2 = torch.empty((1,)) + ret2 = emp2.copy_(start2) + print("xla | torch.empty((1,)).copy_(tensor([-2])):", ret2) + + #print("native torch.full: ", torch.full((2,), 3)) + start1 = torch.tensor(-2) + print("native| torch.tensor(-2) ->", start1) + start2 = torch.tensor([-2]) + print("native| torch.tensor([-2]) ->", start2) + emp1 = torch.empty((1,)) + ret1 = emp1.copy_(start1) + print("native| torch.empty((1,)).copy_(tensor(-2)) :", ret1) + emp2 = torch.empty((1,)) + ret2 = emp2.copy_(start2) + print("native| torch.empty((1,)).copy_(tensor([-2])):", ret2) + +def casting(): + t = torch.tensor([ 4.3000, 4.1510, 4.0020, 3.8531, 3.7041, 3.5551, 3.4061, 3.2571, + 3.1082, 2.9592, 2.8102, 2.6612, 2.5122, 2.3633, 2.2143, 2.0653, + 1.9163, 1.7673, 1.6184, 1.4694, 1.3204, 1.1714, 1.0224, 0.8735, + 0.7245, 0.5755, 0.4265, 0.2776, 0.1286, -0.0204, -0.1694, -0.3184, + -0.4673, -0.6163, -0.7653, -0.9143, -1.0633, -1.2122, -1.3612, -1.5102, + -1.6592, -1.8082, -1.9571, -2.1061, -2.2551, -2.4041, -2.5531, -2.7020, + -2.8510, -3.0000]) + with env: + print("xla |", t.type(torch.int64)) + print("native|", t.type(torch.int64)) + +def linspace(): + dtype=torch.int64 + with env: + print("xla | torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype)) + print("native| torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype)) + return + with env: + print("xla | torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype)) + print("native| torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype)) + with env: + print("xla | torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype)) + print("native| torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype)) + +def logspace(): + with env: + print("xla torch.logspace: ", torch.logspace(start=-10, end=10, steps=5)) + print("native torch.logspace: ", torch.logspace(start=-10, end=10, steps=5)) + +def log_normal(): + with env: + t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815]) + print("xla |torch.log_normal: ", t.log_normal_(0, 0.25)) + t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815]) + print("native |torch.log_normal: ", t.log_normal_(0, 0.25)) + +def linalg_vector_norm(): + with env: + t = torch.tensor(-0.06738138198852539) + print("xla | linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype) + t = torch.tensor(-0.06738138198852539) + print("native| linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype) + +def linalg_tensorsolve(): + with env: + A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134], + [-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691], + [-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]], + [[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348], + [-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785], + [-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]]) + B = torch.tensor([[-5.2537, 7.7364, 4.0160], + [ 4.3621, 0.4733, -4.6142]]) + print("xla | linalg.vectorsolve()", torch.linalg.tensorsolve(A, B)) + A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134], + [-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691], + [-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]], + [[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348], + [-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785], + [-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]]) + B = torch.tensor([[-5.2537, 7.7364, 4.0160], + [ 4.3621, 0.4733, -4.6142]]) + print("native| linalg.vectorsolve()", torch.linalg.tensorsolve(A, B)) + +def test_lu(): + A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145], + [ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332], + [-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]]) + print("native| lu()", torch.lu(A, pivot=True, get_infos=True)) + with env: + A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145], + [ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332], + [-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]]) + print("xla | lu()", torch.lu(A, pivot=True, get_infos=True)) + +def test_lu_solve(): + b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981], + [-5.9605, -3.7100, 0.3334, 3.5580], + [ 5.4002, -6.1015, -3.9192, 3.2690]]) + LU = torch.tensor([[-0.7679, -0.4551, 0.3539], + [ 0.0390, 1.2674, 0.2928], + [-0.0856, 0.2779, -1.2844]]) + pivots = torch.tensor([2, 3, 3], dtype=torch.int32) + print("native| lu_solve()", torch.lu_solve(b, LU, pivots)) + with env: + b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981], + [-5.9605, -3.7100, 0.3334, 3.5580], + [ 5.4002, -6.1015, -3.9192, 3.2690]]) + LU = torch.tensor([[-0.7679, -0.4551, 0.3539], + [ 0.0390, 1.2674, 0.2928], + [-0.0856, 0.2779, -1.2844]]) + pivots = torch.tensor([2, 3, 3], dtype=torch.int32) + print("xla | lu_solve()", torch.lu_solve(b, LU, pivots)) + +def test_lu_unpack(): + unpack_data=True + unpack_pivots=True + if False: + lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100], + [ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745], + [-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]]) + pivots = torch.tensor([3, 3, 3], dtype=torch.int32) + print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + with env: + lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100], + [ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745], + [-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]]) + pivots = torch.tensor([3, 3, 3], dtype=torch.int32) + print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + + if False: + lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845], + [ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389], + [ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366], + [-0.4855, -0.7570, 0.7641, 9.0972, 16.3916], + [ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616], + [-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]]) + pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32) + print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + with env: + lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845], + [ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389], + [ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366], + [-0.4855, -0.7570, 0.7641, 9.0972, 16.3916], + [ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616], + [-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]]) + pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32) + print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + + if True: + lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766], + [ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900], + [ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]], + + [[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849], + [ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431], + [ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]], + + [[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493], + [ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563], + [ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]]) + pivots=torch.tensor([[1, 2, 3], + [1, 2, 3], + [1, 3, 3]], dtype=torch.int32) + print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + with env: + lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766], + [ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900], + [ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]], + + [[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849], + [ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431], + [ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]], + + [[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493], + [ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563], + [ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]]) + pivots=torch.tensor([[1, 2, 3], + [1, 2, 3], + [1, 3, 3]], dtype=torch.int32) + print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots)) + + +def pivot_to_permutation(): + n = 3 + with env: + P = torch.tensor([ + [[1., 0., 0.], + [0.,1.,0.], + [0.,0.,1.]], + [[1., 0., 0.], + [0.,1.,0.], + [0.,0.,1.]], + [[1., 0., 0.], + [0.,1.,0.], + [0.,0.,1.]], + ]) + print("debug: start permutation matrix:", P) + pivots=torch.tensor([[1, 2, 3], + [1, 2, 3], + [1, 3, 3]], dtype=torch.int32) + pivot_size = pivots.shape[-1] + + def _lu_unpack_2d(p, pivot): + #print("treearg:", tree_arg) + #p, pivot = tree_arg + print("args:", p , pivot) + _pivot = pivot - 1 # pivots are offset by 1 in jax + indices = [*range(n)] + for i in range(_pivot.size): + indices[i], indices[_pivot[i]] = indices[_pivot[i]], indices[i] + #print("[debug]: i, pivot[i], indices:", i, _pivots[i], indices) + p = p[jnp.array(indices)] + p = jnp.transpose(p) + #print("permutation:", p) + return p + + tree = (P, pivots) + v_lu_unpack_2d = vmap(_lu_unpack_2d, in_axes=((0, 0))) + ret = v_lu_unpack_2d(P, pivots) + return ret + + + +#nanquantile() +#squeeze() +#linspace() +#casting() +#log_normal() +#linalg_vector_norm() +#linalg_tensorsolve() +#test_lu() +#test_lu_solve() +#test_lu_unpack() +pivot_to_permutation() \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5ef39d40cc8..62fdc9eb6b5 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4560,3 +4560,100 @@ def _euclidean_direct(x1, x2): dist = jnp.sqrt(dist_sq).astype(jnp.float32) return dist + +@op(torch.ops.aten.lu_unpack) +def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): + n = LU_data.shape[-2] + m = LU_data.shape[-1] + dim = min(n,m) + #print("input: lu.shape ", LU_data.shape, LU_pivots.shape) + if unpack_data: + L = jnp.tril(LU_data, k=-1) # Extract lower triangle + #print("lower:", L) + eye = jnp.eye(n, m, dtype=LU_data.dtype) + #print("eye:", eye) + L = L + eye # Add ones to the diagonal of L + #print("lower+eye:", L) + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-1] = dim + #print("indices: ", start_indices, limit_indices) + L = jax.lax.slice(L, start_indices, limit_indices) # Reshape lower triangle to match pivot + + U = jnp.triu(LU_data) # Extract upper triangle + start_indices = jnp.zeros(len(LU_data.shape), dtype=int) + limit_indices = list(LU_data.shape) + limit_indices[-2] = dim + U = jax.lax.slice(U, start_indices, limit_indices) # Reshape upper triangle to match pivot + #print("upper:",U) + else: + L = torch.empty(torch.Size([0])) + U = torch.empty(torch.Size([0])) + + if unpack_pivots: + if True: + pivots = LU_pivots - 1 # pivots are offset by 1 in jax + + pivot_size = pivots.shape[-1] + #? m = pivot_size # In LU decomposition, k = min(m, n), and here we assume m <= n + + # use an identity to create permutation matrix + # tile the 2D identity matrix in other dims + leading_dims = LU_data.shape[:-2] + tile_shape = (*leading_dims, 1, 1) + identity = jnp.eye(n, dtype=jnp.float32) + identity = jnp.expand_dims(identity, axis=tuple(range(len(leading_dims)))) + P = jnp.tile(identity, tile_shape) + print("debug: start permutation matrix:", P) + + # Apply the swaps iteratively + for i in range(pivot_size): + # Get the swap indices + row_idx = jnp.arange(n) #? row_idx = jnp.arange(m) or (pivot_size) + col_idx = pivots[..., i] + print("r c :", row_idx, col_idx, i) + + #print("debug: (1, 1):", P[...,1, 1]) + #print("debug: (row_idx, i):", P[...,row_idx, i]) + #print("debug: (0, col_idx):", P[...,0, col_idx]) + #print("debug: (row_idx, col_idx):", P[...,row_idx, col_idx]) + #print("debug: (row_idx, i):", P[...,row_idx, i]) + #return + # Swap the columns in the permutation matrix + pi = P[..., i, row_idx] + pp = jnp.zeros(pi.shape, dtype=pi.dtype) + for j in range(len(col_idx)): + pp_row = P[..., col_idx[j] , row_idx][j] + print("debug: pp_row:", j, pp_row) + pp = pp.at[j].set(pp_row) + print("debug: (i, row_idx):", pi) + print("debug: (row_idx, col_idx):", pp) + P = P.at[..., i, row_idx].set(pp) + print("debug: p1:", P) + P = P.at[..., col_idx, row_idx].set(pi) + print("debug: p2:", P) + else: + _pivots = LU_pivots - 1 # pivots are offset by 1 in jax + + # use an identity to create permutation matrix + # tile the 2D identity matrix in other dims + tile_shape = list(LU_data.shape) + tile_shape[-1] = 1 + tile_shape[-2] = 1 + P = jnp.tile(jnp.identity(n, dtype=jnp.float32), tile_shape) + #print("debug: start permutation matrix:", P) + + indices = [*range(n)] + for i in range(_pivots.size): + indices[i], indices[_pivots[i]] = indices[_pivots[i]], indices[i] + #print("[debug]: i, pivot[i], indices:", i, _pivots[i], indices) + P = P[jnp.array(indices)] + P = jnp.transpose(P) + #print("permutation:", P) + else: + P = torch.empty(torch.Size([0])) + print("permutation:", P) + + + #print("debug output:", P, L, U) + return P, L, U \ No newline at end of file