Skip to content

Commit

Permalink
Add static_argnames option to qjit (#1158)
Browse files Browse the repository at this point in the history
**Context:**
Adding a `static_argnames` option to qjit for users to configure static
arguments by name.

**Description of the Change:**
Under the hood, this just maps the `static_argnames` to their argument
indices and add to `static_argnums`.

**Benefits:**
Users can specify static arguments to jitted functions by name.

**Possible Drawbacks:**
Even more keyword arguments to qjit...

[sc-41335]
  • Loading branch information
paul0403 authored Oct 22, 2024
1 parent d7c7e39 commit 4182a20
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 1 deletion.
24 changes: 24 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,30 @@
Array([2, 4, 6], dtype=int64)
```

* Static arguments of a qjit-compiled function can now be indicated by a `static_argnames`
argument to `qjit`.
[(#1158)](https://github.com/PennyLaneAI/catalyst/pull/1158)

```python
@qjit(static_argnames="y")
def f(x, y):
if y < 10: # y needs to be marked as static since its concrete boolean value is needed
return x + y

@qjit(static_argnames=["x","y"])
def g(x, y):
if x < 10 and y < 10:
return x + y

res_f = f(1, 2)
res_g = g(3, 4)
print(res_f, res_g)
```

```pycon
3 7
```

<h3>Improvements</h3>

* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface.
Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class CompileOptions:
the main compilation pipeline is complete. Default is ``True``.
static_argnums (Optional[Union[int, Iterable[int]]]): indices of static arguments.
Default is ``None``.
static_argnames (Optional[Union[str, Iterable[str]]]): names of static arguments.
Default is ``None``.
abstracted_axes (Optional[Any]): store the abstracted_axes value. Defaults to ``None``.
disable_assertions (Optional[bool]): disables all assertions. Default is ``False``.
seed (Optional[int]) : the seed for random operations in a qjit call.
Expand All @@ -92,6 +94,7 @@ class CompileOptions:
autograph_include: Optional[Iterable[str]] = ()
async_qnodes: Optional[bool] = False
static_argnums: Optional[Union[int, Iterable[int]]] = None
static_argnames: Optional[Union[str, Iterable[str]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
checkpoint_stage: Optional[str] = ""
Expand Down
10 changes: 10 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
filter_static_args,
get_abstract_signature,
get_type_annotations,
merge_static_argname_into_argnum,
merge_static_args,
promote_arguments,
verify_static_argnums,
Expand Down Expand Up @@ -82,6 +83,7 @@ def qjit(
logfile=None,
pipelines=None,
static_argnums=None,
static_argnames=None,
abstracted_axes=None,
disable_assertions=False,
seed=None,
Expand Down Expand Up @@ -124,6 +126,8 @@ def qjit(
considered to be used by advanced users for low-level debugging purposes.
static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the
positions of static arguments.
static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the
names of static arguments.
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]):
An experimental option to specify dynamic tensor shapes.
This option affects the compilation of the annotated function.
Expand Down Expand Up @@ -482,6 +486,12 @@ def __init__(self, fn, compile_options):
self.user_sig = get_type_annotations(fn)
self._validate_configuration()

# If static_argnames are present, convert them to static_argnums
if compile_options.static_argnames is not None:
compile_options.static_argnums = merge_static_argname_into_argnum(
fn, compile_options.static_argnames, compile_options.static_argnums
)

# Patch the conversion rules by adding the included modules before the block list
include_convertlist = tuple(
ag_config.Convert(rule) for rule in self.compile_options.autograph_include
Expand Down
32 changes: 32 additions & 0 deletions frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,38 @@ def split_static_args(args, static_argnums):
return tuple(dynamic_args), tuple(static_args)


def merge_static_argname_into_argnum(fn: Callable, static_argnames, static_argnums):
"""Map static_argnames of the callable to the corresponding argument indices,
and add them to static_argnums"""
new_static_argnums = [] if (static_argnums is None) else list(static_argnums)
fn_argnames = list(inspect.signature(fn).parameters.keys())

# static_argnames can be a single str, or a list/tuple of strs
# convert all of them to list
if isinstance(static_argnames, str):
static_argnames = [static_argnames]

non_existent_args = []
for static_argname in static_argnames:
if static_argname in fn_argnames:
new_static_argnums.append(fn_argnames.index(static_argname))
continue
non_existent_args.append(static_argname)

if non_existent_args:
non_existent_args_str = "{" + ", ".join(repr(item) for item in non_existent_args) + "}"

raise ValueError(
f"qjitted function has invalid argname {non_existent_args_str} in static_argnames. "
"Function does not take these args."
)

# Remove potential duplicates from static_argnums and static_argnames
new_static_argnums = tuple(sorted(set(new_static_argnums)))

return new_static_argnums


def merge_static_args(signature, args, static_argnums):
"""Merge static arguments back into an abstract signature, retaining the original ordering.
Expand Down
77 changes: 76 additions & 1 deletion frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pennylane as qml
import pytest

from catalyst import qjit
from catalyst import grad, qjit
from catalyst.utils.exceptions import CompileError


Expand Down Expand Up @@ -221,6 +221,81 @@ def wrapper(x, c):
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_static_argnames(self):
# pylint: disable=unused-argument, function-redefined
"""Test static arguments specified by names"""

@qjit(static_argnames="y")
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {1}

with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):

@qjit(static_argnames="yy")
def f_badname(x, y):
return

with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):

@qjit(static_argnames=["y", "yy"])
def f_badname_list(x, y):
return

with pytest.raises(ValueError, match="qjitted function has invalid argname {'xx', 'yy'}"):

@qjit(static_argnames=["xx", "yy"])
def f_badname_list(x, y):
return

@qjit(static_argnames=("x", "y"))
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("x"), static_argnums=[1])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("y"), static_argnums=[0])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("y"), static_argnums=[1])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {1}

def test_static_argnames_with_decorator(self):
# pylint: disable=unused-argument, function-redefined
"""Test static arguments specified by names
on functions with decorators"""

dev = qml.device("lightning.qubit", wires=3)

@qjit(static_argnames="theta")
@qml.qnode(dev)
def f(theta, phi):
qml.RX(theta, wires=0)
qml.RY(phi, wires=1)
return qml.probs()

assert set(f.compile_options.static_argnums) == {0}

@qjit(static_argnames=("x", "y"))
@grad
def f(x, y):
return x * y

assert set(f.compile_options.static_argnums) == {0, 1}


if __name__ == "__main__":
pytest.main(["-x", __file__])

0 comments on commit 4182a20

Please sign in to comment.