Skip to content

Commit

Permalink
WIP: lu_unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s committed Oct 16, 2024
1 parent f71c02d commit 6eb8218
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 2 deletions.
3 changes: 1 addition & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"linalg.matrix_norm",
"linalg.matrix_power",
"linalg.tensorsolve",
"lu_unpack",
"masked.median",
"max_pool2d_with_indices_backward",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_reference_eager(self, device, dtype, op):
continue
check_output = op.name not in random_ops

#print("[DEBUG] sample_input: ", sample_input)
print("[DEBUG] sample_input: ", sample_input)

# TODO: this is a workaround to skip int64 cast for linspace
# reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments
Expand Down
270 changes: 270 additions & 0 deletions experimental/torch_xla2/test_logspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import torch
import torch_xla2
import jax.numpy as jnp
from jax import vmap

env = torch_xla2.default_env()
#env.config.debug_print_each_op = True
#env.config.debug_accuracy_for_each_op = True


def squeeze():
with env:
t1 = torch.tensor([-3.5])
r1 = t1.squeeze_(-1)
print("xla | torch.squeeze :", r1)
t1 = torch.tensor([-3.5])
r1 = t1.squeeze_(-1)
print("native| torch.squeeze :", r1)

def nanquantile():
with env:
t1 = torch.tensor([-7.0, 0.0, torch.nan])
r1 = t1.nanquantile(0.5)
print("xla | torch.nanquantile(",t1,") :", r1)
t1 = torch.tensor([-7.0, 0.0, torch.nan])
r1 = t1.nanquantile(0.5)
print("native| torch.nanquantile(",t1,") :", r1)

def empty():
with env:
#print("xla torch.full: ", torch.full((2,), 3))
start1 = torch.tensor(-2)
print("xla | torch.tensor(-2) ->", start1)
start2 = torch.tensor([-2])
print("xla | torch.tensor([-2]) ->", start2)
emp1 = torch.empty((1,))
ret1 = emp1.copy_(start1)
print("xla | torch.empty((1,)).copy_(tensor(-2)) :", ret1)
emp2 = torch.empty((1,))
ret2 = emp2.copy_(start2)
print("xla | torch.empty((1,)).copy_(tensor([-2])):", ret2)

#print("native torch.full: ", torch.full((2,), 3))
start1 = torch.tensor(-2)
print("native| torch.tensor(-2) ->", start1)
start2 = torch.tensor([-2])
print("native| torch.tensor([-2]) ->", start2)
emp1 = torch.empty((1,))
ret1 = emp1.copy_(start1)
print("native| torch.empty((1,)).copy_(tensor(-2)) :", ret1)
emp2 = torch.empty((1,))
ret2 = emp2.copy_(start2)
print("native| torch.empty((1,)).copy_(tensor([-2])):", ret2)

def casting():
t = torch.tensor([ 4.3000, 4.1510, 4.0020, 3.8531, 3.7041, 3.5551, 3.4061, 3.2571,
3.1082, 2.9592, 2.8102, 2.6612, 2.5122, 2.3633, 2.2143, 2.0653,
1.9163, 1.7673, 1.6184, 1.4694, 1.3204, 1.1714, 1.0224, 0.8735,
0.7245, 0.5755, 0.4265, 0.2776, 0.1286, -0.0204, -0.1694, -0.3184,
-0.4673, -0.6163, -0.7653, -0.9143, -1.0633, -1.2122, -1.3612, -1.5102,
-1.6592, -1.8082, -1.9571, -2.1061, -2.2551, -2.4041, -2.5531, -2.7020,
-2.8510, -3.0000])
with env:
print("xla |", t.type(torch.int64))
print("native|", t.type(torch.int64))

