Skip to content

Commit

Permalink
Feature-pass-static-argnums-to-qnode (#932)
Browse files Browse the repository at this point in the history
**Context:**
Static argnum is not correctly passed through QNode. Currently if we use
`@qjit(static_argnums=(1,))` decorator right before `@qml.qnode(dev)`,
the static argument will still be traced within the QNode.

**Description of the Change:**
Passed the static_argnums from the compile options through qnode
__call__() method and then using it in deduce_avals to avoid tracing the
static arguments.
Also improved the verification and preparation of static_argnums.
**Benefits:**
Added ability to use static arguments inside qnodes

**Possible Drawbacks:**
It can potentially cause confusion for jax users since passing
`static_argnums` through nested calls to `jax.jit()` is not supported in
jax. e.g.
```
@partial(jax.jit, static_argnums=(1,))
@jax.jit
def foo(x, c):
    print("Inside QNode:", c)
    return x + c
```

```
>>> foo(0.5, 0.5)  
>>> Inside QNode: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>
```

which means that parameter c inside foo is still traced.

**Related GitHub Issues:**
#902

[sc-67808]

---------

Co-authored-by: erick-xanadu <[email protected]>
Co-authored-by: David Ittah <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 83745f5 commit 39f81b9
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 17 deletions.
20 changes: 20 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@
the python function being called back into.
[(#919)](https://github.com/PennyLaneAI/catalyst/pull/919)

* Static_argnums now can be passed through a QNode
[(#932)](https://github.com/PennyLaneAI/catalyst/pull/932)

```python
dev = qml.device("lightning.qubit", wires=1)

@qjit(static_argnums=(1,))
@qml.qnode(dev)
def circuit(x, c):
print("Inside QNode:", c)
qml.RY(c, 0)
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))
```

```pycon
>>> circuit(0.5, 0.5)
>>> "Inside QNode: 0.5"
```

* Autograph now supports in-place array assignments with static slices. [(#843)](https://github.com/PennyLaneAI/catalyst/pull/843)

For example,
Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __post_init__(self):
self.static_argnums = ()
elif isinstance(static_argnums, int):
self.static_argnums = (static_argnums,)
elif isinstance(static_argnums, Iterable):
self.static_argnums = tuple(static_argnums)

def __deepcopy__(self, memo):
"""Make a deep copy of all fields of a CompileOptions object except the logfile, which is
Expand Down
9 changes: 7 additions & 2 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@

from catalyst.jax_extras.patches import _gather_shape_rule_dynamic, get_aval2
from catalyst.logging import debug_logger
from catalyst.tracing.type_signatures import verify_static_argnums_type
from catalyst.utils.patching import Patcher

# pylint: disable=protected-access
Expand Down Expand Up @@ -440,17 +441,21 @@ def deduce_signatures(
)


def deduce_avals(f: Callable, args, kwargs):
def deduce_avals(f: Callable, args, kwargs, static_argnums=None):
"""Wraps the callable ``f`` into a WrappedFun container accepting collapsed flatten arguments
and returning expanded flatten results. Calculate input abstract values and output_tree promise.
The promise must be called after the resulting wrapped function is evaluated."""
# TODO: deprecate in favor of `deduce_signatures`
wf = wrap_init(f)
if static_argnums:
verify_static_argnums_type(static_argnums)
dynamic_argnums = [i for i in range(len(args)) if i not in static_argnums]
wf, args = jax._src.api_util.argnums_partial(wf, dynamic_argnums, args)
flat_args, in_tree = tree_flatten((args, kwargs))
abstracted_axes = None
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
in_type = infer_lambda_input_type(axes_specs, flat_args)
in_avals, keep_inputs = unzip2(in_type)
wf = wrap_init(f)
wff, out_tree_promise = flatten_fun(wf, in_tree)
wffa = annotate(wff, in_type)
return wffa, in_avals, keep_inputs, out_tree_promise
Expand Down
6 changes: 4 additions & 2 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def trace_function(

@debug_logger
def trace_quantum_function(
f: Callable, device: QubitDevice, args, kwargs, qnode
f: Callable, device: QubitDevice, args, kwargs, qnode, static_argnums
) -> Tuple[ClosedJaxpr, Any]:
"""Trace quantum function in a way that allows building a nested quantum tape describing the
quantum algorithm.
Expand All @@ -1090,7 +1090,9 @@ def trace_quantum_function(
# (1) - Classical tracing
quantum_tape = QuantumTape(shots=device.shots)
with EvaluationContext.frame_tracing_context(ctx) as trace:
wffa, in_avals, keep_inputs, out_tree_promise = deduce_avals(f, args, kwargs)
wffa, in_avals, keep_inputs, out_tree_promise = deduce_avals(
f, args, kwargs, static_argnums
)
in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
with QueuingManager.stop_recording(), quantum_tape:
# Quantum tape transformations happen at the end of tracing
Expand Down
19 changes: 11 additions & 8 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_type_annotations,
merge_static_args,
promote_arguments,
verify_static_argnums,
)
from catalyst.utils.c_template import mlir_type_to_numpy_type
from catalyst.utils.exceptions import CompileError
Expand Down Expand Up @@ -494,6 +495,10 @@ def __init__(self, fn, compile_options):
def __call__(self, *args, **kwargs):
# Transparantly call Python function in case of nested QJIT calls.
if EvaluationContext.is_tracing():
isQNode = isinstance(self.user_function, qml.QNode)
if isQNode and self.compile_options.static_argnums:
kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs}

return self.user_function(*args, **kwargs)

requires_promotion = self.jit_compile(args)
Expand Down Expand Up @@ -613,16 +618,20 @@ def capture(self, args):
Tuple[Any]: the dynamic argument signature
"""

self._verify_static_argnums(args)
verify_static_argnums(args, self.compile_options.static_argnums)
static_argnums = self.compile_options.static_argnums
abstracted_axes = self.compile_options.abstracted_axes

dynamic_args = filter_static_args(args, static_argnums)
dynamic_sig = get_abstract_signature(dynamic_args)
full_sig = merge_static_args(dynamic_sig, args, static_argnums)

def closure(*args, **kwargs):
st_argnums = kwargs.pop("static_argnums", static_argnums)
return QFunc.__call__(*args, static_argnums=st_argnums, **kwargs)

with Patcher(
(qml.QNode, "__call__", QFunc.__call__),
(qml.QNode, "__call__", closure),
):
# TODO: improve PyTree handling
jaxpr, out_type, treedef = trace_to_jaxpr(
Expand Down Expand Up @@ -717,12 +726,6 @@ def _validate_configuration(self):
"In order for 'autograph_include' to work, 'autograph' must be set to True"
)

def _verify_static_argnums(self, args):
for argnum in self.compile_options.static_argnums:
if argnum < 0 or argnum >= len(args):
msg = f"argnum {argnum} is beyond the valid range of [0, {len(args)})."
raise CompileError(msg)

def _get_workspace(self):
"""Get or create a workspace to use for compilation."""

Expand Down
15 changes: 11 additions & 4 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
from catalyst.logging import debug_logger
from catalyst.tracing.type_signatures import filter_static_args
from catalyst.utils.toml import DeviceCapabilities, ProgramFeatures

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,18 +136,24 @@ def __call__(self, *args, **kwargs):
else:
qjit_device = QJITDevice(self.device, device_capabilities, backend_info)

static_argnums = kwargs.pop("static_argnums", ())

def _eval_quantum(*args):
closed_jaxpr, out_type, out_tree = trace_quantum_function(
self.func, qjit_device, args, kwargs, self
self.func, qjit_device, args, kwargs, self, static_argnums
)
args_expanded = get_implicit_and_explicit_flat_args(None, *args)
dynamic_args = filter_static_args(args, static_argnums)
args_expanded = get_implicit_and_explicit_flat_args(None, *dynamic_args)
res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)
_, out_keep = unzip2(out_type)
res_flat = [r for r, k in zip(res_expanded, out_keep) if k]
return tree_unflatten(out_tree, res_flat)

flattened_fun, _, _, out_tree_promise = deduce_avals(_eval_quantum, args, {})
args_flat = tree_flatten(args)[0]
flattened_fun, _, _, out_tree_promise = deduce_avals(
_eval_quantum, args, {}, static_argnums
)
dynamic_args = filter_static_args(args, static_argnums)
args_flat = tree_flatten(dynamic_args)[0]
res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
return tree_unflatten(out_tree_promise(), res_flat)[0]

Expand Down
39 changes: 39 additions & 0 deletions frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax.tree_util import tree_flatten, tree_unflatten

from catalyst.jax_extras import get_aval2
from catalyst.utils.exceptions import CompileError
from catalyst.utils.patching import Patcher


Expand Down Expand Up @@ -72,6 +73,44 @@ def get_abstract_signature(args):
return tree_unflatten(treedef, abstract_args)


def verify_static_argnums_type(static_argnums):
"""Verify that static_argnums have correct type.
Args:
static_argnums (Iterable[int]): indices to verify
Returns:
None
"""
is_tuple = isinstance(static_argnums, tuple)
is_valid = is_tuple and all(isinstance(arg, int) for arg in static_argnums)
if not is_valid:
raise TypeError(
"The `static_argnums` argument to `qjit` must be an int or convertable to a"
f"tuple of ints, but got value {static_argnums}"
)
return None


def verify_static_argnums(args, static_argnums):
"""Verify that static_argnums have correct type and range.
Args:
args (Iterable): arguments to a compiled function
static_argnums (Iterable[int]): indices to verify
Returns:
None
"""
verify_static_argnums_type(static_argnums)

for argnum in static_argnums:
if argnum < 0 or argnum >= len(args):
msg = f"argnum {argnum} is beyond the valid range of [0, {len(args)})."
raise CompileError(msg)
return None


def filter_static_args(args, static_argnums):
"""Remove static values from arguments using the provided index list.
Expand Down
87 changes: 86 additions & 1 deletion frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from dataclasses import dataclass

import pennylane as qml
import pytest

from catalyst import qjit
Expand Down Expand Up @@ -53,7 +54,7 @@ def f(x: int):

assert f(1) == 1

@pytest.mark.parametrize("argnums", [-1, 100])
@pytest.mark.parametrize("argnums", [-2, 100])
def test_out_of_bounds_static_argument(self, argnums):
"""Test QJIT with invalid static argument index with respect to provided arguments."""

Expand All @@ -64,6 +65,17 @@ def f(x):
with pytest.raises(CompileError, match="is beyond the valid range"):
f(5)

@pytest.mark.parametrize("argnums", [1.0, [1.0], ["x"]])
def test_unsopported_type_static_argument(self, argnums):
"""Test QJIT with invalid static argument type."""

@qjit(static_argnums=argnums)
def f(x, y):
return x + y

with pytest.raises(TypeError, match="The `static_argnums` argument to"):
f(5, 6)

def test_one_static_argument(self):
"""Test QJIT with one static argument."""

Expand Down Expand Up @@ -136,6 +148,79 @@ def f(x: int, y: MyClass):
assert f(1, my_obj) == 9
assert function != f.compiled_function

def test_qnode_with_static_arguments(self, capsys):
"""Test if QJIT static arguments pass through QNode correctly."""
dev = qml.device("lightning.qubit", wires=1)

@qjit(static_argnums=(1,))
@qml.qnode(dev)
def circuit(x, c):
print("Inside QNode:", c)
qml.RY(c, 0)
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

circuit(0.5, 0.5)
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_qnode_nested_with_static_arguments(self, capsys):
"""Test if QJIT static arguments pass through QNode correctly."""
dev = qml.device("lightning.qubit", wires=1)

@qjit(static_argnums=(1,))
@qml.qnode(dev)
def circuit(x, c):
print("Inside QNode:", c)
qml.RY(c, 0)
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

@qjit(static_argnums=(1,))
def wrapper(x, c):
return circuit(x, c)

wrapper(0.5, 0.5)
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_qnode_switch_params(self, capsys):
"""Test if QJIT static arguments pass through QNode correctly when params are switched."""
dev = qml.device("lightning.qubit", wires=1)

@qjit(static_argnums=(0,))
@qml.qnode(dev)
def circuit(c, x):
print("Inside QNode:", c)
qml.RY(c, 0)
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

@qjit(static_argnums=(1,))
def wrapper(x, c):
return circuit(c, x)

wrapper(0.5, 0.5)
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_qnode_nested_not_qnode(self, capsys):
"""Test if QJIT static arguments pass through nested Qjit calls with no QNodes."""
dev = qml.device("lightning.qubit", wires=1)

@qjit(static_argnums=(0,))
def circuit(c, x):
print("Inside QNode:", c)
return x * c

@qjit(static_argnums=(1,))
def wrapper(x, c):
return circuit(c, x)

wrapper(0.5, 0.5)
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"


if __name__ == "__main__":
pytest.main(["-x", __file__])

0 comments on commit 39f81b9

Please sign in to comment.