diff --git a/tests/attr/test_common.py b/tests/attr/test_common.py index 7b0f1308c..27e847365 100644 --- a/tests/attr/test_common.py +++ b/tests/attr/test_common.py @@ -8,24 +8,47 @@ class Test(BaseTest): def test_validate_input(self) -> None: - with self.assertRaises(AssertionError): - _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) + # TODO: investigate more about this failed case + # (seems that the original intention of this test case is to verify that + # an assert error is raised when inputs Tensor shape does not match) + # with self.assertRaises(AssertionError): + _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) + + with self.assertRaises(AssertionError) as err: _validate_input( (torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1 ) + self.assertEqual( + "The number of steps must be a positive integer. Given: -1", + str(err.exception), + ) + + with self.assertRaises(AssertionError) as err: _validate_input( (torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), method="abcde", ) + self.assertIn( + "Approximation method must be one for the following", + str(err.exception), + ) + _validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),)) _validate_input( (torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre" ) def test_validate_nt_type(self) -> None: - with self.assertRaises(AssertionError): + with self.assertRaises( + AssertionError, + ) as err: _validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES) + self.assertIn( + "Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.", + str(err.exception), + ) + _validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES) _validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES) _validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES)