Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse state prep: allow user to pick target bitsize if needed #1430

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions qualtran/bloqs/arithmetic/permutation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
"\n",
"#### Parameters\n",
" - `N`: the total size the permutation acts on.\n",
" - `cycles`: a sequence of permutation cycles that form the permutation. \n",
" - `cycles`: a sequence of permutation cycles that form the permutation.\n",
" - `bitsize`: number of bits to store the indices, defaults to $\\ceil(\\log_2(N))$. \n",
"\n",
"#### Registers\n",
" - `x`: integer register storing a value in [0, ..., N - 1] \n",
Expand Down Expand Up @@ -235,7 +236,8 @@
"\n",
"#### Parameters\n",
" - `N`: the total size the permutation acts on.\n",
" - `cycle`: the permutation cycle to apply. \n",
" - `cycle`: the permutation cycle to apply.\n",
" - `bitsize`: number of bits to store the indices, defaults to $\\ceil(\\log_2(N))$. \n",
"\n",
"#### Registers\n",
" - `x`: integer register storing a value in [0, ..., N - 1] \n",
Expand Down
16 changes: 10 additions & 6 deletions qualtran/bloqs/arithmetic/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class PermutationCycle(Bloq):
Args:
N: the total size the permutation acts on.
cycle: the permutation cycle to apply.
bitsize: number of bits to store the indices, defaults to $\ceil(\log_2(N))$.

Registers:
x: integer register storing a value in [0, ..., N - 1]
Expand All @@ -95,13 +96,14 @@ class PermutationCycle(Bloq):

N: SymbolicInt
cycle: Union[tuple[int, ...], Shaped] = field(converter=_convert_cycle)
bitsize: SymbolicInt = field()

@cached_property
def signature(self) -> Signature:
return Signature.build_from_dtypes(x=BQUInt(self.bitsize, self.N))

@cached_property
def bitsize(self):
@bitsize.default
def _default_bitsize(self):
return bit_length(self.N - 1)

def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'SoquetT']:
Expand Down Expand Up @@ -194,6 +196,7 @@ class Permutation(Bloq):
Args:
N: the total size the permutation acts on.
cycles: a sequence of permutation cycles that form the permutation.
bitsize: number of bits to store the indices, defaults to $\ceil(\log_2(N))$.

Registers:
x: integer register storing a value in [0, ..., N - 1]
Expand All @@ -205,13 +208,14 @@ class Permutation(Bloq):

N: SymbolicInt
cycles: Union[tuple[SymbolicCycleT, ...], Shaped] = field(converter=_convert_cycles)
bitsize: SymbolicInt = field()

@cached_property
def signature(self) -> Signature:
return Signature.build_from_dtypes(x=BQUInt(self.bitsize, self.N))

@cached_property
def bitsize(self):
@bitsize.default
def _default_bitsize(self):
return bit_length(self.N - 1)

def is_symbolic(self):
Expand Down Expand Up @@ -265,7 +269,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq
raise DecomposeTypeError(f"cannot decompose symbolic {self}")

for cycle in self.cycles:
x = bb.add(PermutationCycle(self.N, cycle), x=x)
x = bb.add(PermutationCycle(self.N, cycle, self.bitsize), x=x)

return {'x': x}

Expand All @@ -275,7 +279,7 @@ def build_call_graph(
if is_symbolic(self.cycles):
# worst case cost: single cycle of length N
cycle = Shaped((self.N,))
return {PermutationCycle(self.N, cycle): 1}
return {PermutationCycle(self.N, cycle, self.bitsize): 1}

return super().build_call_graph(ssa)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Sequence, TYPE_CHECKING, Union

import attrs
import numpy as np
import sympy
from attrs import field, frozen
Expand All @@ -25,6 +26,7 @@
_to_tuple_or_has_length,
StatePreparationViaRotations,
)
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics import bit_length, HasLength, is_symbolic, slen, SymbolicInt

if TYPE_CHECKING:
Expand Down Expand Up @@ -58,21 +60,24 @@ class SparseStatePreparationViaRotations(Bloq):
nonzero_coeffs: Union[tuple[complex, ...], HasLength] = field(converter=_to_tuple_or_has_length)
N: SymbolicInt
phase_bitsize: SymbolicInt
target_bitsize: SymbolicInt = field()
anurudhp marked this conversation as resolved.
Show resolved Hide resolved

def __attrs_post_init__(self):
n_idx = slen(self.sparse_indices)
n_coeff = slen(self.nonzero_coeffs)
if not is_symbolic(n_idx, n_coeff) and n_idx != n_coeff:
raise ValueError(f"Number of indices {n_idx} must equal number of coeffs {n_coeff}")
if not is_symbolic(self.target_bitsize, self.N):
assert 2**self.target_bitsize >= self.N

@property
def signature(self) -> Signature:
return Signature.build_from_dtypes(
target_state=QUInt(self.target_bitsize), phase_gradient=QAny(self.phase_bitsize)
)

@property
def target_bitsize(self) -> SymbolicInt:
@target_bitsize.default
def _default_target_bitsize(self) -> SymbolicInt:
return bit_length(self.N - 1)