def linspace():
dtype=torch.int64
with env:
print("xla | torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
return
with env:
print("xla | torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
with env:
print("xla | torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))
print("native| torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))

def logspace():
with env:
print("xla torch.logspace: ", torch.logspace(start=-10, end=10, steps=5))
print("native torch.logspace: ", torch.logspace(start=-10, end=10, steps=5))

def log_normal():
with env:
t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815])
print("xla |torch.log_normal: ", t.log_normal_(0, 0.25))
t = torch.tensor([-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134, -0.1783, 7.1360, -0.7987, 2.3815])
print("native |torch.log_normal: ", t.log_normal_(0, 0.25))

def linalg_vector_norm():
with env:
t = torch.tensor(-0.06738138198852539)
print("xla | linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)
t = torch.tensor(-0.06738138198852539)
print("native| linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)

def linalg_tensorsolve():
with env:
A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134],
[-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691],
[-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]],
[[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348],
[-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785],
[-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
B = torch.tensor([[-5.2537, 7.7364, 4.0160],
[ 4.3621, 0.4733, -4.6142]])
print("xla | linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))
A = torch.tensor([[[-0.0674, 4.8280, -7.4074, -6.6235, -3.4664, 2.4134],
[-0.1783, 7.1360, -0.7987, 2.3815, -2.7199, -1.7691],
[-8.5981, -5.9605, -3.7100, 0.3334, 3.5580, 5.4002]],
[[-6.1015, -3.9192, 3.2690, 7.4735, -1.8522, 6.7348],
[-1.4507, 0.9523, 8.1493, -8.3490, -5.6658, -2.2785],
[-3.5082, 7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
B = torch.tensor([[-5.2537, 7.7364, 4.0160],
[ 4.3621, 0.4733, -4.6142]])
print("native| linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))

def test_lu():
A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145],
[ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332],
[-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]])
print("native| lu()", torch.lu(A, pivot=True, get_infos=True))
with env:
A = torch.tensor([[ 0.0437, 0.6733, -0.7089, -0.4736, -0.3145],
[ 0.2206, -0.3749, 0.8442, -0.5197, 0.2332],
[-0.2896, -0.6009, -0.6085, -0.9129, -0.3178]])
print("xla | lu()", torch.lu(A, pivot=True, get_infos=True))

def test_lu_solve():
b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981],
[-5.9605, -3.7100, 0.3334, 3.5580],
[ 5.4002, -6.1015, -3.9192, 3.2690]])
LU = torch.tensor([[-0.7679, -0.4551, 0.3539],
[ 0.0390, 1.2674, 0.2928],
[-0.0856, 0.2779, -1.2844]])
pivots = torch.tensor([2, 3, 3], dtype=torch.int32)
print("native| lu_solve()", torch.lu_solve(b, LU, pivots))
with env:
b = torch.tensor([[ 2.3815, -2.7199, -1.7691, -8.5981],
[-5.9605, -3.7100, 0.3334, 3.5580],
[ 5.4002, -6.1015, -3.9192, 3.2690]])
LU = torch.tensor([[-0.7679, -0.4551, 0.3539],
[ 0.0390, 1.2674, 0.2928],
[-0.0856, 0.2779, -1.2844]])
pivots = torch.tensor([2, 3, 3], dtype=torch.int32)
print("xla | lu_solve()", torch.lu_solve(b, LU, pivots))

def test_lu_unpack():
unpack_data=True
unpack_pivots=True
if False:
lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100],
[ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745],
[-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]])
pivots = torch.tensor([3, 3, 3], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100],
[ 0.0248, 4.8718, -7.1944, -6.4758, -3.3745],
[-0.8873, -0.3588, -3.0746, -8.4111, -2.1212]])
pivots = torch.tensor([3, 3, 3], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))

if False:
lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845],
[ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389],
[ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366],
[-0.4855, -0.7570, 0.7641, 9.0972, 16.3916],
[ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616],
[-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]])
pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[-8.3876, 7.9964, 6.8432, -8.9778, 1.6845],
[ 0.8269, -9.9104, -2.1215, 14.8806, 6.4389],
[ 0.1808, 0.2953, -4.7303, 0.6897, -7.5366],
[-0.4855, -0.7570, 0.7641, 9.0972, 16.3916],
[ 0.1354, 0.0746, -0.2784, 0.6465, -4.7616],
[-0.9468, -0.9447, 0.7085, 0.6482, 0.6800]])
pivots=torch.tensor([5, 3, 6, 5, 6], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))

if True:
lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766],
[ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900],
[ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]],

[[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849],
[ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431],
[ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]],

