From 6e9d944bcd481f756586f91cc647c852e16394d5 Mon Sep 17 00:00:00 2001 From: Marquess Valdez Date: Tue, 16 Apr 2024 11:27:45 -0700 Subject: [PATCH] fix: Expression arithmetic does not error when a numpy type is on the left hand side. (#1769) * fix: Expression arithmetic does not error when a numpy type is on the left hand side. * add comment --- pyquil/quilatom.py | 8 +++++--- test/unit/test_quilatom.py | 10 +++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pyquil/quilatom.py b/pyquil/quilatom.py index d18c34ec2..4b93cff1b 100644 --- a/pyquil/quilatom.py +++ b/pyquil/quilatom.py @@ -578,9 +578,11 @@ def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray: return np.asarray(self._evaluate(), dtype=dtype) raise ValueError except ValueError: - # Note: The `None` here is a placeholder for the expression in the numpy array. - # The expression instance will still be accessible in the array. - return np.array(None, dtype=object) + # np.asarray(self, ...) would cause an infinite recursion error, so we build the array with a + # placeholder value, then replace it with self after. + array = np.asarray(None, dtype=object) + array.flat[0] = self + return array ParameterSubstitutionsMapDesignator = Mapping[Union["Parameter", "MemoryReference"], ExpressionValueDesignator] diff --git a/test/unit/test_quilatom.py b/test/unit/test_quilatom.py index 82883681a..231ac729d 100644 --- a/test/unit/test_quilatom.py +++ b/test/unit/test_quilatom.py @@ -1,8 +1,10 @@ from typing import Sequence, Union + import pytest from syrupy.assertion import SnapshotAssertion +import numpy as np -from pyquil.quilatom import FormalArgument, Frame, Qubit, Label, LabelPlaceholder, QubitPlaceholder +from pyquil.quilatom import Add, FormalArgument, Frame, Qubit, Label, LabelPlaceholder, QubitPlaceholder, Parameter @pytest.mark.parametrize( @@ -69,3 +71,9 @@ def test_qubit_placeholder(): register[0].out() assert register[0] != register[1] + + +def test_arithmetic_with_numpy(): + x = Parameter("x") + expression = np.float_(1.0) + x + assert expression == Add(np.float_(1.0), x)