Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JVP/VJP type checking in Catalyst frontend #1031

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,29 @@
[(#1020)](https://github.com/PennyLaneAI/catalyst/pull/1020)
[(#1030)](https://github.com/PennyLaneAI/catalyst/pull/1030)

* Add type checking and improve error messaging in the frontend `catalyst.jvp` and
`catalyst.vjp` functions.
[(#1031)](https://github.com/PennyLaneAI/catalyst/pull/1031)

```python
from catalyst import qjit, jvp

def foo(x):
return 2 * x, x * x

@qjit()
def workflow(x: float):
return jvp(foo, (x,), (1,))
# ^
# Expected tangent dtype float, but got int
```

```
TypeError: function params and tangents arguments to catalyst.jvp do not match;
dtypes must be equal. Got function params dtype float64 and so expected tangent
dtype float64, but got tangent dtype int64 instead.
```

<h3>Breaking changes</h3>

* Return values of qjit-compiled functions that were previously `numpy.ndarray` are now of type
Expand Down
52 changes: 52 additions & 0 deletions frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import Callable, Iterable, List, Optional, Union

import jax
import jax.numpy as jnp
from jax._src.api import _dtype
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from pennylane import QNode

Expand Down Expand Up @@ -455,6 +457,32 @@ def check_is_iterable(x, hint):
tangents_flatten, _ = tree_flatten(tangents)
grad_params = _check_grad_params(method, scalar_out, h, argnum, len(args_flatten), in_tree)

if len(tangents_flatten) != len(grad_params.expanded_argnum):
raise TypeError(
"number of tangent and number of differentiable parameters in catalyst.jvp do not "
"match; the number of parameters must be equal. "
f"Got {len(grad_params.expanded_argnum)} differentiable parameters and so expected "
f"as many tangents, but got {len(tangents_flatten)} instead."
)

# Only check dtypes and shapes of parameters marked as differentiable by the `argnum` param
args_to_check = [args_flatten[i] for i in grad_params.argnum]

for p, t in zip(args_to_check, tangents_flatten):
if _dtype(p) != _dtype(t):
raise TypeError(
"function params and tangents arguments to catalyst.jvp do not match; "
"dtypes must be equal. "
f"Got function params dtype {_dtype(p)} and so expected tangent dtype "
f"{_dtype(p)}, but got tangent dtype {_dtype(t)} instead."
)

if jnp.shape(p) != jnp.shape(t):
raise ValueError(
"catalyst.jvp called with different function params and tangent shapes; "
f"got function params shape {jnp.shape(p)} and tangent shape {jnp.shape(t)}"
)

jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *params)

results = jvp_p.bind(
Expand Down Expand Up @@ -542,6 +570,30 @@ def check_is_iterable(x, hint):

jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *params)

if len(jaxpr.out_avals) != len(cotangents_flatten):
raise TypeError(
"number of cotangent and number of function output parameters in catalyst.vjp do "
"not match; the number of parameters must be equal. "
f"Got {len(jaxpr.out_avals)} function output parameters and so expected as many "
f"cotangents, but got {len(cotangents_flatten)} instead."
)

for p, t in zip(jaxpr.out_avals, cotangents_flatten):
if _dtype(p) != _dtype(t):
raise TypeError(
"function output params and cotangents arguments to catalyst.vjp do not match; "
"dtypes must be equal. "
f"Got function output params dtype {_dtype(p)} and so expected cotangent dtype "
f"{_dtype(p)}, but got cotangent dtype {_dtype(t)} instead."
)

if jnp.shape(p) != jnp.shape(t):
raise ValueError(
"catalyst.vjp called with different function output params and cotangent "
f"shapes; got function output params shape {jnp.shape(p)} and cotangent shape "
f"{jnp.shape(t)}"
)

cotangents, _ = tree_flatten(cotangents)

results = vjp_p.bind(
Expand Down
164 changes: 164 additions & 0 deletions frontend/test/pytest/test_jvpvjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import jvp as J_jvp
from jax import vjp as J_vjp
from jax.tree_util import tree_flatten, tree_unflatten
from numpy.testing import assert_allclose

Check notice on line 25 in frontend/test/pytest/test_jvpvjp.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_jvpvjp.py#L25

third party import "from numpy.testing import assert_allclose" should be placed before "import jax" (wrong-import-order)

from catalyst import jvp as C_jvp
from catalyst import qjit
Expand All @@ -39,6 +39,16 @@
return qml.expval(qml.PauliY(0))


def f_R1_to_R2(x):
"""A test function f : R1 -> R2"""
return 2 * x, x * x


def g_R3_to_R2(_n, x):
"""A test function g : R3 -> R2, where `_n` is a dummy, non-differentiable parameter"""
return jnp.stack([1 + x[0] + 2 * x[1] + 3 * x[2], 1 + x[0] + 2 * x[1] ** 2 + 3 * x[2] ** 3])


diff_methods = ["auto", "fd"]


Expand Down Expand Up @@ -832,5 +842,159 @@
assert_allclose(r_j, r_c)


rmoyard marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("diff_method", diff_methods)
def test_jvp_argument_type_checks_correct_inputs(diff_method):
"""Test that Catalyst's jvp can JIT compile when given the correct types."""

@qjit
def C_workflow_f():
x = (1.0,)
tangents = (1.0,)
return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnum=[0])

@qjit
def C_workflow_g():
x = jnp.array([2.0, 3.0, 4.0])
tangents = jnp.ones([3], dtype=float)
return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnum=[1])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_jvp_argument_type_checks_incompatible_n_inputs(diff_method):
"""Tests error handling of Catalyst's jvp when the number of differentiable params
and tangent arguments are incompatible.
"""

with pytest.raises(
TypeError,
match=(
"number of tangent and number of differentiable parameters in catalyst.jvp "
"do not match"
),
):

@qjit
def C_workflow():
# If `f` takes one differentiable param (argnum=[0]), then `tangents` must have length 1
x = (1.0,)
tangents = (1.0, 1.0)
return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnum=[0])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_jvp_argument_type_checks_incompatible_input_types(diff_method):
"""Tests error handling of Catalyst's jvp when the types of the differentiable
params and tangent arguments are incompatible.
"""

with pytest.raises(
TypeError, match="function params and tangents arguments to catalyst.jvp do not match"
):

@qjit
def C_workflow():
# If `x` has type float, then `tangents` should also have type float
x = (1.0,)
tangents = (1,)
return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnum=[0])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_jvp_argument_type_checks_incompatible_input_shapes(diff_method):
"""Tests error handling of Catalyst's jvp when the shapes of the differentiable
params and tangent arguments are incompatible.
"""

with pytest.raises(
ValueError, match="catalyst.jvp called with different function params and tangent shapes"
):

@qjit
def C_workflow():
# If `x` has shape (3,), then `tangents` must also have shape (3,),
# but it has shape (4,)
x = jnp.array([2.0, 3.0, 4.0])
tangents = jnp.ones([4], dtype=float)
return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnum=[1])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_vjp_argument_type_checks_correct_inputs(diff_method):
"""Test that Catalyst's vjp can JIT compile when given the correct types."""

@qjit
def C_workflow_f():
x = (1.0,)
cotangents = (1.0, 1.0)
return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnum=[0])