[[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493],
[ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563],
[ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]])
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
print("native| lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))
with env:
lu = torch.tensor([[[ -5.3344, -2.2530, -4.3840, -3.1485, -7.3766],
[ 0.3589, 2.7324, -4.2898, 0.6681, 9.0900],
[ 0.1734, 0.2346, 0.9901, 2.2108, 4.8699]],

[[ 8.5252, 5.7155, 8.5447, -0.6509, -8.0849],
[ -0.5005, 8.9886, 4.2181, -4.7992, -10.9431],
[ -0.9880, -0.2169, 7.5312, 3.2518, -5.4951]],

[[ -8.6799, 5.6140, -7.0426, -1.9027, -3.6493],
[ -0.0134, -4.0132, 3.2959, -8.1260, -0.6563],
[ 0.1997, 0.7197, -9.0417, -1.5426, -0.2071]]])
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
print("xla | lu_unpack()", torch.lu_unpack(lu, pivots,unpack_data=unpack_data, unpack_pivots=unpack_pivots))


def pivot_to_permutation():
n = 3
with env:
P = torch.tensor([
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
[[1., 0., 0.],
[0.,1.,0.],
[0.,0.,1.]],
])
print("debug: start permutation matrix:", P)
pivots=torch.tensor([[1, 2, 3],
[1, 2, 3],
[1, 3, 3]], dtype=torch.int32)
pivot_size = pivots.shape[-1]

def _lu_unpack_2d(p, pivot):
#print("treearg:", tree_arg)
#p, pivot = tree_arg
print("args:", p , pivot)
_pivot = pivot - 1 # pivots are offset by 1 in jax
indices = [*range(n)]
for i in range(_pivot.size):
indices[i], indices[_pivot[i]] = indices[_pivot[i]], indices[i]
#print("[debug]: i, pivot[i], indices:", i, _pivots[i], indices)
p = p[jnp.array(indices)]
p = jnp.transpose(p)
#print("permutation:", p)
return p

tree = (P, pivots)
v_lu_unpack_2d = vmap(_lu_unpack_2d, in_axes=((0, 0)))
ret = v_lu_unpack_2d(P, pivots)
return ret



#nanquantile()
#squeeze()
#linspace()
#casting()
#log_normal()
#linalg_vector_norm()
#linalg_tensorsolve()
#test_lu()
#test_lu_solve()
#test_lu_unpack()
pivot_to_permutation()
97 changes: 97 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4560,3 +4560,100 @@ def _euclidean_direct(x1, x2):
dist = jnp.sqrt(dist_sq).astype(jnp.float32)

return dist

@op(torch.ops.aten.lu_unpack)
def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
n = LU_data.shape[-2]
m = LU_data.shape[-1]
dim = min(n,m)
#print("input: lu.shape ", LU_data.shape, LU_pivots.shape)
if unpack_data:
L = jnp.tril(LU_data, k=-1) # Extract lower triangle
#print("lower:", L)
eye = jnp.eye(n, m, dtype=LU_data.dtype)
#print("eye:", eye)
L = L + eye # Add ones to the diagonal of L
#print("lower+eye:", L)
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
limit_indices = list(LU_data.shape)
limit_indices[-1] = dim
#print("indices: ", start_indices, limit_indices)
L = jax.lax.slice(L, start_indices, limit_indices) # Reshape lower triangle to match pivot

U = jnp.triu(LU_data) # Extract upper triangle
start_indices = jnp.zeros(len(LU_data.shape), dtype=int)
limit_indices = list(LU_data.shape)
limit_indices[-2] = dim
U = jax.lax.slice(U, start_indices, limit_indices) # Reshape upper triangle to match pivot
#print("upper:",U)
else:
L = torch.empty(torch.Size([0]))
U = torch.empty(torch.Size([0]))

if unpack_pivots:
if True:
pivots = LU_pivots - 1 # pivots are offset by 1 in jax

pivot_size = pivots.shape[-1]
#? m = pivot_size # In LU decomposition, k = min(m, n), and here we assume m <= n

# use an identity to create permutation matrix
# tile the 2D identity matrix in other dims
leading_dims = LU_data.shape[:-2]
tile_shape = (*leading_dims, 1, 1)
identity = jnp.eye(n, dtype=jnp.float32)
identity = jnp.expand_dims(identity, axis=tuple(range(len(leading_dims))))
P = jnp.tile(identity, tile_shape)
print("debug: start permutation matrix:", P)

# Apply the swaps iteratively
for i in range(pivot_size):
# Get the swap indices
row_idx = jnp.arange(n) #? row_idx = jnp.arange(m) or (pivot_size)
col_idx = pivots[..., i]
print("r c :", row_idx, col_idx, i)

#print("debug: (1, 1):", P[...,1, 1])
#print("debug: (row_idx, i):", P[...,row_idx, i])
#print("debug: (0, col_idx):", P[...,0, col_idx])
#print("debug: (row_idx, col_idx):", P[...,row_idx, col_idx])
#print("debug: (row_idx, i):", P[...,row_idx, i])
#return
# Swap the columns in the permutation matrix
pi = P[..., i, row_idx]
pp = jnp.zeros(pi.shape, dtype=pi.dtype)
for j in range(len(col_idx)):
pp_row = P[..., col_idx[j] , row_idx][j]
print("debug: pp_row:", j, pp_row)
pp = pp.at[j].set(pp_row)
print("debug: (i, row_idx):", pi)
print("debug: (row_idx, col_idx):", pp)
P = P.at[..., i, row_idx].set(pp)
print("debug: p1:", P)
P = P.at[..., col_idx, row_idx].set(pi)
print("debug: p2:", P)
else:
_pivots = LU_pivots - 1 # pivots are offset by 1 in jax

# use an identity to create permutation matrix
# tile the 2D identity matrix in other dims
tile_shape = list(LU_data.shape)
tile_shape[-1] = 1
tile_shape[-2] = 1
P = jnp.tile(jnp.identity(n, dtype=jnp.float32), tile_shape)
#print("debug: start permutation matrix:", P)

indices = [*range(n)]
for i in range(_pivots.size):
indices[i], indices[_pivots[i]] = indices[_pivots[i]], indices[i]
#print("[debug]: i, pivot[i], indices:", i, _pivots[i], indices)
P = P[jnp.array(indices)]
P = jnp.transpose(P)
#print("permutation:", P)
else:
P = torch.empty(torch.Size([0]))
print("permutation:", P)


#print("debug output:", P, L, U)
return P, L, U

0 comments on commit 6eb8218

Please sign in to comment.