From e10560cebb9b3c3005a3e560a40b712db207f301 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 13 Nov 2024 09:16:20 -0800 Subject: [PATCH] Fix bug in KaliskiStep3 and add tests for all steps (#1496) * Fix bug in KaliskiStep3 and add tests for all steps * cost --- qualtran/bloqs/arithmetic/comparison.py | 19 ++--- qualtran/bloqs/factoring/ecc/ec_add_test.py | 2 +- qualtran/bloqs/mod_arithmetic/mod_division.py | 23 +++--- .../bloqs/mod_arithmetic/mod_division_test.py | 82 ++++++++++++++++++- 4 files changed, 99 insertions(+), 27 deletions(-) diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index b28269ae4..fff30d5fc 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -1462,20 +1462,17 @@ def on_classical_vals( c: Optional['ClassicalValT'] = None, target: Optional['ClassicalValT'] = None, ) -> Dict[str, 'ClassicalValT']: + if self._op_symbol in ('>', '<='): + c_val = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) + else: + c_val = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) if self.uncompute: - assert c == add_ints( - int(a), - int(b), - num_bits=int(self.dtype.bitsize), - is_signed=isinstance(self.dtype, QInt), - ) + assert c == c_val assert target == self._classical_comparison(a, b) return {'a': a, 'b': b} - if self._op_symbol in ('>', '<='): - c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) - else: - c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) - return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))} + assert c is None + assert target is None + return {'a': a, 'b': b, 'c': c_val, 'target': int(self._classical_comparison(a, b))} def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']: if self._op_symbol in ('>', '<='): diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 37c397707..7f7439d2d 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -418,7 +418,7 @@ def test_ec_add_symbolic_cost(): # toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n # 3-controlled toffolis in step 2. The expression is written with rationals because sympy # comparison fails with floats. - assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(391, 2) * n - 31 + assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31 def test_ec_add(bloq_autotester): diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.py b/qualtran/bloqs/mod_arithmetic/mod_division.py index c099c7562..06f643525 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division.py @@ -72,8 +72,6 @@ def signature(self) -> 'Signature': def on_classical_vals( self, v: int, m: int, f: int, is_terminal: int ) -> Dict[str, 'ClassicalValT']: - print('here') - assert False m ^= f & (v == 0) assert is_terminal == 0 is_terminal ^= m @@ -101,10 +99,10 @@ def build_composite_bloq( def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': if is_symbolic(self.bitsize): - cvs: Union[HasLength, List[int]] = HasLength(self.bitsize) + cvs: Union[HasLength, List[int]] = HasLength(self.bitsize + 1) else: - cvs = [0] * int(self.bitsize) - return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2} + cvs = [0] * int(self.bitsize) + [1] + return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 3} @frozen @@ -197,11 +195,11 @@ def on_classical_vals( def build_composite_bloq( self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet ) -> Dict[str, 'SoquetT']: - u, v, junk, greater_than = bb.add( + u, v, junk_c, greater_than = bb.add( LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v ) - (greater_than, f, b), junk, ctrl = bb.add( + (greater_than, f, b), junk_m, ctrl = bb.add( MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b) ) @@ -209,13 +207,13 @@ def build_composite_bloq( ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m) greater_than, f, b = bb.add( - MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl + MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk_m, target=ctrl ) u, v = bb.add( LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(), a=u, b=v, - c=junk, + c=junk_c, target=greater_than, ) return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f} @@ -391,7 +389,7 @@ def build_composite_bloq( def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { - CNOT(): 4, + CNOT(): 3, XGate(): 2, ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1, CSwapApprox(self.bitsize): 2, @@ -475,7 +473,7 @@ def on_classical_vals( of `f` and `m`. """ assert m == 0 - is_terminal = f == 1 and v == 0 + is_terminal = int(f == 1 and v == 0) if f == 0: # When `f = 0` this means that the algorithm is nearly over and that we just need to # double the value of `r`. @@ -489,7 +487,8 @@ def on_classical_vals( f = 0 r = (r << 1) % self.mod else: - m = (u % 2 == 1) & (v % 2 == 0) + m = ((u % 2 == 1) & (v % 2 == 0)) or (u % 2 == 1 and v % 2 == 1 and u > v) + m = int(m) # Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580. swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v) if swap: diff --git a/qualtran/bloqs/mod_arithmetic/mod_division_test.py b/qualtran/bloqs/mod_arithmetic/mod_division_test.py index 093f0908f..934a26967 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division_test.py @@ -19,6 +19,7 @@ import qualtran.testing as qlt_testing from qualtran import QMontgomeryUInt +from qualtran.bloqs.mod_arithmetic import mod_division from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse from qualtran.resource_counting import get_cost_value, QECGatesCost from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod): continue x_montgomery = dtype.uint_to_montgomery(x, mod) res = blq.call_classically(x=x_montgomery) - print(x, x_montgomery) + assert res == cblq.call_classically(x=x_montgomery) assert len(res) == 2 assert res[0] == dtype.montgomery_inverse(x_montgomery, mod) @@ -85,11 +86,11 @@ def test_kaliski_symbolic_cost(): # construction this is just $n-1$ (BitwiseNot -> Add(p+1)). # - The cost of an iteration in Litinski $13n$ since they ignore constants. # Our construction is exactly the same but we also count the constants - # which amout to $3$. for a total cost of $13n + 3$. + # which amout to $3$. for a total cost of $13n + 4$. # For example the cost of ModDbl is 2n+1. In their figure 8, they report # it as just $2n$. ModDbl gets executed within the 2n loop so its contribution # to the overal cost should be 4n^2 + 2n instead of just 4n^2. - assert total_toff == 26 * n**2 + 7 * n - 1 + assert total_toff == 26 * n**2 + 9 * n - 1 def test_kaliskimodinverse_example(bloq_autotester): @@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester): @pytest.mark.notebook def test_notebook(): qlt_testing.execute_notebook('mod_division') + + +def test_kaliski_iteration_decomposition(): + mod = 7 + bitsize = 5 + b = mod_division._KaliskiIteration(bitsize, mod) + cb = b.decompose_bloq() + for x in range(mod): + u = mod + v = x + r = 0 + s = 1 + f = 1 + + for _ in range(2 * bitsize): + inputs = {'u': u, 'v': v, 'r': r, 's': s, 'm': 0, 'f': f, 'is_terminal': 0} + res = b.call_classically(**inputs) + assert res == cb.call_classically(**inputs), f'{inputs=}' + u, v, r, s, _, f, _ = res # type: ignore + + qlt_testing.assert_valid_bloq_decomposition(b) + qlt_testing.assert_equivalent_bloq_counts(b, generalizer=(ignore_alloc_free, ignore_split_join)) + + +def test_kaliski_steps(): + bitsize = 5 + mod = 7 + steps = [ + mod_division._KaliskiIterationStep1(bitsize), + mod_division._KaliskiIterationStep2(bitsize), + mod_division._KaliskiIterationStep3(bitsize), + mod_division._KaliskiIterationStep4(bitsize), + mod_division._KaliskiIterationStep5(bitsize), + mod_division._KaliskiIterationStep6(bitsize, mod), + ] + csteps = [b.decompose_bloq() for b in steps] + + # check decomposition is valid. + for step in steps: + qlt_testing.assert_valid_bloq_decomposition(step) + qlt_testing.assert_equivalent_bloq_counts( + step, generalizer=(ignore_alloc_free, ignore_split_join) + ) + + # check that for all inputs all 2n iteration work when excuted directly on the 6 steps + # and their decompositions. + for x in range(mod): + u, v, r, s, f = mod, x, 0, 1, 1 + + for _ in range(2 * bitsize): + a = b = m = is_terminal = 0 + + res = steps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal) + assert res == csteps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal) + v, m, f, is_terminal = res # type: ignore + + res = steps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f) + assert res == csteps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f) + u, v, b, a, m, f = res # type: ignore + + res = steps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f) + assert res == csteps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f) + u, v, b, a, m, f = res # type: ignore + + res = steps[3].call_classically(u=u, v=v, r=r, s=s, a=a) + assert res == csteps[3].call_classically(u=u, v=v, r=r, s=s, a=a) + u, v, r, s, a = res # type: ignore + + res = steps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f) + assert res == csteps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f) + u, v, r, s, b, f = res # type: ignore + + res = steps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f) + assert res == csteps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f) + u, v, r, s, b, a, m, f = res # type: ignore