diff --git a/qualtran/_infra/gate_with_registers.py b/qualtran/_infra/gate_with_registers.py index e9056cdd3..219f3e995 100644 --- a/qualtran/_infra/gate_with_registers.py +++ b/qualtran/_infra/gate_with_registers.py @@ -373,8 +373,9 @@ def __pow__(self, power: int) -> 'GateWithRegisters': return Power(bloq, abs(power)) raise NotImplementedError(f"{self} does not implemented __pow__ for {power=}.") + @classmethod def _get_ctrl_spec( - self, + cls, num_controls: Union[Optional[int], 'CtrlSpec'] = None, control_values=None, control_qid_shape: Optional[Tuple[int, ...]] = None, @@ -498,7 +499,7 @@ def controlled( Returns: A controlled version of the bloq. """ - ctrl_spec = self._get_ctrl_spec( + ctrl_spec = GateWithRegisters._get_ctrl_spec( num_controls, control_values, control_qid_shape, ctrl_spec=ctrl_spec ) controlled_bloq, _ = self.get_ctrl_system(ctrl_spec=ctrl_spec) diff --git a/qualtran/bloqs/mcmt/controlled_via_and.ipynb b/qualtran/bloqs/mcmt/controlled_via_and.ipynb index 1785c5512..825c66f74 100644 --- a/qualtran/bloqs/mcmt/controlled_via_and.ipynb +++ b/qualtran/bloqs/mcmt/controlled_via_and.ipynb @@ -153,6 +153,36 @@ "show_call_graph(controlled_via_and_ints_g)\n", "show_counts_sigma(controlled_via_and_ints_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Nested Controls\n", + "Calling `controlled` on a `ControlledViaAnd` returns another `ControlledViaAnd` by combining the existing and new controls into a single control specification." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "nested_ctrl_bloq = controlled_via_and_qbits.controlled(CtrlSpec(cvs=[1, 1]))\n", + "show_bloqs([nested_ctrl_bloq])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "show_call_graph(nested_ctrl_bloq)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/mcmt/controlled_via_and.py b/qualtran/bloqs/mcmt/controlled_via_and.py index a6778483d..b9f7389a0 100644 --- a/qualtran/bloqs/mcmt/controlled_via_and.py +++ b/qualtran/bloqs/mcmt/controlled_via_and.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import Counter from functools import cached_property -from typing import TYPE_CHECKING +from typing import Iterable, Sequence, TYPE_CHECKING import numpy as np from attrs import frozen @@ -23,7 +23,7 @@ from qualtran.bloqs.mcmt.ctrl_spec_and import CtrlSpecAnd if TYPE_CHECKING: - from qualtran import BloqBuilder, SoquetT + from qualtran import AddControlledT, BloqBuilder, SoquetT from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -126,6 +126,28 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return counts + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + ctrl_spec_combined = CtrlSpec( + qdtypes=ctrl_spec.qdtypes + self.ctrl_spec.qdtypes, + cvs=ctrl_spec.cvs + self.ctrl_spec.cvs, + ) + ctrl_bloq = ControlledViaAnd(subbloq=self.subbloq, ctrl_spec=ctrl_spec_combined) + + def _adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] + ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + rhs_ctrl_soqs_t = tuple(in_soqs.pop(name) for name in self.ctrl_reg_names) + all_ctrl_soqs_t = tuple([*ctrl_soqs, *rhs_ctrl_soqs_t]) + + all_ctrl_soqs_d = dict(zip(ctrl_bloq.ctrl_reg_names, all_ctrl_soqs_t)) + all_soqs = all_ctrl_soqs_d | in_soqs + all_soqs = bb.add_t(ctrl_bloq, **all_soqs) + + n_ctrl_lhs = ctrl_spec.num_ctrl_reg + return all_soqs[:n_ctrl_lhs], all_soqs[n_ctrl_lhs:] + + return ctrl_bloq, _adder + @bloq_example def _controlled_via_and_qbits() -> ControlledViaAnd: diff --git a/qualtran/bloqs/mcmt/controlled_via_and_test.py b/qualtran/bloqs/mcmt/controlled_via_and_test.py index 981d7396f..9e60e3ecc 100644 --- a/qualtran/bloqs/mcmt/controlled_via_and_test.py +++ b/qualtran/bloqs/mcmt/controlled_via_and_test.py @@ -15,12 +15,14 @@ import pytest from qualtran import Controlled, CtrlSpec, QInt, QUInt +from qualtran.bloqs.basic_gates import XGate from qualtran.bloqs.for_testing.matrix_gate import MatrixGate from qualtran.bloqs.mcmt.controlled_via_and import ( _controlled_via_and_ints, _controlled_via_and_qbits, ControlledViaAnd, ) +from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost def test_examples(bloq_autotester): @@ -40,10 +42,39 @@ def test_tensor_against_naive_controlled(ctrl_spec: CtrlSpec): rs = np.random.RandomState(42) subbloq = MatrixGate.random(2, random_state=rs) - cbloq = ControlledViaAnd(subbloq, ctrl_spec) - naive_cbloq = Controlled(subbloq, ctrl_spec) + ctrl_bloq = ControlledViaAnd(subbloq, ctrl_spec) + naive_ctrl_bloq = Controlled(subbloq, ctrl_spec) - expected_tensor = naive_cbloq.tensor_contract() - actual_tensor = cbloq.tensor_contract() + expected_tensor = naive_ctrl_bloq.tensor_contract() + actual_tensor = ctrl_bloq.tensor_contract() np.testing.assert_allclose(expected_tensor, actual_tensor) + + +def test_nested_controls(): + spec1 = CtrlSpec(QUInt(4), [2, 3]) + spec2 = CtrlSpec(QInt(4), [1, 2]) + spec = CtrlSpec((QInt(4), QUInt(4)), ([1, 2], [2, 3])) + + rs = np.random.RandomState(42) + bloq = MatrixGate.random(2, random_state=rs) + + ctrl_bloq = ControlledViaAnd(bloq, spec1).controlled(ctrl_spec=spec2) + assert ctrl_bloq == ControlledViaAnd(bloq, spec) + + +def test_nested_controlled_x(): + bloq = XGate() + + ctrl_bloq = ControlledViaAnd(bloq, CtrlSpec(cvs=[1, 1])).controlled( + ctrl_spec=CtrlSpec(cvs=[1, 1]) + ) + cost = get_cost_value(ctrl_bloq, QECGatesCost()) + + n_ands = 3 + assert cost == GateCounts(and_bloq=n_ands, clifford=n_ands + 1, measurement=n_ands) + + np.testing.assert_allclose( + ctrl_bloq.tensor_contract(), + XGate().controlled(CtrlSpec(cvs=[1, 1, 1, 1])).tensor_contract(), + ) diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and.py b/qualtran/bloqs/mcmt/ctrl_spec_and.py index d22b47f4e..7bcb195c8 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and.py +++ b/qualtran/bloqs/mcmt/ctrl_spec_and.py @@ -80,7 +80,9 @@ def __attrs_post_init__(self): if not is_symbolic(self.n_ctrl_qubits) and self.n_ctrl_qubits <= 1: raise ValueError(f"Expected at least 2 controls, got {self.n_ctrl_qubits}") for qdtype in self.ctrl_spec.qdtypes: - if not isinstance(qdtype, (QBit, QInt, QUInt, BQUInt, QIntOnesComp, QMontgomeryUInt)): + if not isinstance( + qdtype, (QBit, QAny, QInt, QUInt, BQUInt, QIntOnesComp, QMontgomeryUInt) + ): raise NotImplementedError("CtrlSpecAnd currently only supports integer types") @cached_property