Skip to content

Commit

Permalink
Fix extrapolation in ZNE function (#1213)
Browse files Browse the repository at this point in the history
Extrapolation was done against the folding numbers instead of the scale
factors. Since the folding numbers start at 0, extrapolation would
always yield a result very close to the first data point.

---------

Co-authored-by: Romain Moyard <[email protected]>
  • Loading branch information
dime10 and rmoyard authored Oct 23, 2024
1 parent da04467 commit 22a900f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@

<h3>Bug fixes</h3>

* Fix a bug in `catalyst.mitigate_with_zne` that would lead
to incorrectly extrapolated results.
[(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1213)

* Fix a bug preventing the target of `qml.adjoint` and `qml.ctrl` calls from being transformed by
AutoGraph.
[(#1212)](https://github.com/PennyLaneAI/catalyst/pull/1212)
Expand Down
28 changes: 14 additions & 14 deletions frontend/catalyst/api_extensions/error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def workflow(weights, s):
if not _is_odd_positive(scale_factors):
raise ValueError("The scale factors must be positive odd integers: {scale_factors}")

num_folds = jnp.array([jnp.floor((s - 1) / 2) for s in scale_factors], dtype=int)

return ZNECallable(fn, num_folds, extrapolate, folding)
return ZNECallable(fn, scale_factors, extrapolate, folding)


## IMPL ##
Expand All @@ -164,14 +162,14 @@ class ZNECallable(CatalystCallable):
def __init__(
self,
fn: Callable,
num_folds: jnp.ndarray,
scale_factors: Sequence[int],
extrapolate: Callable[[Sequence[float], Sequence[float]], float],
folding: str,
):
functools.update_wrapper(self, fn)
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
self.num_folds = num_folds
self.scale_factors = scale_factors
self.extrapolate = extrapolate
self.folding = folding

Expand Down Expand Up @@ -209,16 +207,18 @@ def __call__(self, *args, **kwargs):
callable_fn
), "expected callable set as param on the first operation in zne target"

results = zne_p.bind(
*args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=callable_fn
fold_numbers = (jnp.asarray(self.scale_factors, dtype=int) - 1) // 2
fold_results = zne_p.bind(
*args_data, fold_numbers, folding=folding, jaxpr=jaxpr, fn=callable_fn
)
float_num_folds = jnp.array(self.num_folds, dtype=float)
results = self.extrapolate(float_num_folds, results[0])
# Single measurement
if results.shape == ():
return results
# Multiple measurements
return tuple(res for res in results)

scale_factors = jnp.asarray(self.scale_factors, dtype=float)
zne_results = self.extrapolate(scale_factors, fold_results)

# if multiple measurement processes, split array back into tuple
if len(zne_results.shape):
zne_results = tuple(zne_results)
return zne_results


def polynomial_extrapolation(degree):
Expand Down
3 changes: 1 addition & 2 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ class Folding(Enum):
##############

zne_p = core.Primitive("zne")
zne_p.multiple_results = True
qdevice_p = core.Primitive("qdevice")
qdevice_p.multiple_results = True
qalloc_p = core.Primitive("qalloc")
Expand Down Expand Up @@ -1053,7 +1052,7 @@ def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-arg
shape = list(args[-1].shape)
if len(jaxpr.out_avals) > 1:
shape.append(len(jaxpr.out_avals))
return [core.ShapedArray(shape, jaxpr.out_avals[0].dtype)]
return core.ShapedArray(shape, jaxpr.out_avals[0].dtype)


def _folding_attribute(ctx, folding):
Expand Down
2 changes: 1 addition & 1 deletion frontend/test/lit/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def circuit():


# CHECK: func.func public @jit_mcm_method_with_zne() -> tensor<f64>
# CHECK: mitigation.zne @one_shot_wrapper(%c_0) folding( global) numFolds(%6 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64>
# CHECK: mitigation.zne @one_shot_wrapper(%c) folding( global) numFolds(%2 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64>

# CHECK: func.func private @one_shot_wrapper(%arg0: tensor<5xi1>) -> tensor<f64>
# CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<f64>
Expand Down

0 comments on commit 22a900f

Please sign in to comment.