@property
Expand Down Expand Up @@ -160,6 +165,7 @@ def _dense_stateprep_bloq(self) -> StatePreparationViaRotations:
dense_coeffs_padded = np.pad(
list(self.nonzero_coeffs), (0, 2**self.dense_bitsize - len(self.nonzero_coeffs))
)
dense_coeffs_padded = dense_coeffs_padded / np.linalg.norm(dense_coeffs_padded)
return StatePreparationViaRotations(tuple(dense_coeffs_padded.tolist()), self.phase_bitsize)

@property
Expand All @@ -170,9 +176,10 @@ def _basis_permutation_bloq(self) -> Permutation:

assert isinstance(self.sparse_indices, tuple)

return Permutation.from_partial_permutation_map(
permute_bloq = Permutation.from_partial_permutation_map(
self.N, dict(enumerate(self.sparse_indices))
)
return attrs.evolve(permute_bloq, bitsize=self.target_bitsize)

def build_composite_bloq(
self, bb: 'BloqBuilder', target_state: 'SoquetT', phase_gradient: 'SoquetT'
Expand All @@ -198,10 +205,24 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {self._dense_stateprep_bloq: 1, self._basis_permutation_bloq: 1}


@bloq_example
@bloq_example(generalizer=ignore_split_join)
def _sparse_state_prep_via_rotations() -> SparseStatePreparationViaRotations:
sparse_state_prep_via_rotations = SparseStatePreparationViaRotations.from_sparse_array(
[0.70914953, 0, 0, 0, 0.46943701, 0, 0.2297245, 0, 0, 0.32960471, 0, 0, 0.33959273, 0, 0],
phase_bitsize=2,
)
return sparse_state_prep_via_rotations


@bloq_example(generalizer=ignore_split_join)
def _sparse_state_prep_via_rotations_with_large_target_bitsize() -> (
SparseStatePreparationViaRotations
):
sparse_state_prep_via_rotations = SparseStatePreparationViaRotations.from_sparse_array(
[0.70914953, 0, 0, 0, 0.46943701, 0, 0.2297245, 0, 0, 0.32960471, 0, 0, 0.33959273, 0, 0],
phase_bitsize=2,
)
sparse_state_prep_via_rotations_with_large_target_bitsize = attrs.evolve(
sparse_state_prep_via_rotations, target_bitsize=6
)
return sparse_state_prep_via_rotations_with_large_target_bitsize
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import attrs
import numpy as np
import pytest
from numpy.typing import NDArray
Expand All @@ -20,12 +23,17 @@
from qualtran.bloqs.rotations import PhaseGradientState
from qualtran.bloqs.state_preparation.sparse_state_preparation_via_rotations import (
_sparse_state_prep_via_rotations,
_sparse_state_prep_via_rotations_with_large_target_bitsize,
SparseStatePreparationViaRotations,
)


def test_examples(bloq_autotester):
bloq_autotester(_sparse_state_prep_via_rotations)
@pytest.mark.parametrize(
"bloq_ex",
[_sparse_state_prep_via_rotations, _sparse_state_prep_via_rotations_with_large_target_bitsize],
)
def test_examples(bloq_autotester, bloq_ex):
bloq_autotester(bloq_ex)
Comment on lines +31 to +36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a slow test that asserts the simulation is correct? The bloq example auto tests do not cover correctness of decomposition, AFAIR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added simulation tests (at the end)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks



def get_prepared_state_vector(bloq: SparseStatePreparationViaRotations) -> NDArray[np.complex128]:
Expand All @@ -40,7 +48,8 @@ def get_prepared_state_vector(bloq: SparseStatePreparationViaRotations) -> NDArr


@pytest.mark.slow
def test_prepared_state():
@pytest.mark.parametrize("target_bitsize", [None, 4, 6])
def test_prepared_state(target_bitsize: Optional[int]):
expected_state = np.array(
[
(-0.42677669529663675 - 0.1767766952966366j),
Expand All @@ -63,6 +72,9 @@ def test_prepared_state():
N = len(expected_state)

bloq = SparseStatePreparationViaRotations.from_sparse_array(expected_state, phase_bitsize=3)
if target_bitsize is not None:
bloq = attrs.evolve(bloq, target_bitsize=target_bitsize)

actual_state = get_prepared_state_vector(bloq)
np.testing.assert_allclose(np.linalg.norm(actual_state), 1)
np.testing.assert_allclose(actual_state[:N], expected_state)
Expand Down
1 change: 1 addition & 0 deletions qualtran/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def assert_bloq_example_serializes_for_pytest(bloq_ex: BloqExample):
'state_prep_via_rotation_symb', # cannot serialize HasLength
'state_prep_via_rotation_symb_phasegrad', # cannot serialize Shaped
'sparse_state_prep_via_rotations', # cannot serialize Permutation
'sparse_state_prep_via_rotations_with_large_target_bitsize', # setting an array element with a sequence.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Is it because we can't serialize permutation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think so

'explicit_matrix_block_encoding', # cannot serialize AutoPartition
'symmetric_banded_matrix_block_encoding', # cannot serialize AutoPartition
'chebyshev_poly_even',
Expand Down
Loading