Skip to content

Commit

Permalink
fix: Expression arithmetic does not error when a numpy type is on the…
Browse files Browse the repository at this point in the history
… left hand side. (#1769)

* fix: Expression arithmetic does not error when a numpy type is on the
left hand side.

* add comment
  • Loading branch information
MarquessV authored Apr 16, 2024
1 parent 9b492bc commit 6e9d944
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
8 changes: 5 additions & 3 deletions pyquil/quilatom.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,11 @@ def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray:
return np.asarray(self._evaluate(), dtype=dtype)
raise ValueError
except ValueError:
# Note: The `None` here is a placeholder for the expression in the numpy array.
# The expression instance will still be accessible in the array.
return np.array(None, dtype=object)
# np.asarray(self, ...) would cause an infinite recursion error, so we build the array with a
# placeholder value, then replace it with self after.
array = np.asarray(None, dtype=object)
array.flat[0] = self
return array


ParameterSubstitutionsMapDesignator = Mapping[Union["Parameter", "MemoryReference"], ExpressionValueDesignator]
Expand Down
10 changes: 9 additions & 1 deletion test/unit/test_quilatom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Sequence, Union

import pytest
from syrupy.assertion import SnapshotAssertion
import numpy as np

from pyquil.quilatom import FormalArgument, Frame, Qubit, Label, LabelPlaceholder, QubitPlaceholder
from pyquil.quilatom import Add, FormalArgument, Frame, Qubit, Label, LabelPlaceholder, QubitPlaceholder, Parameter


@pytest.mark.parametrize(
Expand Down Expand Up @@ -69,3 +71,9 @@ def test_qubit_placeholder():
register[0].out()

assert register[0] != register[1]


def test_arithmetic_with_numpy():
x = Parameter("x")
expression = np.float_(1.0) + x
assert expression == Add(np.float_(1.0), x)

0 comments on commit 6e9d944

Please sign in to comment.