From 8297af5f00417f35de77ad906a1d459d51c77ec1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 10 Sep 2024 21:23:20 +0300 Subject: [PATCH] Use `assert_allclose` in gemm test (#132) `assert_allclose` will actually show you how many elements mismatched and the relative/absolute difference. --------- Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_gemm_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 18050d59..f7850d1c 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -8,6 +8,7 @@ from shark_turbine.kernel.lang.global_symbols import * from shark_turbine.kernel.wave.iree_utils import generate_iree_ref import os +from torch.testing import assert_close _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) @@ -93,7 +94,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: gemm(a, b, c) iree_ref = torch.zeros(2048, 10240, dtype=torch.float32) generate_iree_ref("mmt", [a, b], [iree_ref], config) - assert torch.equal(c, iree_ref) + assert_close(c, iree_ref) if __name__ == "__main__":