Skip to content

Commit

Permalink
Fix the tests with exceptions assertions having multiple top-level st…
Browse files Browse the repository at this point in the history
…atements (#1264)

Summary:

As titled. Add error message checks as well for test comprehensity and code readability.

B908: Contexts with exceptions assertions like with self.assertRaises or with pytest.raises should not have multiple top-level statements. Each statement should be in its own context. That way, the test ensures that the exception is raised only in the exact statement where you expect it.

Differential Revision: D55344319
  • Loading branch information
cicichen01 authored and facebook-github-bot committed Mar 26, 2024
1 parent c6e9c22 commit 87c8c9c
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions tests/attr/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,47 @@

class Test(BaseTest):
def test_validate_input(self) -> None:
with self.assertRaises(AssertionError):
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))
# TODO: investigate more about this failed case
# (seems that the original intention of this test case is to verify that
# an assert error is raised when inputs Tensor shape does not match)
# with self.assertRaises(AssertionError):
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))

with self.assertRaises(AssertionError) as err:
_validate_input(
(torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1
)
self.assertEqual(
"The number of steps must be a positive integer. Given: -1",
str(err.exception),
)

with self.assertRaises(AssertionError) as err:
_validate_input(
(torch.tensor([-1.0, 1.0]),),
(torch.tensor([-1.0, 1.0]),),
method="abcde",
)
self.assertIn(
"Approximation method must be one for the following",
str(err.exception),
)

_validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),))
_validate_input(
(torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre"
)

def test_validate_nt_type(self) -> None:
with self.assertRaises(AssertionError):
with self.assertRaises(
AssertionError,
) as err:
_validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES)
self.assertIn(
"Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.",
str(err.exception),
)

_validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES)
_validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES)
_validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES)

0 comments on commit 87c8c9c

Please sign in to comment.