Skip to content

Commit

Permalink
Use assert_allclose in gemm test (#132)
Browse files Browse the repository at this point in the history
`assert_allclose` will actually show you how many elements mismatched
and the relative/absolute difference.

---------

Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Sep 10, 2024
1 parent 7a405fd commit 8297af5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.iree_utils import generate_iree_ref
import os
from torch.testing import assert_close

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))

Expand Down Expand Up @@ -93,7 +94,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
gemm(a, b, c)
iree_ref = torch.zeros(2048, 10240, dtype=torch.float32)
generate_iree_ref("mmt", [a, b], [iree_ref], config)
assert torch.equal(c, iree_ref)
assert_close(c, iree_ref)


if __name__ == "__main__":
Expand Down

0 comments on commit 8297af5

Please sign in to comment.