Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

Added power function to Jax backend and test. #884

Merged
merged 13 commits into from
Dec 7, 2020
Merged
14 changes: 14 additions & 0 deletions tensornetwork/backends/jax/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,17 @@ def sign(self, tensor: Tensor) -> Tensor:

def item(self, tensor):
return tensor.item()

def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor:
"""
Returns the power of tensor a to the value of b.
In the case b is a tensor, then the power is by element
with a as the base and b as the exponent.
In the case b is a scalar, then the power of each value in a
is raised to the exponent of b.

Args:
a: The tensor that contains the base.
b: The tensor that contains the exponent or a single scalar.
"""
return jnp.power(a, b)
17 changes: 16 additions & 1 deletion tensornetwork/backends/jax/jax_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,6 @@ def matvec_jax(vector, matrix):
num_krylov_vecs=100,
tol=0.0001)


def test_sum():
np.random.seed(10)
backend = jax_backend.JaxBackend()
Expand Down Expand Up @@ -1240,3 +1239,19 @@ def test_item(dtype):
backend = jax_backend.JaxBackend()
tensor = backend.randn((1,), dtype=dtype, seed=10)
assert backend.item(tensor) == tensor.item()

@pytest.mark.parametrize("dtype", np_dtypes)
def test_power(dtype):
shape = (4, 3, 2)
backend = jax_backend.JaxBackend()
base_tensor = backend.randn(shape, dtype=dtype, seed=10)
power_tensor = backend.randn(shape, dtype=dtype, seed=10)
actual = backend.power(base_tensor, power_tensor)
expected = jax.numpy.power(base_tensor, power_tensor)
np.testing.assert_allclose(expected, actual)

power = np.random.rand(1)[0]
actual = backend.power(base_tensor, power)
expected = jax.numpy.power(base_tensor, power)
np.testing.assert_allclose(expected, actual)