From b55065a2b747ed4a4b755a9f34d04e6bcabdfa4d Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 9 Oct 2024 12:20:55 -0400 Subject: [PATCH] Enable check for sharded Conv2D test (#263) The fix https://github.com/iree-org/iree-turbine/pull/205 solves the issue with this test. Xfail the Unet Resnet block test with maybe low accuracy. --- .../layers/sharded_conv2d_with_iree_test.py | 14 +++++------ .../sharded_resnet_block_with_iree_test.py | 24 ++++++++++++------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 2a6ecace2..9b29e5761 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -173,14 +173,12 @@ def run_test_sharded_conv2d_with_iree( ) assert len(actual_result.shards) == len(expected_result.shards) assert actual_result.shard_dim == expected_result.shard_dim - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) def test_sharded_conv2d_with_iree( diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 86bb41c71..581584369 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -19,6 +19,7 @@ import iree.runtime from typing import List, Optional import os +import pytest vm_context: iree.runtime.VmContext = None @@ -207,19 +208,26 @@ def run_test_sharded_resnet_block_with_iree( parameters_path=parameters_path, ) assert len(actual_result.shards) == len(expected_result.shards) - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + # TODO: reenable this test once numerical issues are resolved. + # The absolute accuracy is > 0.00042. Is this good enough? + # Maybe add a test with fp64, where if the accuracy is high would give us more + # confidence that fp32 is also OK. + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) global vm_context del vm_context +@pytest.mark.xfail( + reason="Maybe numerical issues with low accuracy.", + strict=True, + raises=AssertionError, +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path],