@qjit
def C_workflow_g():
x = jnp.array([2.0, 3.0, 4.0])
cotangents = jnp.ones([2], dtype=float)
return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnum=[1])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method):
"""Tests error handling of Catalyst's vjp when the number of function output params
and cotangent arguments are incompatible.
"""

with pytest.raises(
TypeError,
match=(
"number of cotangent and number of function output parameters in catalyst.vjp "
"do not match"
),
):

@qjit
def C_workflow():
# If `f` returns two outputs, then `cotangents` must have length 2
x = (1.0,)
cotangents = (1.0,)
return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnum=[0])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_vjp_argument_type_checks_incompatible_input_types(diff_method):
"""Tests error handling of Catalyst's vjp when the types of the function output params
and cotangent arguments are incompatible.
"""

with pytest.raises(
TypeError,
match="function output params and cotangents arguments to catalyst.vjp do not match",
):

@qjit
def C_workflow():
# If `x` has type float, then `cotangents` should also have type float
x = (1.0,)
cotangents = (1, 1)
return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnum=[0])


@pytest.mark.parametrize("diff_method", diff_methods)
def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method):
"""Tests error handling of Catalyst's vjp when the shapes of the function output params
and cotangent arguments are incompatible.
"""

with pytest.raises(
ValueError,
match="catalyst.vjp called with different function output params and cotangent shapes",
):

@qjit
def C_workflow():
# If `f` returns object with shape (2,), then `cotangents` must also have
# shape (2,), but it has shape (3,)
x = jnp.array([2.0, 3.0, 4.0])
cotangents = jnp.ones([3], dtype=float)
return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnum=[1])


if __name__ == "__main__":
pytest.main(["-x", __file__])
Loading