Skip to content

Commit

Permalink
WIP: lu_unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s committed Oct 15, 2024
1 parent f71c02d commit e48001e
Show file tree
Hide file tree
Showing 3 changed files with 447 additions and 2 deletions.
3 changes: 1 addition & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"linalg.matrix_norm",
"linalg.matrix_power",
"linalg.tensorsolve",
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_reference_eager(self, device, dtype, op):
continue
check_output = op.name not in random_ops

#print("[DEBUG] sample_input: ", sample_input)
print("[DEBUG] sample_input: ", sample_input)

# TODO: this is a workaround to skip int64 cast for linspace
# reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments
Expand Down
266 changes: 266 additions & 0 deletions experimental/torch_xla2/test_logspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import torch
import torch_xla2
import jax.numpy as jnp

env = torch_xla2.default_env()
#env.config.debug_print_each_op = True
#env.config.debug_accuracy_for_each_op = True


def squeeze():
with env:
t1 = torch.tensor([-3.5])
r1 = t1.squeeze_(-1)
print("xla | torch.squeeze :", r1)
t1 = torch.tensor([-3.5])
r1 = t1.squeeze_(-1)
print("native| torch.squeeze :", r1)

def nanquantile():
with env:
t1 = torch.tensor([-7.0, 0.0, torch.nan])
r1 = t1.nanquantile(0.5)
print("xla | torch.nanquantile(",t1,") :", r1)
t1 = torch.tensor([-7.0, 0.0, torch.nan])
r1 = t1.nanquantile(0.5)
print("native| torch.nanquantile(",t1,") :", r1)

def empty():
with env:
#print("xla torch.full: ", torch.full((2,), 3))
start1 = torch.tensor(-2)
print("xla | torch.tensor(-2) ->", start1)
start2 = torch.tensor([-2])
print("xla | torch.tensor([-2]) ->", start2)
emp1 = torch.empty((1,))
ret1 = emp1.copy_(start1)
print("xla | torch.empty((1,)).copy_(tensor(-2)) :", ret1)
emp2 = torch.empty((1,))
ret2 = emp2.copy_(start2)
print("xla | torch.empty((1,)).copy_(tensor([-2])):", ret2)

#print("native torch.full: ", torch.full((2,), 3))
start1 = torch.tensor(-2)
print("native| torch.tensor(-2) ->", start1)
start2 = torch.tensor([-2])
print("native| torch.tensor([-2]) ->", start2)
emp1 = torch.empty((1,))
ret1 = emp1.copy_(start1)
print("native| torch.empty((1,)).copy_(tensor(-2)) :", ret1)
emp2 = torch.empty((1,))
ret2 = emp2.copy_(start2)
print("native| torch.empty((1,)).copy_(tensor([-2])):", ret2)

def casting():
t = torch.tensor([ 4.3000, 4.1510, 4.0020, 3.8531, 3.7041, 3.5551, 3.4061, 3.2571,
3.1082, 2.9592, 2.8102, 2.6612, 2.5122, 2.3633, 2.2143, 2.0653,
1.9163, 1.7673, 1.6184, 1.4694, 1.3204, 1.1714, 1.0224, 0.8735,
0.7245, 0.5755, 0.4265, 0.2776, 0.1286, -0.0204, -0.1694, -0.3184,
-0.4673, -0.6163, -0.7653, -0.9143, -1.0633, -1.2122, -1.3612, -1.5102,
-1.6592, -1.8082, -1.9571, -2.1061, -2.2551, -2.4041, -2.5531, -2.7020,
-2.8510, -3.0000])
with env:
print("xla |", t.type(torch.int64))
print("native|", t.type(torch.int64))

def linspace():
dtype=torch.int64
with env:
print("xla | torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
return
with env:
print("xla | torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
with env:
print("xla | torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))

def logspace():
with env:
print("xla torch.logspace: ", torch.logspace(start=-10, end=10, steps=5))
print("native torch.logspace: ", torch.logspace(start=-10, end=10, steps=5))

def log_normal():
with env:
t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815])
print("xla |torch.log_normal: ", t.log_normal_(0, 0.25))
t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815])
print("native |torch.log_normal: ", t.log_normal_(0, 0.25))

