From 4182a20e54c115130c9b715b81dcb66c40aa5664 Mon Sep 17 00:00:00 2001 From: paul0403 <79805239+paul0403@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:33:04 -0400 Subject: [PATCH] Add `static_argnames` option to qjit (#1158) **Context:** Adding a `static_argnames` option to qjit for users to configure static arguments by name. **Description of the Change:** Under the hood, this just maps the `static_argnames` to their argument indices and add to `static_argnums`. **Benefits:** Users can specify static arguments to jitted functions by name. **Possible Drawbacks:** Even more keyword arguments to qjit... [sc-41335] --- doc/releases/changelog-dev.md | 24 ++++++ frontend/catalyst/compiler.py | 3 + frontend/catalyst/jit.py | 10 +++ frontend/catalyst/tracing/type_signatures.py | 32 ++++++++ frontend/test/pytest/test_static_arguments.py | 77 ++++++++++++++++++- 5 files changed, 145 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4b25ea8cb5..d817f0d661 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -185,6 +185,30 @@ Array([2, 4, 6], dtype=int64) ``` +* Static arguments of a qjit-compiled function can now be indicated by a `static_argnames` + argument to `qjit`. + [(#1158)](https://github.com/PennyLaneAI/catalyst/pull/1158) + + ```python + @qjit(static_argnames="y") + def f(x, y): + if y < 10: # y needs to be marked as static since its concrete boolean value is needed + return x + y + + @qjit(static_argnames=["x","y"]) + def g(x, y): + if x < 10 and y < 10: + return x + y + + res_f = f(1, 2) + res_g = g(3, 4) + print(res_f, res_g) + ``` + + ```pycon + 3 7 + ``` +

Improvements

* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface. diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index e49363c719..59a4634ada 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -70,6 +70,8 @@ class CompileOptions: the main compilation pipeline is complete. Default is ``True``. static_argnums (Optional[Union[int, Iterable[int]]]): indices of static arguments. Default is ``None``. + static_argnames (Optional[Union[str, Iterable[str]]]): names of static arguments. + Default is ``None``. abstracted_axes (Optional[Any]): store the abstracted_axes value. Defaults to ``None``. disable_assertions (Optional[bool]): disables all assertions. Default is ``False``. seed (Optional[int]) : the seed for random operations in a qjit call. @@ -92,6 +94,7 @@ class CompileOptions: autograph_include: Optional[Iterable[str]] = () async_qnodes: Optional[bool] = False static_argnums: Optional[Union[int, Iterable[int]]] = None + static_argnames: Optional[Union[str, Iterable[str]]] = None abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None lower_to_llvm: Optional[bool] = True checkpoint_stage: Optional[str] = "" diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 27e24d58bf..d03f8305ab 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -45,6 +45,7 @@ filter_static_args, get_abstract_signature, get_type_annotations, + merge_static_argname_into_argnum, merge_static_args, promote_arguments, verify_static_argnums, @@ -82,6 +83,7 @@ def qjit( logfile=None, pipelines=None, static_argnums=None, + static_argnames=None, abstracted_axes=None, disable_assertions=False, seed=None, @@ -124,6 +126,8 @@ def qjit( considered to be used by advanced users for low-level debugging purposes. static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the positions of static arguments. + static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the + names of static arguments. abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): An experimental option to specify dynamic tensor shapes. This option affects the compilation of the annotated function. @@ -482,6 +486,12 @@ def __init__(self, fn, compile_options): self.user_sig = get_type_annotations(fn) self._validate_configuration() + # If static_argnames are present, convert them to static_argnums + if compile_options.static_argnames is not None: + compile_options.static_argnums = merge_static_argname_into_argnum( + fn, compile_options.static_argnames, compile_options.static_argnums + ) + # Patch the conversion rules by adding the included modules before the block list include_convertlist = tuple( ag_config.Convert(rule) for rule in self.compile_options.autograph_include diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index e840c79781..8823d7b8fb 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -156,6 +156,38 @@ def split_static_args(args, static_argnums): return tuple(dynamic_args), tuple(static_args) +def merge_static_argname_into_argnum(fn: Callable, static_argnames, static_argnums): + """Map static_argnames of the callable to the corresponding argument indices, + and add them to static_argnums""" + new_static_argnums = [] if (static_argnums is None) else list(static_argnums) + fn_argnames = list(inspect.signature(fn).parameters.keys()) + + # static_argnames can be a single str, or a list/tuple of strs + # convert all of them to list + if isinstance(static_argnames, str): + static_argnames = [static_argnames] + + non_existent_args = [] + for static_argname in static_argnames: + if static_argname in fn_argnames: + new_static_argnums.append(fn_argnames.index(static_argname)) + continue + non_existent_args.append(static_argname) + + if non_existent_args: + non_existent_args_str = "{" + ", ".join(repr(item) for item in non_existent_args) + "}" + + raise ValueError( + f"qjitted function has invalid argname {non_existent_args_str} in static_argnames. " + "Function does not take these args." + ) + + # Remove potential duplicates from static_argnums and static_argnames + new_static_argnums = tuple(sorted(set(new_static_argnums))) + + return new_static_argnums + + def merge_static_args(signature, args, static_argnums): """Merge static arguments back into an abstract signature, retaining the original ordering. diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py index d615022c9f..4b0b05517e 100644 --- a/frontend/test/pytest/test_static_arguments.py +++ b/frontend/test/pytest/test_static_arguments.py @@ -19,7 +19,7 @@ import pennylane as qml import pytest -from catalyst import qjit +from catalyst import grad, qjit from catalyst.utils.exceptions import CompileError @@ -221,6 +221,81 @@ def wrapper(x, c): captured = capsys.readouterr() assert captured.out.strip() == "Inside QNode: 0.5" + def test_static_argnames(self): + # pylint: disable=unused-argument, function-redefined + """Test static arguments specified by names""" + + @qjit(static_argnames="y") + def f(x, y): + return + + assert set(f.compile_options.static_argnums) == {1} + + with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"): + + @qjit(static_argnames="yy") + def f_badname(x, y): + return + + with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"): + + @qjit(static_argnames=["y", "yy"]) + def f_badname_list(x, y): + return + + with pytest.raises(ValueError, match="qjitted function has invalid argname {'xx', 'yy'}"): + + @qjit(static_argnames=["xx", "yy"]) + def f_badname_list(x, y): + return + + @qjit(static_argnames=("x", "y")) + def f(x, y): + return + + assert set(f.compile_options.static_argnums) == {0, 1} + + @qjit(static_argnames=("x"), static_argnums=[1]) + def f(x, y): + return + + assert set(f.compile_options.static_argnums) == {0, 1} + + @qjit(static_argnames=("y"), static_argnums=[0]) + def f(x, y): + return + + assert set(f.compile_options.static_argnums) == {0, 1} + + @qjit(static_argnames=("y"), static_argnums=[1]) + def f(x, y): + return + + assert set(f.compile_options.static_argnums) == {1} + + def test_static_argnames_with_decorator(self): + # pylint: disable=unused-argument, function-redefined + """Test static arguments specified by names + on functions with decorators""" + + dev = qml.device("lightning.qubit", wires=3) + + @qjit(static_argnames="theta") + @qml.qnode(dev) + def f(theta, phi): + qml.RX(theta, wires=0) + qml.RY(phi, wires=1) + return qml.probs() + + assert set(f.compile_options.static_argnums) == {0} + + @qjit(static_argnames=("x", "y")) + @grad + def f(x, y): + return x * y + + assert set(f.compile_options.static_argnums) == {0, 1} + if __name__ == "__main__": pytest.main(["-x", __file__])