Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fpapa250 committed Oct 31, 2024
1 parent 4092a5b commit 4daddba
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 48 deletions.
24 changes: 12 additions & 12 deletions qualtran/bloqs/factoring/ecc/ec_add_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ECWindowAddR(Bloq):
Args:
n: The bitsize of the two registers storing the elliptic curve point
R: The elliptic curve point to add (NOT in montgomery form).
ec_window_size: The number of bits in the ECAdd window.
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
Registers:
Expand All @@ -146,40 +146,40 @@ class ECWindowAddR(Bloq):

n: int
R: ECPoint
ec_window_size: int
add_window_size: int
mul_window_size: int = 1

@cached_property
def signature(self) -> 'Signature':
return Signature(
[
Register('ctrl', QBit(), shape=(self.ec_window_size,)),
Register('ctrl', QBit(), shape=(self.add_window_size,)),
Register('x', QUInt(self.n)),
Register('y', QUInt(self.n)),
]
)

@cached_property
def qrom(self) -> QROAMClean:
if is_symbolic(self.n) or is_symbolic(self.ec_window_size):
if is_symbolic(self.n) or is_symbolic(self.add_window_size):
log_block_sizes = None
if is_symbolic(self.n) and not is_symbolic(self.ec_window_size):
if is_symbolic(self.n) and not is_symbolic(self.add_window_size):
# We assume that bitsize is much larger than window_size
log_block_sizes = (0,)
return QROAMClean(
[
Shaped((2**self.ec_window_size,)),
Shaped((2**self.ec_window_size,)),
Shaped((2**self.ec_window_size,)),
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
],
selection_bitsizes=(self.ec_window_size,),
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
log_block_sizes=log_block_sizes,
)

cR = self.R
data_a, data_b, data_lam = [0], [0], [0]
for _ in range(1, 2**self.ec_window_size):
for _ in range(1, 2**self.add_window_size):
data_a.append(QMontgomeryUInt(self.n).uint_to_montgomery(int(cR.x), int(self.R.mod)))
data_b.append(QMontgomeryUInt(self.n).uint_to_montgomery(int(cR.y), int(self.R.mod)))
lam_num = (3 * cR.x**2 + cR.curve_a) % cR.mod
Expand All @@ -193,7 +193,7 @@ def qrom(self) -> QROAMClean:

return QROAMClean(
[data_a, data_b, data_lam],
selection_bitsizes=(self.ec_window_size,),
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
)

Expand Down Expand Up @@ -275,7 +275,7 @@ def wire_symbol(
def _ec_window_add_r_small() -> ECWindowAddR:
n = 16
P = ECPoint(2, 2, mod=7, curve_a=3)
ec_window_add_r_small = ECWindowAddR(n=n, R=P, ec_window_size=4)
ec_window_add_r_small = ECWindowAddR(n=n, R=P, add_window_size=4)
return ec_window_add_r_small


Expand Down
26 changes: 23 additions & 3 deletions qualtran/bloqs/factoring/ecc/ec_add_r_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,43 @@ def test_ec_add_r(bloq_autotester, bloq):
def test_ec_window_add_r_bloq_counts(n, window_size, a, b):
p = 17
R = ECPoint(a, b, mod=p)
bloq = ECWindowAddR(n=n, R=R, ec_window_size=window_size)
bloq = ECWindowAddR(n=n, R=R, add_window_size=window_size)
qlt_testing.assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5])
def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
np.testing.assert_array_equal(ret1_i, ret2[i])


