Skip to content

Commit

Permalink
Enable check for sharded Conv2D test (#263)
Browse files Browse the repository at this point in the history
The fix iree-org/iree-turbine#205 solves the
issue with this test.

Xfail the Unet Resnet block test with maybe low accuracy.
  • Loading branch information
sogartar authored Oct 9, 2024
1 parent 4e2f351 commit b55065a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
14 changes: 6 additions & 8 deletions sharktank/tests/layers/sharded_conv2d_with_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,12 @@ def run_test_sharded_conv2d_with_iree(
)
assert len(actual_result.shards) == len(expected_result.shards)
assert actual_result.shard_dim == expected_result.shard_dim
# TODO: reenable this check once numerical issues are resolved.
# See https://github.com/iree-org/iree/issues/18283
# for actual_shard, expected_shard in zip(
# actual_result.shards, expected_result.shards
# ):
# torch.testing.assert_close(
# unbox_tensor(actual_shard), unbox_tensor(expected_shard)
# )
for actual_shard, expected_shard in zip(
actual_result.shards, expected_result.shards
):
torch.testing.assert_close(
unbox_tensor(actual_shard), unbox_tensor(expected_shard)
)


def test_sharded_conv2d_with_iree(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import iree.runtime
from typing import List, Optional
import os
import pytest

vm_context: iree.runtime.VmContext = None

Expand Down Expand Up @@ -207,19 +208,26 @@ def run_test_sharded_resnet_block_with_iree(
parameters_path=parameters_path,
)
assert len(actual_result.shards) == len(expected_result.shards)
# TODO: reenable this check once numerical issues are resolved.
# See https://github.com/iree-org/iree/issues/18283
# for actual_shard, expected_shard in zip(
# actual_result.shards, expected_result.shards
# ):
# torch.testing.assert_close(
# unbox_tensor(actual_shard), unbox_tensor(expected_shard)
# )
# TODO: reenable this test once numerical issues are resolved.
# The absolute accuracy is > 0.00042. Is this good enough?
# Maybe add a test with fp64, where if the accuracy is high would give us more
# confidence that fp32 is also OK.
for actual_shard, expected_shard in zip(
actual_result.shards, expected_result.shards
):
torch.testing.assert_close(
unbox_tensor(actual_shard), unbox_tensor(expected_shard)
)

global vm_context
del vm_context


@pytest.mark.xfail(
reason="Maybe numerical issues with low accuracy.",
strict=True,
raises=AssertionError,
)
def test_sharded_resnet_block_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
Expand Down

0 comments on commit b55065a

Please sign in to comment.