Skip to content

Commit

Permalink
Add a transform to split sums to QJITDevice preprocess (#927)
Browse files Browse the repository at this point in the history
**Context:**
Some backends do not natively support sums of observables (e.g.
`qml.expval(X(0) + Y(1))`).

This is already taken care of by the `split_non_commuting` transform if
the backend in question
also doesn't support non-commuting observables, but applying
`split_non_commuting` to handle sums on
for devices that support non-commuting measurements will result in
unnecessarily splitting executions into multiple tapes.

**Description of the Change:**
If `split_non_commuting` is not being added to the transform program
already, and the device does not support sum observables, we add the
transform `split_to_single_terms`, which splits the sums but leaves them
all on a single tape. Supporting either `Hamiltonian` or `Sum` is taken
to indicate support for summed observables.
  • Loading branch information
lillian542 authored Jul 16, 2024
1 parent 1e5c6d6 commit 1f2afab
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c
enzyme=v0.0.130

# Always remove custom PL/LQ versions before release.
pennylane=296316654bd6aabb3ff67eb0bac440fa8d706ee8
pennylane=50def12062e4e4d64611246bdbd85ad8bf87cfff
32 changes: 25 additions & 7 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import pennylane as qml
from pennylane.measurements import MidMeasureMP
from pennylane.transforms import split_non_commuting
from pennylane.transforms import split_non_commuting, split_to_single_terms
from pennylane.transforms.core import TransformProgram

from catalyst.device.decomposition import (
Expand Down Expand Up @@ -466,14 +466,14 @@ def preprocess(
"""

_, config = self.original_device.preprocess(execution_config)

program = TransformProgram()

# measurement transforms (these may change operations on the tape to accommodate
# measurement transformations, so must occur before decomposition of measurements)
if self.qjit_capabilities.non_commuting_observables_flag is False:
program.add_transform(split_non_commuting)
if self.measurement_processes == {"Counts"}:
program.add_transform(measurements_from_counts)
# measurement transforms may change operations on the tape to accommodate
# measurement transformations, so must occur before decomposition
measurement_transforms = self._measurement_transform_program()
config.device_options["transforms_modify_measurements"] = bool(measurement_transforms)
program = program + measurement_transforms

# decomposition to supported ops/measurements
ops_acceptance = partial(catalyst_acceptance, operations=self.operations)
Expand Down Expand Up @@ -501,6 +501,24 @@ def preprocess(

return program, config

def _measurement_transform_program(self):

measurement_program = TransformProgram()

supports_sum_observables = any(
obs in self.qjit_capabilities.native_obs for obs in ("Sum", "Hamiltonian")
)

if self.qjit_capabilities.non_commuting_observables_flag is False:
measurement_program.add_transform(split_non_commuting)
elif not supports_sum_observables:
measurement_program.add_transform(split_to_single_terms)

if self.measurement_processes == {"Counts"}:
measurement_program.add_transform(measurements_from_counts)

return measurement_program

def execute(self, circuits, execution_config):
"""
Raises: RuntimeError
Expand Down
12 changes: 4 additions & 8 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,13 +925,11 @@ def apply_transform(
params = tape.get_parameters(trainable_only=False)
tape.trainable_params = qml.math.get_trainable_indices(params)

is_program_transformed = qnode_program

if is_program_transformed and qnode_program.is_informative:
if qnode_program.is_informative:
msg = "Catalyst does not support informative transforms."
raise CompileError(msg)

if is_program_transformed or device_modify_measurements:
if qnode_program or device_modify_measurements:
is_valid_for_batch = is_transform_valid_for_batch_transforms(tape, flat_results)
total_program = qnode_program + device_program
else:
Expand Down Expand Up @@ -1117,15 +1115,13 @@ def is_leaf(obj):
if isinstance(device, qml.devices.Device):
config = _make_execution_config(qnode)
device_program, config = device.preprocess(ctx, config)
device_modify_measurements = config.device_options["transforms_modify_measurements"]
else:
device_program = TransformProgram()
device_modify_measurements = False # this is only for the new API transform program

qnode_program = qnode.transform_program if qnode else TransformProgram()

device_modify_measurements = "measurements_from_counts" in [
t.transform.__name__ for t in device_program
]

tapes, post_processing = apply_transform(
qnode_program,
device_program,
Expand Down
103 changes: 89 additions & 14 deletions frontend/test/pytest/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pennylane.devices import Device
from pennylane.devices.execution_config import DefaultExecutionConfig, ExecutionConfig
from pennylane.tape import QuantumScript
from pennylane.transforms import split_non_commuting
from pennylane.transforms import split_non_commuting, split_to_single_terms
from pennylane.transforms.core import TransformProgram

from catalyst import CompileError, ctrl
Expand Down Expand Up @@ -337,32 +337,64 @@ def circuit(theta: float):
assert "expval" not in mlir
assert "counts" in mlir

def test_non_commuting_measurements_are_split(self, mocker):
"""Test that the split_non_commuting transform is added to the transform
program from preprocess when non_commuting_observables_flag is False"""
def test_measurements_are_split(self, mocker):
"""Test that the split_to_single_terms or split_non_commuting transform
are added to the transform program from preprocess as expected, based on the
sum_observables_flag and the non_commuting_observables_flag"""

# dummy device supports non_commuting observables by default
dev = DummyDevice(wires=4, shots=1000)

# Create a qjit device that supports non-commuting observables
dev_capabilities = get_device_capabilities(dev, ProgramFeatures(dev.shots is not None))

# dev1 supports non-commuting observables and sum observables - no splitting
assert "Sum" in dev_capabilities.native_obs
assert "Hamiltonian" in dev_capabilities.native_obs
assert dev_capabilities.non_commuting_observables_flag is True
backend_info = extract_backend_info(dev, dev_capabilities)
qjit_dev1 = QJITDeviceNewAPI(dev, dev_capabilities, backend_info)

# Create a qjit device that does NOT support non-commuting observables
dev_capabilities = replace(dev_capabilities, non_commuting_observables_flag=False)
# dev2 supports non-commuting observables but NOT sums - split_to_single_terms
del dev_capabilities.native_obs["Sum"]
del dev_capabilities.native_obs["Hamiltonian"]
backend_info = extract_backend_info(dev, dev_capabilities)
qjit_dev2 = QJITDeviceNewAPI(dev, dev_capabilities, backend_info)

# dev3 supports does not support non-commuting observables OR sums - split_non_commuting
dev_capabilities = replace(dev_capabilities, non_commuting_observables_flag=False)
backend_info = extract_backend_info(dev, dev_capabilities)
qjit_dev3 = QJITDeviceNewAPI(dev, dev_capabilities, backend_info)

# dev4 supports sums but NOT non-commuting observables - split_non_commuting
dev_capabilities = replace(dev_capabilities, non_commuting_observables_flag=False)
backend_info = extract_backend_info(dev, dev_capabilities)
qjit_dev4 = QJITDeviceNewAPI(dev, dev_capabilities, backend_info)

# Check the preprocess
with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx:
transform_program1, _ = qjit_dev1.preprocess(ctx)
transform_program2, _ = qjit_dev2.preprocess(ctx)
transform_program1, _ = qjit_dev1.preprocess(ctx) # no splitting
transform_program2, _ = qjit_dev2.preprocess(ctx) # split_to_single_terms
transform_program3, _ = qjit_dev3.preprocess(ctx) # split_non_commuting
transform_program4, _ = qjit_dev4.preprocess(ctx) # split_non_commuting

assert split_to_single_terms not in transform_program1
assert split_non_commuting not in transform_program1
assert split_non_commuting in transform_program2

def test_split_non_commuting_execution(self, mocker):
assert split_to_single_terms in transform_program2
assert split_non_commuting not in transform_program2

assert split_non_commuting in transform_program3
assert split_to_single_terms not in transform_program3

assert split_non_commuting in transform_program4
assert split_to_single_terms not in transform_program4

@pytest.mark.parametrize(
"observables",
[
(qml.X(0) @ qml.X(1), qml.Y(0)), # distributed to separate tapes, but no sum splitting
(qml.X(0) + qml.X(1), qml.Y(0)), # split into 3 seperate terms and distributed
],
)
def test_split_non_commuting_execution(self, observables, mocker):
"""Test that the results of the execution for a tape with non-commuting observables is
consistent (on a backend that does, in fact, support non-commuting observables) regardless
of whether split_non_commuting is applied or not as expected"""
Expand All @@ -373,7 +405,7 @@ def test_split_non_commuting_execution(self, mocker):
def unjitted_circuit(theta: float):
qml.RX(theta, 0)
qml.RY(0.89, 1)
return qml.expval(qml.X(0) @ qml.X(1)), qml.expval(qml.Y(0))
return [qml.expval(o) for o in observables]

expected_result = unjitted_circuit(1.2)

Expand All @@ -400,6 +432,49 @@ def unjitted_circuit(theta: float):
transform_program, _ = spy.spy_return
assert split_non_commuting in transform_program

def test_split_to_single_terms_execution(self, mocker):
"""Test that the results of the execution for a tape with multi-term observables is
consistent (on a backend that does, in fact, support multi-term observables) regardless
of whether split_to_single_terms is applied or not"""

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(dev)
def unjitted_circuit(theta: float):
qml.RX(theta, 0)
qml.RY(0.89, 1)
return qml.expval(qml.X(0) + qml.X(1)), qml.expval(qml.Y(0))

expected_result = unjitted_circuit(1.2)

config = get_device_toml_config(dev)
spy = mocker.spy(QJITDeviceNewAPI, "preprocess")

# make sure non_commuting_observables_flag is True - otherwise we use
# split_non_commuting instead of split_to_single_terms
assert config["compilation"]["non_commuting_observables"] is True
# make sure the testing device does in fact support sum observables
assert "Sum" in config["operators"]["observables"]

# test case where transform should not be applied
jitted_circuit = qml.qjit(unjitted_circuit)
assert len(jitted_circuit(1.2)) == len(expected_result) == 2
assert np.allclose(jitted_circuit(1.2), expected_result)

transform_program, _ = spy.spy_return
assert split_to_single_terms not in transform_program

# mock TOML file output to indicate non-commuting observables are NOT supported
del config["operators"]["observables"]["Sum"]
del config["operators"]["observables"]["Hamiltonian"]
with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)):
jitted_circuit = qml.qjit(unjitted_circuit)
assert len(jitted_circuit(1.2)) == len(expected_result) == 2
assert np.allclose(jitted_circuit(1.2), unjitted_circuit(1.2))

transform_program, _ = spy.spy_return
assert split_to_single_terms in transform_program


# tapes and regions for generating HybridOps
tape1 = QuantumScript([qml.X(0), qml.Hadamard(1)])
Expand Down

0 comments on commit 1f2afab

Please sign in to comment.