@pytest.mark.slow
@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5, 8])
def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
def test_ec_window_add_r_classical_slow(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, ec_window_size=m, mul_window_size=m)
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
Expand Down
28 changes: 18 additions & 10 deletions qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ class ECPhaseEstimateR(Bloq):
This is used as a subroutine in `FindECCPrivateKey`. First, we phase-estimate the
addition of the base point $P$, then of the public key $Q$.
When the ellptic curve point addition window size is 1 we use the ECAddR bloq which has it's
own bespoke circuit; when it is greater than 1 we use the windowed circuit which uses
pre-computed classical additions loaded into the circuit.
Args:
n: The bitsize of the elliptic curve points' x and y registers.
point: The elliptic curve point to phase estimate against.
ec_window_size: The number of bits in the ECAdd window.
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
"""

n: int
point: ECPoint
ec_window_size: int = 1
add_window_size: int = 1
mul_window_size: int = 1

@cached_property
Expand All @@ -65,28 +69,32 @@ def signature(self) -> 'Signature':

@property
def ec_add(self) -> Union[functools.partial[ECAddR], functools.partial[ECWindowAddR]]:
if self.ec_window_size == 1:
if self.add_window_size == 1:
return functools.partial(ECAddR, n=self.n)
return functools.partial(
ECWindowAddR,
n=self.n,
ec_window_size=self.ec_window_size,
add_window_size=self.add_window_size,
mul_window_size=self.mul_window_size,
)

@property
def num_windows(self) -> int:
return self.n // self.add_window_size

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")
ctrl = [bb.add(PlusState()) for _ in range(self.n)]

if self.ec_window_size == 1:
if self.add_window_size == 1:
for i in range(self.n):
ctrl[i], x, y = bb.add(self.ec_add(R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)
else:
ctrls = np.split(np.array(ctrl), self.n // self.ec_window_size)
for i in range(self.n // self.ec_window_size):
ctrls = np.split(np.array(ctrl), self.num_windows)
for i in range(self.num_windows):
ctrls[i], x, y = bb.add(
self.ec_add(R=2 ** (self.ec_window_size * i) * self.point),
self.ec_add(R=2 ** (self.add_window_size * i) * self.point),
ctrl=ctrls[i],
x=x,
y=y,
Expand All @@ -97,15 +105,15 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[
return {'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {self.ec_add(R=self.point): self.n // self.ec_window_size, MeasureQFT(n=self.n): 1}
return {self.ec_add(R=self.point): self.num_windows, MeasureQFT(n=self.n): 1}

def __str__(self) -> str:
return f'PE${self.point}$'


@bloq_example
def _ec_pe() -> ECPhaseEstimateR:
n, p = sympy.symbols('n p ')
n, p = sympy.symbols('n p')
Rx, Ry = sympy.symbols('R_x R_y')
ec_pe = ECPhaseEstimateR(n=n, point=ECPoint(Rx, Ry, mod=p))
return ec_pe
Expand Down
4 changes: 3 additions & 1 deletion qualtran/bloqs/factoring/ecc/ec_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sympy
from attrs import frozen

from qualtran.symbolics import is_symbolic, SymbolicInt
Expand Down Expand Up @@ -50,7 +51,8 @@ def __add__(self, other):
raise ValueError('Use consistent mod and curve')

if is_symbolic(self.x, self.y, other.x, other.y, self.mod, self.curve_a):
return self
x, y, p = sympy.symbols('x y p')
return ECPoint(x=x, y=y, mod=p)
if self == -other:
return ECPoint.inf(mod=self.mod, curve_a=self.curve_a)
if self == ECPoint.inf(mod=self.mod, curve_a=self.curve_a):
Expand Down
28 changes: 9 additions & 19 deletions qualtran/bloqs/factoring/ecc/ecc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
" - `n`: The bitsize of the elliptic curve points' x and y registers.\n",
" - `base_point`: The base point $P$ with unknown order $r$ such that $P = [r] P$.\n",
" - `public_key`: The public key $Q$ such that $Q = [k] P$ for private key $k$.\n",
" - `ec_window_size`: The number of bits in the ECAdd window.\n",
" - `add_window_size`: The number of bits in the ECAdd window.\n",
" - `mul_window_size`: The number of bits in the modular multiplication window. \n",
"\n",
"#### References\n",
Expand Down Expand Up @@ -217,10 +217,14 @@
"This is used as a subroutine in `FindECCPrivateKey`. First, we phase-estimate the\n",
"addition of the base point $P$, then of the public key $Q$.\n",
"\n",
"When the ellptic curve point addition window size is 1 we use the ECAddR bloq which has it's\n",
"own bespoke circuit; when it is greater than 1 we use the windowed circuit which uses\n",
"pre-computed classical additions loaded into the circuit.\n",
"\n",
"#### Parameters\n",
" - `n`: The bitsize of the elliptic curve points' x and y registers.\n",
" - `point`: The elliptic curve point to phase estimate against.\n",
" - `ec_window_size`: The number of bits in the ECAdd window.\n",
" - `add_window_size`: The number of bits in the ECAdd window.\n",
" - `mul_window_size`: The number of bits in the modular multiplication window.\n"
]
},
Expand Down Expand Up @@ -255,7 +259,7 @@
},
"outputs": [],
"source": [
"n, p = sympy.symbols('n p ')\n",
"n, p = sympy.symbols('n p')\n",
"Rx, Ry = sympy.symbols('R_x R_y')\n",
"ec_pe = ECPhaseEstimateR(n=n, point=ECPoint(Rx, Ry, mod=p))"
]
Expand Down Expand Up @@ -482,7 +486,7 @@
"#### Parameters\n",
" - `n`: The bitsize of the two registers storing the elliptic curve point\n",
" - `R`: The elliptic curve point to add (NOT in montgomery form).\n",
" - `ec_window_size`: The number of bits in the ECAdd window.\n",
" - `add_window_size`: The number of bits in the ECAdd window.\n",
" - `mul_window_size`: The number of bits in the modular multiplication window. \n",
"\n",
"#### Registers\n",
Expand Down Expand Up @@ -516,20 +520,6 @@
"### Example Instances"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e1a3a397",
"metadata": {
"cq.autogen": "ECWindowAddR.ec_window_add"
},
"outputs": [],
"source": [
"n, p, w = sympy.symbols('n p w')\n",
"Rx, Ry = sympy.symbols('Rx Ry')\n",
"ec_window_add_r = ECWindowAddR(n=n, ec_window_size=w, R=ECPoint(Rx, Ry, mod=p))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -541,7 +531,7 @@
"source": [
"n = 16\n",
"P = ECPoint(2, 2, mod=7, curve_a=3)\n",
"ec_window_add_r_small = ECWindowAddR(n=n, R=P, ec_window_size=4)"
"ec_window_add_r_small = ECWindowAddR(n=n, R=P, add_window_size=4)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions qualtran/bloqs/factoring/ecc/find_ecc_private_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class FindECCPrivateKey(Bloq):
n: The bitsize of the elliptic curve points' x and y registers.
base_point: The base point $P$ with unknown order $r$ such that $P = [r] P$.
public_key: The public key $Q$ such that $Q = [k] P$ for private key $k$.
ec_window_size: The number of bits in the ECAdd window.
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
References:
Expand All @@ -78,7 +78,7 @@ class FindECCPrivateKey(Bloq):
n: int
base_point: ECPoint
public_key: ECPoint
ec_window_size: int = 1
add_window_size: int = 1
mul_window_size: int = 1

@cached_property
Expand All @@ -102,7 +102,7 @@ def ec_pe_r(self) -> functools.partial[ECPhaseEstimateR]:
return functools.partial(
ECPhaseEstimateR,
n=self.n,
ec_window_size=self.ec_window_size,
add_window_size=self.add_window_size,
mul_window_size=self.mul_window_size,
)

Expand Down

0 comments on commit 4daddba

Please sign in to comment.