diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index f23336ff1..56fa32774 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -106,6 +106,12 @@ def replace_rng_sequence_state_example(): class BaseTest(parameterized.TestCase): + def assert_keys_equal(self, a, b): + self.assertEqual(jax.random.key_impl(a), jax.random.key_impl(b)) + np.testing.assert_array_equal( + jax.random.key_data(a), jax.random.key_data(b) + ) + @test_utils.transform_and_run def test_parameter_reuse(self): w1 = base.get_parameter("w", [], init=jnp.zeros) @@ -646,7 +652,7 @@ def test_prng_reserve(self): s.reserve(10) hk_keys = tuple(next(s) for _ in range(10)) jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:]) - jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys) def test_prng_reserve_twice(self): k = jax.random.PRNGKey(42) @@ -657,14 +663,14 @@ def test_prng_reserve_twice(self): k, subkey1, subkey2 = tuple(jax.random.split(test_utils.clone(k), num=3)) _, subkey3, subkey4 = tuple(jax.random.split(k, num=3)) jax_keys = (subkey1, subkey2, subkey3, subkey4) - jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys) def test_prng_sequence_split(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) hk_keys = s.take(10) jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:]) - jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(self.assert_keys_equal, hk_keys, jax_keys) @parameterized.parameters(42, 28) def test_with_rng(self, seed): @@ -782,7 +788,7 @@ def test_rng_reserve_size(self): for _ in range(2): split_key, *expected_keys = jax.random.split(split_key, size+1) hk_keys = hk.next_rng_keys(size) - np.testing.assert_array_equal(hk_keys, expected_keys) + jax.tree.map(self.assert_keys_equal, list(hk_keys), expected_keys) @parameterized.parameters( base.get_params, base.get_current_state, base.get_initial_state diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index c4e3f03c2..e3be583cf 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -87,26 +87,37 @@ def wrapper(*a, **kw): class StatefulTest(parameterized.TestCase): + def assert_keys_equal(self, a, b): + self.assertEqual(jax.random.key_impl(a), jax.random.key_impl(b)) + np.testing.assert_array_equal( + jax.random.key_data(a), jax.random.key_data(b) + ) + + def assert_keys_not_equal(self, a, b): + self.assertFalse( + (jax.random.key_impl(a) == jax.random.key_impl(b)) and + (jnp.all(jax.random.key_data(a) == jax.random.key_data(b)))) + @test_utils.transform_and_run def test_grad(self): - x = jnp.array(3.) + x = jnp.array(3.0) g = stateful.grad(SquareModule())(x) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_grad_no_transform(self): - x = jnp.array(3.) + x = jnp.array(3.0) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.grad(jnp.square)(x) @test_utils.transform_and_run def test_value_and_grad(self): - x = jnp.array(2.) + x = jnp.array(2.0) y, g = stateful.value_and_grad(SquareModule())(x) - self.assertEqual(y, x ** 2) + self.assertEqual(y, x**2) np.testing.assert_allclose(g, 2 * x, rtol=1e-4) def test_value_and_grad_no_transform(self): - x = jnp.array(3.) + x = jnp.array(3.0) with self.assertRaises(ValueError, msg="Use jax.grad() instead"): stateful.value_and_grad(jnp.square)(x) @@ -645,12 +656,12 @@ def test_vmap_no_split_rng(self): x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() - np.testing.assert_array_equal(k1, k2) - np.testing.assert_array_equal(k2, k3) - np.testing.assert_array_equal(k3, k4) - self.assertFalse(np.array_equal(key_before, k1)) - self.assertFalse(np.array_equal(key_after, k1)) - self.assertFalse(np.array_equal(key_before, key_after)) + self.assert_keys_equal(k1, k2) + self.assert_keys_equal(k2, k3) + self.assert_keys_equal(k3, k4) + self.assert_keys_not_equal(key_before, k1) + self.assert_keys_not_equal(key_after, k1) + self.assert_keys_not_equal(key_before, key_after) @test_utils.transform_and_run def test_vmap_split_rng(self):