def linalg_vector_norm():
with env:
t = torch.tensor(-0.06738138198852539)
print("xla | linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)
t = torch.tensor(-0.06738138198852539)
print("native| linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)

def linalg_tensorsolve():
with env:
A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134],
[-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691],
[-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]],
[[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348],
[-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785],
[-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
B = torch.tensor([[-5.2537, 7.7364, 4.0160],
[ 4.3621, 0.4733, -4.6142]])
print("xla | linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))
A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134],
[-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691],
[-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]],
[[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348],
[-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785],
[-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
B = torch.tensor([[-5.2537, 7.7364, 4.0160],
[ 4.3621, 0.4733, -4.6142]])
print("native| linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))

def test_lu():
A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145],
[ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332],
[-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]])
print("native| lu()", torch.lu(A, pivot=True, get_infos=True))
with env:
A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145],
[ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332],
[-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]])
print("xla | lu()", torch.lu(A, pivot=True, get_infos=True))

def test_lu_solve():
b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981],
[-5.9605, -3.7100, 0.3334, 3.5580],
[ 5.4002, -6.1015, -3.9192, 3.2690]])
LU = torch.tensor([[-0.7679, -0.4551, 0.3539],
[ 0.0390, 1.2674, 0.2928],
[-0.0856, 0.2779, -1.2844]])
pivots = torch.tensor([2, 3, 3], dtype=torch.int32)
print("native| lu_solve()", torch.lu_solve(b, LU, pivots))
with env:
b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981],
[-5.9605, -3.7100, 0.3334, 3.5580],
[ 5.4002, -6.1015, -3.9192, 3.2690]])
LU = torch.tensor([[-0.7679, -0.4551, 0.3539],
[ 0.0390, 1.2674, 0.2928],
[-0.0856, 0.2779, -1.2844]])
pivots = torch.tensor([2, 3, 3], dtype=torch.int32)
print("xla | lu_solve()", torch.lu_solve(b, LU, pivots))

def test_lu_unpack():
unpack_data=True
unpack_pivots=True
if False:
lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100],
[ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745],
[-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]])
pivots = torch.tensor([3, 3, 3], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100],
[ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745],
[-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]])
pivots = torch.tensor([3, 3, 3], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))

if False:
lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845],
[ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389],
[ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366],
[-0.4855, -0.7570, 0.7641, 9.0972, 16.3916],
[ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616],
[-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]])
pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845],
[ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389],
[ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366],
[-0.4855, -0.7570, 0.7641, 9.0972, 16.3916],
[ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616],
[-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]])
pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))

if True:
lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766],
[ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900],
[ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]],

[[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849],
[ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431],
[ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]],

[[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493],
[ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563],
[ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]])
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766],
[ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900],
[ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]],

[[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849],
[ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431],
[ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]],

[[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493],
[ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563],
[ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]])
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))


def pivot_to_permutation():
with env:
P = torch.tensor([
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
])
print("debug: start permutation matrix:", P)
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
pivot_size = pivots.shape[-1]

row_idx = jnp.arange(3) #? row_idx = jnp.arange(m) or (pivot_size)
i=1
col_idx = pivots[..., i]
indices = torch.tensor(row_idx)
c1 = P.gather(2, indices)
c1 = P.index_select(2, indices)
print("debug:", c1)
return

# Apply the swaps iteratively
for i in range(pivot_size):
# Get the swap indices
row_idx = jnp.arange(3) #? row_idx = jnp.arange(m) or (pivot_size)
col_idx = pivots[..., i]
print("r c :", row_idx, col_idx, i)
indices = torch.tensor([[0, 0]])
c1 = P.gather(2, indices)
print("debug:", c1)

#nanquantile()
#squeeze()
#linspace()
#casting()
#log_normal()
#linalg_vector_norm()
#linalg_tensorsolve()
#test_lu()
#test_lu_solve()
#test_lu_unpack()
pivot_to_permutation()
Loading

0 comments on commit e48001e

Please sign in to comment.