From 4182a20e54c115130c9b715b81dcb66c40aa5664 Mon Sep 17 00:00:00 2001
From: paul0403 <79805239+paul0403@users.noreply.github.com>
Date: Tue, 22 Oct 2024 14:33:04 -0400
Subject: [PATCH] Add `static_argnames` option to qjit (#1158)
**Context:**
Adding a `static_argnames` option to qjit for users to configure static
arguments by name.
**Description of the Change:**
Under the hood, this just maps the `static_argnames` to their argument
indices and add to `static_argnums`.
**Benefits:**
Users can specify static arguments to jitted functions by name.
**Possible Drawbacks:**
Even more keyword arguments to qjit...
[sc-41335]
---
doc/releases/changelog-dev.md | 24 ++++++
frontend/catalyst/compiler.py | 3 +
frontend/catalyst/jit.py | 10 +++
frontend/catalyst/tracing/type_signatures.py | 32 ++++++++
frontend/test/pytest/test_static_arguments.py | 77 ++++++++++++++++++-
5 files changed, 145 insertions(+), 1 deletion(-)
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4b25ea8cb5..d817f0d661 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -185,6 +185,30 @@
Array([2, 4, 6], dtype=int64)
```
+* Static arguments of a qjit-compiled function can now be indicated by a `static_argnames`
+ argument to `qjit`.
+ [(#1158)](https://github.com/PennyLaneAI/catalyst/pull/1158)
+
+ ```python
+ @qjit(static_argnames="y")
+ def f(x, y):
+ if y < 10: # y needs to be marked as static since its concrete boolean value is needed
+ return x + y
+
+ @qjit(static_argnames=["x","y"])
+ def g(x, y):
+ if x < 10 and y < 10:
+ return x + y
+
+ res_f = f(1, 2)
+ res_g = g(3, 4)
+ print(res_f, res_g)
+ ```
+
+ ```pycon
+ 3 7
+ ```
+
Improvements
* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface.
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index e49363c719..59a4634ada 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -70,6 +70,8 @@ class CompileOptions:
the main compilation pipeline is complete. Default is ``True``.
static_argnums (Optional[Union[int, Iterable[int]]]): indices of static arguments.
Default is ``None``.
+ static_argnames (Optional[Union[str, Iterable[str]]]): names of static arguments.
+ Default is ``None``.
abstracted_axes (Optional[Any]): store the abstracted_axes value. Defaults to ``None``.
disable_assertions (Optional[bool]): disables all assertions. Default is ``False``.
seed (Optional[int]) : the seed for random operations in a qjit call.
@@ -92,6 +94,7 @@ class CompileOptions:
autograph_include: Optional[Iterable[str]] = ()
async_qnodes: Optional[bool] = False
static_argnums: Optional[Union[int, Iterable[int]]] = None
+ static_argnames: Optional[Union[str, Iterable[str]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
checkpoint_stage: Optional[str] = ""
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index 27e24d58bf..d03f8305ab 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -45,6 +45,7 @@
filter_static_args,
get_abstract_signature,
get_type_annotations,
+ merge_static_argname_into_argnum,
merge_static_args,
promote_arguments,
verify_static_argnums,
@@ -82,6 +83,7 @@ def qjit(
logfile=None,
pipelines=None,
static_argnums=None,
+ static_argnames=None,
abstracted_axes=None,
disable_assertions=False,
seed=None,
@@ -124,6 +126,8 @@ def qjit(
considered to be used by advanced users for low-level debugging purposes.
static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the
positions of static arguments.
+ static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the
+ names of static arguments.
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]):
An experimental option to specify dynamic tensor shapes.
This option affects the compilation of the annotated function.
@@ -482,6 +486,12 @@ def __init__(self, fn, compile_options):
self.user_sig = get_type_annotations(fn)
self._validate_configuration()
+ # If static_argnames are present, convert them to static_argnums
+ if compile_options.static_argnames is not None:
+ compile_options.static_argnums = merge_static_argname_into_argnum(
+ fn, compile_options.static_argnames, compile_options.static_argnums
+ )
+
# Patch the conversion rules by adding the included modules before the block list
include_convertlist = tuple(
ag_config.Convert(rule) for rule in self.compile_options.autograph_include
diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py
index e840c79781..8823d7b8fb 100644
--- a/frontend/catalyst/tracing/type_signatures.py
+++ b/frontend/catalyst/tracing/type_signatures.py
@@ -156,6 +156,38 @@ def split_static_args(args, static_argnums):
return tuple(dynamic_args), tuple(static_args)
+def merge_static_argname_into_argnum(fn: Callable, static_argnames, static_argnums):
+ """Map static_argnames of the callable to the corresponding argument indices,
+ and add them to static_argnums"""
+ new_static_argnums = [] if (static_argnums is None) else list(static_argnums)
+ fn_argnames = list(inspect.signature(fn).parameters.keys())
+
+ # static_argnames can be a single str, or a list/tuple of strs
+ # convert all of them to list
+ if isinstance(static_argnames, str):
+ static_argnames = [static_argnames]
+
+ non_existent_args = []
+ for static_argname in static_argnames:
+ if static_argname in fn_argnames:
+ new_static_argnums.append(fn_argnames.index(static_argname))
+ continue
+ non_existent_args.append(static_argname)
+
+ if non_existent_args:
+ non_existent_args_str = "{" + ", ".join(repr(item) for item in non_existent_args) + "}"
+
+ raise ValueError(
+ f"qjitted function has invalid argname {non_existent_args_str} in static_argnames. "
+ "Function does not take these args."
+ )
+
+ # Remove potential duplicates from static_argnums and static_argnames
+ new_static_argnums = tuple(sorted(set(new_static_argnums)))
+
+ return new_static_argnums
+
+
def merge_static_args(signature, args, static_argnums):
"""Merge static arguments back into an abstract signature, retaining the original ordering.
diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py
index d615022c9f..4b0b05517e 100644
--- a/frontend/test/pytest/test_static_arguments.py
+++ b/frontend/test/pytest/test_static_arguments.py
@@ -19,7 +19,7 @@
import pennylane as qml
import pytest
-from catalyst import qjit
+from catalyst import grad, qjit
from catalyst.utils.exceptions import CompileError
@@ -221,6 +221,81 @@ def wrapper(x, c):
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"
+ def test_static_argnames(self):
+ # pylint: disable=unused-argument, function-redefined
+ """Test static arguments specified by names"""
+
+ @qjit(static_argnames="y")
+ def f(x, y):
+ return
+
+ assert set(f.compile_options.static_argnums) == {1}
+
+ with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):
+
+ @qjit(static_argnames="yy")
+ def f_badname(x, y):
+ return
+
+ with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):
+
+ @qjit(static_argnames=["y", "yy"])
+ def f_badname_list(x, y):
+ return
+
+ with pytest.raises(ValueError, match="qjitted function has invalid argname {'xx', 'yy'}"):
+
+ @qjit(static_argnames=["xx", "yy"])
+ def f_badname_list(x, y):
+ return
+
+ @qjit(static_argnames=("x", "y"))
+ def f(x, y):
+ return
+
+ assert set(f.compile_options.static_argnums) == {0, 1}
+
+ @qjit(static_argnames=("x"), static_argnums=[1])
+ def f(x, y):
+ return
+
+ assert set(f.compile_options.static_argnums) == {0, 1}
+
+ @qjit(static_argnames=("y"), static_argnums=[0])
+ def f(x, y):
+ return
+
+ assert set(f.compile_options.static_argnums) == {0, 1}
+
+ @qjit(static_argnames=("y"), static_argnums=[1])
+ def f(x, y):
+ return
+
+ assert set(f.compile_options.static_argnums) == {1}
+
+ def test_static_argnames_with_decorator(self):
+ # pylint: disable=unused-argument, function-redefined
+ """Test static arguments specified by names
+ on functions with decorators"""
+
+ dev = qml.device("lightning.qubit", wires=3)
+
+ @qjit(static_argnames="theta")
+ @qml.qnode(dev)
+ def f(theta, phi):
+ qml.RX(theta, wires=0)
+ qml.RY(phi, wires=1)
+ return qml.probs()
+
+ assert set(f.compile_options.static_argnums) == {0}
+
+ @qjit(static_argnames=("x", "y"))
+ @grad
+ def f(x, y):
+ return x * y
+
+ assert set(f.compile_options.static_argnums) == {0, 1}
+
if __name__ == "__main__":
pytest.main(["-x", __file__])