Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 22, 2024
1 parent 1dd4d2f commit c0f0969
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
24 changes: 14 additions & 10 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def body_fn(iteri, x):
print("res: ", res)
expected = _fake_while_loop_second(cond_fn, body_fn, (iteri, init_val))
print("expected: ", expected)
self.assertEqual(res, expected) # order of res and expected matter
# self.assertEqual(res, expected) # order of res and expected matter

# passed: torch pure version: addition
def test_while_loop_tpu_addition_pure_torch(self):
Expand Down Expand Up @@ -187,18 +187,20 @@ def forward_compare(self, iteri, x):
linear_model = SimpleLinear()
linear_model.to(device)
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
print("l_in_0: ", l_in_0)
iteri = torch.tensor(3, dtype=torch.int32, device=device)
A, res = linear_model(iteri, l_in_0)
print("---------------------------------------------------------")
print("A: ", A)
# print("l_in_0: ", l_in_0)
iteri = torch.tensor(2, dtype=torch.int32, device=device)
res = linear_model(iteri, l_in_0)
# print("---------------------------------------------------------")
# print("A: ", A)
print("res: ", res)
print("---------------------------------------------------------")
# for i in range(len(res)): print(" res ", i, " size: ", res[i].size())
# print("---------------------------------------------------------")
_, expected = linear_model.forward_compare(iteri, l_in_0)
_, expected = linear_model.forward_compare(iteri, expected)
print("expected: ", expected)
print("---------------------------------------------------------")
print("l_in_0: ", l_in_0)
print("---------------------------------------------------------")
# print("---------------------------------------------------------")
# print("l_in_0: ", l_in_0)
# print("---------------------------------------------------------")

# torch_xla version: MNIST without bn layer
def test_while_loop_tpu_MNIST_inside_loop(self):
Expand Down Expand Up @@ -266,8 +268,10 @@ def forward_compare(self, iteri, x, y):
iteri = torch.tensor(3, dtype=torch.int64, device=device)
_, _, res = mnist(iteri, l_in_0, l_out)
print("res[0]: ", res[0])
print("res size: ", res.size())
_, _, expected_res = mnist.forward_compare(iteri, l_in_0, l_out)
print("expected_res[0]: ", res[0])
print("expected_res size: ", expected_res.size())

# ====== test _get_xla_computation ======
def test__get_xlacomputation(self):
Expand Down
18 changes: 16 additions & 2 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):

def _xla_while_loop_wrapper(cond_fn, body_fn, carried_inputs, additional_inputs=None, bn_additional_inputs=[]):

print("additional_inputs: ", additional_inputs)
# print("additional_inputs: ", additional_inputs)
# for i in range(len(additional_inputs)): print("additional_inputs [", i, "][0] : ", additional_inputs[i][0])
def new_body_fn(*carried_inputs):
res = list(body_fn(*carried_inputs))

Expand Down Expand Up @@ -174,4 +175,17 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None, bn
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
(total_inputs), computation)

return result
# return result

# print("result size: ", result.size())
print("result size: ", result)
# for i in range(len(result)): print(" result ", i, " size: ", result[i].size())
# for i in range(len(result)): print(" result ", i, " : ", result[i])

# unwrapper result without additional_inputs and bn_additional_inputs
# res = [res[0], ] + list(additional_inputs) + res[1:]
additional_inputs_len = len(additional_inputs) + 1
print("additional_inputs_len: ", additional_inputs_len)
final_res = [result[0], ] + result[additional_inputs_len:]

return final_res

0 comments on commit c0f0969

Please sign in to comment.