From 6036ad755eef9fc849b0c0934cc464fd8893233d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 4 Oct 2024 19:31:52 +0200 Subject: [PATCH] Add QGF data type for Galois Fields (#1433) * Add QGF data type for Galois Fields * Update requirements * More docstrings * debugging ci * debugging ci * debugging ci - no parallel nb * roll back debugging changes * Just set it in an env variable * More tests, change str * Fix test * Address nits --------- Co-authored-by: Matthew Harrigan --- .github/workflows/ci.yaml | 2 + dev_tools/requirements/deps/runtime.txt | 1 + dev_tools/requirements/envs/dev.env.txt | 33 +++-- dev_tools/requirements/envs/docs.env.txt | 17 ++- dev_tools/requirements/envs/format.env.txt | 17 ++- dev_tools/requirements/envs/pip-tools.env.txt | 4 +- dev_tools/requirements/envs/pylint.env.txt | 19 ++- dev_tools/requirements/envs/pytest.env.txt | 17 ++- dev_tools/requirements/envs/runtime.env.txt | 17 ++- qualtran/__init__.py | 1 + qualtran/_infra/data_types.py | 133 +++++++++++++++++- qualtran/_infra/data_types_test.py | 38 ++++- qualtran/simulation/classical_sim.py | 15 +- qualtran/simulation/classical_sim_test.py | 12 ++ 14 files changed, 275 insertions(+), 51 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7aa99955f..9740e2532 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -65,6 +65,8 @@ jobs: pip install --no-deps -e . - run: | python dev_tools/execute-notebooks.py + env: + NUMBA_NUM_THREADS: 4 format: runs-on: ubuntu-latest diff --git a/dev_tools/requirements/deps/runtime.txt b/dev_tools/requirements/deps/runtime.txt index e1311b35b..f5bb0da86 100644 --- a/dev_tools/requirements/deps/runtime.txt +++ b/dev_tools/requirements/deps/runtime.txt @@ -6,6 +6,7 @@ numpy sympy cirq-core==1.4 fxpmath +galois # qualtran/testing.py nbconvert diff --git a/dev_tools/requirements/envs/dev.env.txt b/dev_tools/requirements/envs/dev.env.txt index 109d11252..60b9216e4 100644 --- a/dev_tools/requirements/envs/dev.env.txt +++ b/dev_tools/requirements/envs/dev.env.txt @@ -123,7 +123,7 @@ defusedxml==0.7.1 # via nbconvert deprecation==2.1.0 # via openfermion -dill==0.3.8 +dill==0.3.9 # via pylint distlib==0.3.8 # via virtualenv @@ -161,13 +161,15 @@ fqdn==1.5.1 # via jsonschema fxpmath==0.4.9 # via -r deps/runtime.txt +galois==0.4.2 + # via -r deps/runtime.txt graphviz==0.20.3 # via qref greenlet==3.1.1 # via sqlalchemy -grpcio==1.66.1 +grpcio==1.66.2 # via grpcio-tools -grpcio-tools==1.66.1 +grpcio-tools==1.66.2 # via -r deps/packaging.txt h11==0.14.0 # via httpcore @@ -175,7 +177,7 @@ h5py==3.12.1 # via # openfermion # pyscf -httpcore==1.0.5 +httpcore==1.0.6 # via httpx httpx==0.27.2 # via jupyterlab @@ -201,7 +203,7 @@ ipykernel==6.29.5 # -r deps/pytest.txt # jupyterlab # myst-nb -ipython==8.27.0 +ipython==8.28.0 # via # -r deps/runtime.txt # ipykernel @@ -224,7 +226,7 @@ jaraco-classes==3.4.0 # via keyring jaraco-context==6.0.1 # via keyring -jaraco-functools==4.0.2 +jaraco-functools==4.1.0 # via keyring jax==0.4.33 # via openfermion @@ -395,7 +397,9 @@ notebook-shim==0.2.4 # jupyterlab # notebook numba==0.60.0 - # via quimb + # via + # galois + # quimb numpy==1.26.4 # via # -r deps/runtime.txt @@ -403,6 +407,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # h5py # jax # jaxlib @@ -517,7 +522,7 @@ pyparsing==3.1.4 # bartiq # matplotlib # pydot -pyproject-hooks==1.1.0 +pyproject-hooks==1.2.0 # via # build # pip-tools @@ -562,9 +567,9 @@ qref==0.7.0 # via # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via -r deps/runtime.txt quimb==1.8.4 # via -r deps/runtime.txt @@ -597,7 +602,7 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rich==13.8.1 +rich==13.9.1 # via twine rpds-py==0.20.0 # via @@ -677,7 +682,7 @@ terminado==0.18.1 # jupyter-server-terminals tinycss2==1.3.0 # via nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # black # build @@ -736,12 +741,14 @@ typing-extensions==4.12.2 # black # cirq-core # dash + # galois # ipython # mypy # myst-nb # pydantic # pydantic-core # pydata-sphinx-theme + # rich # sqlalchemy tzdata==2024.2 # via pandas @@ -751,7 +758,7 @@ urllib3==2.2.3 # via # requests # twine -virtualenv==20.26.5 +virtualenv==20.26.6 # via -r deps/packaging.txt wcwidth==0.2.13 # via prompt-toolkit diff --git a/dev_tools/requirements/envs/docs.env.txt b/dev_tools/requirements/envs/docs.env.txt index 5296c6f19..2468fd901 100644 --- a/dev_tools/requirements/envs/docs.env.txt +++ b/dev_tools/requirements/envs/docs.env.txt @@ -204,6 +204,10 @@ fxpmath==0.4.9 # via # -c envs/dev.env.txt # -r deps/runtime.txt +galois==0.4.2 + # via + # -c envs/dev.env.txt + # -r deps/runtime.txt graphviz==0.20.3 # via # -c envs/dev.env.txt @@ -216,7 +220,7 @@ h11==0.14.0 # via # -c envs/dev.env.txt # httpcore -httpcore==1.0.5 +httpcore==1.0.6 # via # -c envs/dev.env.txt # httpx @@ -246,7 +250,7 @@ ipykernel==6.29.5 # -c envs/dev.env.txt # jupyterlab # myst-nb -ipython==8.27.0 +ipython==8.28.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -455,6 +459,7 @@ notebook-shim==0.2.4 numba==0.60.0 # via # -c envs/dev.env.txt + # galois # quimb numpy==1.26.4 # via @@ -463,6 +468,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # matplotlib # numba # pandas @@ -613,11 +619,11 @@ qref==0.7.0 # -c envs/dev.env.txt # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -756,7 +762,7 @@ tinycss2==1.3.0 # via # -c envs/dev.env.txt # nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # jupyterlab @@ -808,6 +814,7 @@ typing-extensions==4.12.2 # async-lru # cirq-core # dash + # galois # ipython # myst-nb # pydantic diff --git a/dev_tools/requirements/envs/format.env.txt b/dev_tools/requirements/envs/format.env.txt index f8bb281d1..14fdefc6a 100644 --- a/dev_tools/requirements/envs/format.env.txt +++ b/dev_tools/requirements/envs/format.env.txt @@ -190,6 +190,10 @@ fxpmath==0.4.9 # via # -c envs/dev.env.txt # -r deps/runtime.txt +galois==0.4.2 + # via + # -c envs/dev.env.txt + # -r deps/runtime.txt graphviz==0.20.3 # via # -c envs/dev.env.txt @@ -198,7 +202,7 @@ h11==0.14.0 # via # -c envs/dev.env.txt # httpcore -httpcore==1.0.5 +httpcore==1.0.6 # via # -c envs/dev.env.txt # httpx @@ -221,7 +225,7 @@ ipykernel==6.29.5 # via # -c envs/dev.env.txt # jupyterlab -ipython==8.27.0 +ipython==8.28.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -401,6 +405,7 @@ notebook-shim==0.2.4 numba==0.60.0 # via # -c envs/dev.env.txt + # galois # quimb numpy==1.26.4 # via @@ -409,6 +414,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # matplotlib # numba # pandas @@ -551,11 +557,11 @@ qref==0.7.0 # -c envs/dev.env.txt # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -646,7 +652,7 @@ tinycss2==1.3.0 # via # -c envs/dev.env.txt # nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # black @@ -700,6 +706,7 @@ typing-extensions==4.12.2 # black # cirq-core # dash + # galois # ipython # pydantic # pydantic-core diff --git a/dev_tools/requirements/envs/pip-tools.env.txt b/dev_tools/requirements/envs/pip-tools.env.txt index 986d32203..1503cbbc1 100644 --- a/dev_tools/requirements/envs/pip-tools.env.txt +++ b/dev_tools/requirements/envs/pip-tools.env.txt @@ -20,12 +20,12 @@ pip-tools==7.4.1 # via # -c envs/dev.env.txt # -r deps/pip-tools.txt -pyproject-hooks==1.1.0 +pyproject-hooks==1.2.0 # via # -c envs/dev.env.txt # build # pip-tools -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # build diff --git a/dev_tools/requirements/envs/pylint.env.txt b/dev_tools/requirements/envs/pylint.env.txt index d8f95d15c..d705b6bd9 100644 --- a/dev_tools/requirements/envs/pylint.env.txt +++ b/dev_tools/requirements/envs/pylint.env.txt @@ -170,7 +170,7 @@ deprecation==2.1.0 # via # -c envs/dev.env.txt # openfermion -dill==0.3.8 +dill==0.3.9 # via # -c envs/dev.env.txt # pylint @@ -216,6 +216,10 @@ fxpmath==0.4.9 # via # -c envs/dev.env.txt # -r deps/runtime.txt +galois==0.4.2 + # via + # -c envs/dev.env.txt + # -r deps/runtime.txt graphviz==0.20.3 # via # -c envs/dev.env.txt @@ -229,7 +233,7 @@ h5py==3.12.1 # -c envs/dev.env.txt # openfermion # pyscf -httpcore==1.0.5 +httpcore==1.0.6 # via # -c envs/dev.env.txt # httpx @@ -260,7 +264,7 @@ ipykernel==6.29.5 # via # -c envs/dev.env.txt # jupyterlab -ipython==8.27.0 +ipython==8.28.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -463,6 +467,7 @@ notebook-shim==0.2.4 numba==0.60.0 # via # -c envs/dev.env.txt + # galois # quimb numpy==1.26.4 # via @@ -472,6 +477,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # h5py # jax # jaxlib @@ -649,11 +655,11 @@ qref==0.7.0 # -c envs/dev.env.txt # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -788,7 +794,7 @@ tinycss2==1.3.0 # via # -c envs/dev.env.txt # nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # jupyterlab @@ -847,6 +853,7 @@ typing-extensions==4.12.2 # async-lru # cirq-core # dash + # galois # ipython # pydantic # pydantic-core diff --git a/dev_tools/requirements/envs/pytest.env.txt b/dev_tools/requirements/envs/pytest.env.txt index c64afb2e8..c740dd068 100644 --- a/dev_tools/requirements/envs/pytest.env.txt +++ b/dev_tools/requirements/envs/pytest.env.txt @@ -199,6 +199,10 @@ fxpmath==0.4.9 # via # -c envs/dev.env.txt # -r deps/runtime.txt +galois==0.4.2 + # via + # -c envs/dev.env.txt + # -r deps/runtime.txt graphviz==0.20.3 # via # -c envs/dev.env.txt @@ -212,7 +216,7 @@ h5py==3.12.1 # -c envs/dev.env.txt # openfermion # pyscf -httpcore==1.0.5 +httpcore==1.0.6 # via # -c envs/dev.env.txt # httpx @@ -240,7 +244,7 @@ ipykernel==6.29.5 # -c envs/dev.env.txt # -r deps/pytest.txt # jupyterlab -ipython==8.27.0 +ipython==8.28.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -428,6 +432,7 @@ notebook-shim==0.2.4 numba==0.60.0 # via # -c envs/dev.env.txt + # galois # quimb numpy==1.26.4 # via @@ -437,6 +442,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # h5py # jax # jaxlib @@ -620,11 +626,11 @@ qref==0.7.0 # -c envs/dev.env.txt # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -722,7 +728,7 @@ tinycss2==1.3.0 # via # -c envs/dev.env.txt # nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # coverage @@ -775,6 +781,7 @@ typing-extensions==4.12.2 # async-lru # cirq-core # dash + # galois # ipython # pydantic # pydantic-core diff --git a/dev_tools/requirements/envs/runtime.env.txt b/dev_tools/requirements/envs/runtime.env.txt index 01a1dec16..e8ccda755 100644 --- a/dev_tools/requirements/envs/runtime.env.txt +++ b/dev_tools/requirements/envs/runtime.env.txt @@ -177,6 +177,10 @@ fxpmath==0.4.9 # via # -c envs/dev.env.txt # -r deps/runtime.txt +galois==0.4.2 + # via + # -c envs/dev.env.txt + # -r deps/runtime.txt graphviz==0.20.3 # via # -c envs/dev.env.txt @@ -185,7 +189,7 @@ h11==0.14.0 # via # -c envs/dev.env.txt # httpcore -httpcore==1.0.5 +httpcore==1.0.6 # via # -c envs/dev.env.txt # httpx @@ -208,7 +212,7 @@ ipykernel==6.29.5 # via # -c envs/dev.env.txt # jupyterlab -ipython==8.27.0 +ipython==8.28.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -380,6 +384,7 @@ notebook-shim==0.2.4 numba==0.60.0 # via # -c envs/dev.env.txt + # galois # quimb numpy==1.26.4 # via @@ -388,6 +393,7 @@ numpy==1.26.4 # cirq-core # contourpy # fxpmath + # galois # matplotlib # numba # pandas @@ -524,11 +530,11 @@ qref==0.7.0 # -c envs/dev.env.txt # -r deps/runtime.txt # bartiq -qsharp==1.8.0 +qsharp==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt -qsharp-widgets==1.8.0 +qsharp-widgets==1.9.0 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -619,7 +625,7 @@ tinycss2==1.3.0 # via # -c envs/dev.env.txt # nbconvert -tomli==2.0.1 +tomli==2.0.2 # via # -c envs/dev.env.txt # jupyterlab @@ -670,6 +676,7 @@ typing-extensions==4.12.2 # async-lru # cirq-core # dash + # galois # ipython # pydantic # pydantic-core diff --git a/qualtran/__init__.py b/qualtran/__init__.py index 6f04267a3..244af6828 100644 --- a/qualtran/__init__.py +++ b/qualtran/__init__.py @@ -53,6 +53,7 @@ QUInt, BQUInt, QMontgomeryUInt, + QGF, ) # Internal imports: none diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index 5ea2ac1f8..e21eb209f 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -50,6 +50,7 @@ import abc from enum import Enum +from functools import cached_property from typing import Any, Iterable, List, Sequence, Union import attrs @@ -57,7 +58,7 @@ from fxpmath import Fxp from numpy.typing import NDArray -from qualtran.symbolics import is_symbolic, SymbolicInt +from qualtran.symbolics import bit_length, is_symbolic, SymbolicInt class QDType(metaclass=abc.ABCMeta): @@ -810,8 +811,134 @@ def assert_valid_classical_val_array( raise ValueError(f"Too-large classical values encountered in {debug_str}") +@attrs.frozen +class QGF(QDType): + r"""Galois Field type to represent elements of a finite field. + + A Finite Field or Galois Field is a field that contains finite number of elements. The order + of a finite field is the number of elements in the field, which is either a prime number or + a prime power. For every prime number $p$ and every positive integer $m$ there are fields of + order $p^m$, all of which are isomorphic. When m=1, the finite field of order p can be + constructed via integers modulo p. + + Elements of a Galois Field $GF(p^m)$ may be conveniently viewed as polynomials + $a_{0} + a_{1}x + ... + a_{m−1}x_{m−1}$, where $a_0, a_1, ..., a_{m−1} \in F(p)$. + $GF(p^m)$ addition is defined as the component-wise (polynomial) addition over F(p) and + multiplication is defined as polynomial multiplication modulo an irreducible polynomial of + degree $m$. The selection of the specific irreducible polynomial affects the representation + of the given field, but all fields of a fixed size are isomorphic. + + The data type uses the [Galois library](https://mhostetter.github.io/galois/latest/) to + perform arithmetic over Galois Fields. By default, the Conway polynomial $C_{p, m}$ is used + as the irreducible polynomial. + + Attributes: + characteristic: The characteristic $p$ of the field $GF(p^m)$. + The characteristic must be prime. + degree: The degree $m$ of the field $GF(p^{m})$. The degree must be a positive integer. + + References + [Finite Field](https://en.wikipedia.org/wiki/Finite_field) + + [Intro to Prime Fields](https://mhostetter.github.io/galois/latest/tutorials/intro-to-prime-fields/) + + [Intro to Extension Fields](https://mhostetter.github.io/galois/latest/tutorials/intro-to-extension-fields/) + """ + + characteristic: SymbolicInt + degree: SymbolicInt + + @cached_property + def order(self) -> SymbolicInt: + return self.characteristic**self.degree + + @cached_property + def bitsize(self) -> SymbolicInt: + """Bitsize of qubit register required to represent a single instance of this data type.""" + return bit_length(self.order - 1) + + @cached_property + def num_qubits(self) -> SymbolicInt: + """Number of qubits required to represent a single instance of this data type.""" + return self.bitsize + + def get_classical_domain(self) -> Iterable[Any]: + """Yields all possible classical (computational basis state) values representable + by this type.""" + yield from self.gf_type.elements + + @cached_property + def _quint_equivalent(self) -> QUInt: + return QUInt(self.num_qubits) + + @cached_property + def gf_type(self): + from galois import GF + + return GF(int(self.characteristic), int(self.degree), compile='python-calculate') + + def to_bits(self, x) -> List[int]: + """Yields individual bits corresponding to binary representation of x""" + self.assert_valid_classical_val(x) + return self._quint_equivalent.to_bits(int(x)) + + def from_bits(self, bits: Sequence[int]): + """Combine individual bits to form x""" + return self.gf_type(self._quint_equivalent.from_bits(bits)) + + def from_bits_array(self, bits_array: NDArray[np.uint8]): + """Combine individual bits to form classical values. + + Often, converting an array can be performed faster than converting each element individually. + This operation accepts any NDArray of bits such that the last dimension equals `self.bitsize`, + and the output array satisfies `output_shape = input_shape[:-1]`. + """ + return self.gf_type(self._quint_equivalent.from_bits_array(bits_array)) + + def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'): + """Raises an exception if `val` is not a valid classical value for this type. + + Args: + val: A classical value that should be in the domain of this QDType. + debug_str: Optional debugging information to use in exception messages. + """ + if not isinstance(val, self.gf_type): + raise ValueError(f"{debug_str} should be a {self.gf_type}, not {val!r}") + + def assert_valid_classical_val_array(self, val_array: NDArray[Any], debug_str: str = 'val'): + """Raises an exception if `val_array` is not a valid array of classical values + for this type. + + Often, validation on an array can be performed faster than validating each element + individually. + + Args: + val_array: A numpy array of classical values. Each value should be in the domain + of this QDType. + debug_str: Optional debugging information to use in exception messages. + """ + if np.any(val_array < 0): + raise ValueError(f"Negative classical values encountered in {debug_str}") + if np.any(val_array >= self.order): + raise ValueError(f"Too-large classical values encountered in {debug_str}") + + def is_symbolic(self) -> bool: + """Returns True if this qdtype is parameterized with symbolic objects.""" + return is_symbolic(self.characteristic, self.order) + + def iteration_length_or_zero(self) -> SymbolicInt: + """Safe version of iteration length. + + Returns the iteration_length if the type has it or else zero. + """ + return self.order + + def __str__(self): + return f'QGF({self.characteristic}**{self.degree})' + + QAnyInt = (QInt, QUInt, BQUInt, QMontgomeryUInt) -QAnyUInt = (QUInt, BQUInt, QMontgomeryUInt) +QAnyUInt = (QUInt, BQUInt, QMontgomeryUInt, QGF) class QDTypeCheckingSeverity(Enum): @@ -827,7 +954,7 @@ class QDTypeCheckingSeverity(Enum): """Strictly enforce type checking between registers. Only single bit conversions are allowed.""" -def _check_uint_fxp_consistent(a: Union[QUInt, BQUInt, QMontgomeryUInt], b: QFxp) -> bool: +def _check_uint_fxp_consistent(a: Union[QUInt, BQUInt, QMontgomeryUInt, QGF], b: QFxp) -> bool: """A uint / qfxp is consistent with a whole or totally fractional unsigned QFxp.""" if b.signed: return False diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 1dc0ced33..227c8de7b 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -21,7 +21,7 @@ import sympy from numpy.typing import NDArray -from qualtran.symbolics import is_symbolic +from qualtran.symbolics import ceil, is_symbolic, log2 from .data_types import ( BQUInt, @@ -31,6 +31,7 @@ QBit, QDType, QFxp, + QGF, QInt, QIntOnesComp, QMontgomeryUInt, @@ -135,13 +136,23 @@ def test_qmontgomeryuint(): assert is_symbolic(QMontgomeryUInt(sympy.Symbol('x'))) +def test_qgf(): + qgf_256 = QGF(characteristic=2, degree=8) + assert str(qgf_256) == 'QGF(2**8)' + assert qgf_256.num_qubits == 8 + p, m = sympy.symbols('p, m', integer=True, positive=True) + qgf_pm = QGF(characteristic=p, degree=m) + assert qgf_pm.num_qubits == ceil(log2(p**m)) + assert is_symbolic(qgf_pm) + + @pytest.mark.parametrize('qdtype', [QBit(), QInt(4), QUInt(4), BQUInt(3, 5)]) def test_domain_and_validation(qdtype: QDType): for v in qdtype.get_classical_domain(): qdtype.assert_valid_classical_val(v) -@pytest.mark.parametrize('qdtype', [QBit(), QInt(4), QUInt(4), BQUInt(3, 5)]) +@pytest.mark.parametrize('qdtype', [QBit(), QInt(4), QUInt(4), BQUInt(3, 5), QGF(2, 8)]) def test_domain_and_validation_arr(qdtype: QDType): arr = np.array(list(qdtype.get_classical_domain())) qdtype.assert_valid_classical_val_array(arr) @@ -172,6 +183,9 @@ def test_validation_errs(): with pytest.raises(ValueError): QUInt(3).assert_valid_classical_val(-1) + with pytest.raises(ValueError): + QGF(2, 8).assert_valid_classical_val(2**8) + def test_validate_arrays(): rs = np.random.RandomState(52) @@ -233,10 +247,11 @@ def test_single_qubit_consistency(): assert check_dtypes_consistent(QAny(1), QBit()) assert check_dtypes_consistent(BQUInt(1), QBit()) assert check_dtypes_consistent(QFxp(1, 1), QBit()) + assert check_dtypes_consistent(QGF(characteristic=2, degree=1), QBit()) def assert_to_and_from_bits_array_consistent(qdtype: QDType, values: Union[Sequence[Any], NDArray]): - values = np.asarray(values) + values = np.asanyarray(values) bits_array = qdtype.to_bits_array(values) # individual values @@ -263,6 +278,23 @@ def test_qint_to_and_from_bits(): assert_to_and_from_bits_array_consistent(qint4, range(-8, 8)) +def test_qgf_to_and_from_bits(): + from galois import GF + + qgf_256 = QGF(2, 8) + gf256 = GF(2**8) + assert [*qgf_256.get_classical_domain()] == [*range(256)] + a, b = qgf_256.to_bits(gf256(21)), qgf_256.to_bits(gf256(22)) + c = qgf_256.from_bits(list(np.bitwise_xor(a, b))) + assert c == gf256(21) + gf256(22) + for x in gf256.elements: + assert x == gf256.Vector(qgf_256.to_bits(x)) + + with pytest.raises(ValueError): + qgf_256.to_bits(21) + assert_to_and_from_bits_array_consistent(qgf_256, gf256([*range(256)])) + + def test_quint_to_and_from_bits(): quint4 = QUInt(4) assert [*quint4.get_classical_domain()] == [*range(0, 16)] diff --git a/qualtran/simulation/classical_sim.py b/qualtran/simulation/classical_sim.py index 48a326461..8af0af038 100644 --- a/qualtran/simulation/classical_sim.py +++ b/qualtran/simulation/classical_sim.py @@ -81,6 +81,15 @@ def _numpy_dtype_from_qdtype(dtype: 'QDType') -> Type: return object +def _empty_ndarray_from_reg(reg: Register) -> np.ndarray: + from qualtran._infra.data_types import QGF + + if isinstance(reg.dtype, QGF): + return reg.dtype.gf_type.Zeros(reg.shape) + + return np.empty(reg.shape, dtype=_numpy_dtype_from_qdtype(reg.dtype)) + + def _get_in_vals( binst: Union[DanglingT, BloqInstance], reg: Register, soq_assign: Dict[Soquet, ClassicalValT] ) -> ClassicalValT: @@ -88,9 +97,7 @@ def _get_in_vals( if not reg.shape: return soq_assign[Soquet(binst, reg)] - dtype: Type = _numpy_dtype_from_qdtype(reg.dtype) - - arg = np.empty(reg.shape, dtype=dtype) + arg = _empty_ndarray_from_reg(reg) for idx in reg.all_idxs(): soq = Soquet(binst, reg, idx=idx) arg[idx] = soq_assign[soq] @@ -121,7 +128,7 @@ def _update_assign_from_vals( if reg.shape: # `val` is an array - val = np.asarray(val) + val = np.asanyarray(val) if val.shape != reg.shape: raise ValueError( f"Incorrect shape {val.shape} received for {debug_str}. " f"Want {reg.shape}." diff --git a/qualtran/simulation/classical_sim_test.py b/qualtran/simulation/classical_sim_test.py index 947bc7887..6073a50e4 100644 --- a/qualtran/simulation/classical_sim_test.py +++ b/qualtran/simulation/classical_sim_test.py @@ -28,6 +28,7 @@ QBit, QDType, QFxp, + QGF, QInt, QIntOnesComp, QUInt, @@ -169,6 +170,7 @@ def test_notebook(): class TestMultiDimensionalReg(Bloq): dtype: QDType n: int + dtypes_to_assert: tuple[type, ...] = (int, np.integer) @property def signature(self): @@ -180,6 +182,7 @@ def signature(self): ) def on_classical_vals(self, x): + assert all(isinstance(y, self.dtypes_to_assert) for y in x.reshape(-1)) return {'y': x} @@ -197,3 +200,12 @@ def test_multidimensional_classical_sim_for_large_int(): x = [2**88 - 1, 2**12 - 1, 2**54 - 1, 1 - 2**72, 1 - 2**62] bloq = TestMultiDimensionalReg(dtype, len(x)) np.testing.assert_equal(bloq.call_classically(x=np.array(x))[0], x) + + +def test_multidimensional_classical_sim_for_gqf(): + dtype = QGF(2, 2) + x = dtype.gf_type.elements + bloq = TestMultiDimensionalReg(dtype, len(x), (dtype.gf_type,)) + y = bloq.call_classically(x=x)[0] + assert isinstance(y, dtype.gf_type) + np.testing.assert_equal(y, x)