Skip to content

Commit

Permalink
Fix the tests with exceptions assertions having multiple top-level st…
Browse files Browse the repository at this point in the history
…atements (#1264)

Summary:
Pull Request resolved: #1264

As titled. Add error message checks as well for test comprehensity and code readability.

B908: Contexts with exceptions assertions like with self.assertRaises or with pytest.raises should not have multiple top-level statements. Each statement should be in its own context. That way, the test ensures that the exception is raised only in the exact statement where you expect it.

Reviewed By: vivekmig

Differential Revision: D55344319

fbshipit-source-id: aad315a15f764a9fb24d653c46c3940ae99248e9
  • Loading branch information
cicichen01 authored and facebook-github-bot committed Mar 26, 2024
1 parent 88f4b0a commit dab9447
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions tests/attr/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,56 @@

class Test(BaseTest):
def test_validate_input(self) -> None:
with self.assertRaises(AssertionError):
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))
with self.assertRaises(AssertionError) as err:
_validate_input(
(torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0, 0.0, 1.0]),)
)
self.assertEqual(
"Baseline can be provided as a tensor for just one input and "
"broadcasted to the batch or input and baseline must have the "
"same shape or the baseline corresponding to each input tensor "
"must be a scalar. Found baseline: tensor([-2., 0., 1.]) and "
"input: tensor([-1., 1.])",
str(err.exception),
)

with self.assertRaises(AssertionError) as err:
_validate_input(
(torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1
)
self.assertEqual(
"The number of steps must be a positive integer. Given: -1",
str(err.exception),
)

with self.assertRaises(AssertionError) as err:
_validate_input(
(torch.tensor([-1.0, 1.0]),),
(torch.tensor([-1.0, 1.0]),),
method="abcde",
)
self.assertIn(
"Approximation method must be one for the following",
str(err.exception),
)
# any baseline which is broadcastable to match the input is supported, which
# includes a scalar / single-element tensor.
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))
_validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),))
_validate_input(
(torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre"
)

def test_validate_nt_type(self) -> None:
with self.assertRaises(AssertionError):
with self.assertRaises(
AssertionError,
) as err:
_validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES)
self.assertIn(
"Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.",
str(err.exception),
)

_validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES)
_validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES)
_validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES)

0 comments on commit dab9447

Please sign in to comment.