diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index d7417e8f3..9f5eb9843 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -874,3 +874,17 @@ def sign(self, tensor: Tensor) -> Tensor: def item(self, tensor): return tensor.item() + + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + """ + Returns the power of tensor a to the value of b. + In the case b is a tensor, then the power is by element + with a as the base and b as the exponent. + In the case b is a scalar, then the power of each value in a + is raised to the exponent of b. + + Args: + a: The tensor that contains the base. + b: The tensor that contains the exponent or a single scalar. + """ + return jnp.power(a, b) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 13c74247c..615bba2e6 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -974,7 +974,6 @@ def matvec_jax(vector, matrix): num_krylov_vecs=100, tol=0.0001) - def test_sum(): np.random.seed(10) backend = jax_backend.JaxBackend() @@ -1240,3 +1239,19 @@ def test_item(dtype): backend = jax_backend.JaxBackend() tensor = backend.randn((1,), dtype=dtype, seed=10) assert backend.item(tensor) == tensor.item() + +@pytest.mark.parametrize("dtype", np_dtypes) +def test_power(dtype): + shape = (4, 3, 2) + backend = jax_backend.JaxBackend() + base_tensor = backend.randn(shape, dtype=dtype, seed=10) + power_tensor = backend.randn(shape, dtype=dtype, seed=10) + actual = backend.power(base_tensor, power_tensor) + expected = jax.numpy.power(base_tensor, power_tensor) + np.testing.assert_allclose(expected, actual) + + power = np.random.rand(1)[0] + actual = backend.power(base_tensor, power) + expected = jax.numpy.power(base_tensor, power) + np.testing.assert_allclose(expected, actual) + \ No newline at end of file