diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 659901b74dc..05f440cae98 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4844,22 +4844,17 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): tile_shape[-1] = 1 tile_shape[-2] = 1 P = jnp.tile(identity2d, tile_shape) - #print("debug: start permutation matrix:", P) # closure to be called for each input 2D matrix. def _lu_unpack_2d(p, pivot): - jax.debug.print("unpack2d: {} {} {}", p , pivot, pivot.size) _pivot = pivot - 1 # pivots are offset by 1 in jax indices = jnp.array([*range(n)], dtype=jnp.int32) def update_indices(i, _indices): - #jax.debug.print("fori <<: {} {} {} {}", i, _indices, _pivot, p) tmp = _indices[i] _indices = _indices.at[i].set(_indices[_pivot[i]]) _indices = _indices.at[_pivot[i]].set(tmp) - #jax.debug.print("fori >>: {} {} {} {}", i, _indices, _pivot, p) return _indices indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) - #jax.debug.print("indices {}", indices) p = p[jnp.array(indices)] p = jnp.transpose(p) return p @@ -4888,12 +4883,9 @@ def update_indices(i, _indices): # reshape result back to P's shape newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1]) - #print("newshape: {} {}", newRetshape, unpackedP.shape, dim) P = unpackedP.reshape(newRetshape) - #print("permutation after: ", P) else: # emulate pytroch behavior: return empty tensors P = torch.empty(torch.Size([0])) - #print("debug output:", P, L, U) return P, L, U