Skip to content

Commit

Permalink
Fix bug in KaliskiStep3 and add tests for all steps (#1496)
Browse files Browse the repository at this point in the history
* Fix bug in KaliskiStep3 and add tests for all steps

* cost
  • Loading branch information
NoureldinYosri authored Nov 13, 2024
1 parent 07c98b6 commit e10560c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 27 deletions.
19 changes: 8 additions & 11 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,20 +1462,17 @@ def on_classical_vals(
c: Optional['ClassicalValT'] = None,
target: Optional['ClassicalValT'] = None,
) -> Dict[str, 'ClassicalValT']:
if self._op_symbol in ('>', '<='):
c_val = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c_val = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
if self.uncompute:
assert c == add_ints(
int(a),
int(b),
num_bits=int(self.dtype.bitsize),
is_signed=isinstance(self.dtype, QInt),
)
assert c == c_val
assert target == self._classical_comparison(a, b)
return {'a': a, 'b': b}
if self._op_symbol in ('>', '<='):
c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))}
assert c is None
assert target is None
return {'a': a, 'b': b, 'c': c_val, 'target': int(self._classical_comparison(a, b))}

def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']:
if self._op_symbol in ('>', '<='):
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/factoring/ecc/ec_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_ec_add_symbolic_cost():
# toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n
# 3-controlled toffolis in step 2. The expression is written with rationals because sympy
# comparison fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(391, 2) * n - 31
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31


def test_ec_add(bloq_autotester):
Expand Down
23 changes: 11 additions & 12 deletions qualtran/bloqs/mod_arithmetic/mod_division.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, v: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
print('here')
assert False
m ^= f & (v == 0)
assert is_terminal == 0
is_terminal ^= m
Expand Down Expand Up @@ -101,10 +99,10 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.bitsize):
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize)
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize + 1)
else:
cvs = [0] * int(self.bitsize)
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2}
cvs = [0] * int(self.bitsize) + [1]
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 3}


@frozen
Expand Down Expand Up @@ -197,25 +195,25 @@ def on_classical_vals(
def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet
) -> Dict[str, 'SoquetT']:
u, v, junk, greater_than = bb.add(
u, v, junk_c, greater_than = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v
)

(greater_than, f, b), junk, ctrl = bb.add(
(greater_than, f, b), junk_m, ctrl = bb.add(
MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b)
)

ctrl, a = bb.add(CNOT(), ctrl=ctrl, target=a)
ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m)

greater_than, f, b = bb.add(
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk_m, target=ctrl
)
u, v = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(),
a=u,
b=v,
c=junk,
c=junk_c,
target=greater_than,
)
return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f}
Expand Down Expand Up @@ -391,7 +389,7 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
CNOT(): 4,
CNOT(): 3,
XGate(): 2,
ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1,
CSwapApprox(self.bitsize): 2,
Expand Down Expand Up @@ -475,7 +473,7 @@ def on_classical_vals(
of `f` and `m`.
"""
assert m == 0
is_terminal = f == 1 and v == 0
is_terminal = int(f == 1 and v == 0)
if f == 0:
# When `f = 0` this means that the algorithm is nearly over and that we just need to
# double the value of `r`.
Expand All @@ -489,7 +487,8 @@ def on_classical_vals(
f = 0
r = (r << 1) % self.mod
else:
m = (u % 2 == 1) & (v % 2 == 0)
m = ((u % 2 == 1) & (v % 2 == 0)) or (u % 2 == 1 and v % 2 == 1 and u > v)
m = int(m)
# Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580.
swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v)
if swap:
Expand Down
82 changes: 79 additions & 3 deletions qualtran/bloqs/mod_arithmetic/mod_division_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import qualtran.testing as qlt_testing
from qualtran import QMontgomeryUInt
from qualtran.bloqs.mod_arithmetic import mod_division
from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
Expand All @@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
continue
x_montgomery = dtype.uint_to_montgomery(x, mod)
res = blq.call_classically(x=x_montgomery)
print(x, x_montgomery)

assert res == cblq.call_classically(x=x_montgomery)
assert len(res) == 2
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
Expand Down Expand Up @@ -85,11 +86,11 @@ def test_kaliski_symbolic_cost():
# construction this is just $n-1$ (BitwiseNot -> Add(p+1)).
# - The cost of an iteration in Litinski $13n$ since they ignore constants.
# Our construction is exactly the same but we also count the constants
# which amout to $3$. for a total cost of $13n + 3$.
# which amout to $3$. for a total cost of $13n + 4$.
# For example the cost of ModDbl is 2n+1. In their figure 8, they report
# it as just $2n$. ModDbl gets executed within the 2n loop so its contribution
# to the overal cost should be 4n^2 + 2n instead of just 4n^2.
assert total_toff == 26 * n**2 + 7 * n - 1
assert total_toff == 26 * n**2 + 9 * n - 1


def test_kaliskimodinverse_example(bloq_autotester):
Expand All @@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester):
@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('mod_division')


def test_kaliski_iteration_decomposition():
mod = 7
bitsize = 5
b = mod_division._KaliskiIteration(bitsize, mod)
cb = b.decompose_bloq()
for x in range(mod):
u = mod
v = x
r = 0
s = 1
f = 1

for _ in range(2 * bitsize):
inputs = {'u': u, 'v': v, 'r': r, 's': s, 'm': 0, 'f': f, 'is_terminal': 0}
res = b.call_classically(**inputs)
assert res == cb.call_classically(**inputs), f'{inputs=}'
u, v, r, s, _, f, _ = res # type: ignore

qlt_testing.assert_valid_bloq_decomposition(b)
qlt_testing.assert_equivalent_bloq_counts(b, generalizer=(ignore_alloc_free, ignore_split_join))


def test_kaliski_steps():
bitsize = 5
mod = 7
steps = [
mod_division._KaliskiIterationStep1(bitsize),
mod_division._KaliskiIterationStep2(bitsize),
mod_division._KaliskiIterationStep3(bitsize),
mod_division._KaliskiIterationStep4(bitsize),
mod_division._KaliskiIterationStep5(bitsize),
mod_division._KaliskiIterationStep6(bitsize, mod),
]
csteps = [b.decompose_bloq() for b in steps]

# check decomposition is valid.
for step in steps:
qlt_testing.assert_valid_bloq_decomposition(step)
qlt_testing.assert_equivalent_bloq_counts(
step, generalizer=(ignore_alloc_free, ignore_split_join)
)

# check that for all inputs all 2n iteration work when excuted directly on the 6 steps
# and their decompositions.
for x in range(mod):
u, v, r, s, f = mod, x, 0, 1, 1

for _ in range(2 * bitsize):
a = b = m = is_terminal = 0

res = steps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
assert res == csteps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
v, m, f, is_terminal = res # type: ignore

res = steps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
assert res == csteps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
u, v, r, s, a = res # type: ignore

res = steps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
assert res == csteps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
u, v, r, s, b, f = res # type: ignore

res = steps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
assert res == csteps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
u, v, r, s, b, a, m, f = res # type: ignore

0 comments on commit e10560c

Please sign in to comment.