Skip to content

Commit

Permalink
Add efficient linear combination for Montgomery forms (#666)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead authored Sep 19, 2024
1 parent 20059a7 commit 21cba95
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 2 deletions.
21 changes: 21 additions & 0 deletions benches/boxed_monty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ fn bench_montgomery_ops<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
BatchSize::SmallInput,
)
});

group.bench_function(
"lincomb_vartime, BoxedUint*BoxedUint+BoxedUint*BoxedUint",
|b| {
b.iter_batched(
|| {
BoxedMontyForm::new(
BoxedUint::random_mod(&mut OsRng, params.modulus().as_nz_ref()),
params.clone(),
)
},
|a| {
BoxedMontyForm::lincomb_vartime(&[
(black_box(&a), black_box(&a)),
(black_box(&a), black_box(&a)),
])
},
BatchSize::SmallInput,
)
},
);
}

fn bench_montgomery_conversion<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
Expand Down
15 changes: 14 additions & 1 deletion benches/const_monty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crypto_bigint::MultiExponentiate;
impl_modulus!(
Modulus,
U256,
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
"7fffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
);

type ConstMontyForm = crypto_bigint::modular::ConstMontyForm<Modulus, { U256::LIMBS }>;
Expand Down Expand Up @@ -129,6 +129,19 @@ fn bench_montgomery_ops<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
)
});

group.bench_function("lincomb_vartime, U256*U256+U256*U256", |b| {
b.iter_batched(
|| ConstMontyForm::random(&mut OsRng),
|a| {
ConstMontyForm::lincomb_vartime(&[
(black_box(a), black_box(a)),
(black_box(a), black_box(a)),
])
},
BatchSize::SmallInput,
)
});

#[cfg(feature = "alloc")]
for i in [1, 2, 3, 4, 10, 100] {
group.bench_function(
Expand Down
2 changes: 1 addition & 1 deletion src/limb/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Ord for Limb {
let mut ret = Ordering::Less;
ret.conditional_assign(&Ordering::Equal, self.ct_eq(other));
ret.conditional_assign(&Ordering::Greater, self.ct_gt(other));
debug_assert_eq!(ret == Ordering::Less, self.ct_lt(other).into());
debug_assert_eq!(ret == Ordering::Less, bool::from(self.ct_lt(other)));
ret
}
}
Expand Down
1 change: 1 addition & 0 deletions src/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! the modulus can vary at runtime.

mod const_monty_form;
mod lincomb;
mod monty_form;
mod reduction;

Expand Down
7 changes: 7 additions & 0 deletions src/modular/boxed_monty_form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

mod add;
mod inv;
mod lincomb;
mod mul;
mod neg;
mod pow;
Expand Down Expand Up @@ -34,6 +35,8 @@ pub struct BoxedMontyParams {
/// The lowest limbs of -(MODULUS^-1) mod R
/// We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS.
mod_neg_inv: Limb,
/// Leading zeros in the modulus, used to choose optimized algorithms
mod_leading_zeros: u32,
}

impl BoxedMontyParams {
Expand Down Expand Up @@ -93,6 +96,9 @@ impl BoxedMontyParams {
debug_assert!(bool::from(modulus_is_odd));

let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0));

let mod_leading_zeros = modulus.as_ref().leading_zeros().max(Word::BITS - 1);

let r3 = montgomery_reduction_boxed(&mut r2.square(), &modulus, mod_neg_inv);

Self {
Expand All @@ -101,6 +107,7 @@ impl BoxedMontyParams {
r2,
r3,
mod_neg_inv,
mod_leading_zeros,
}
}

Expand Down
75 changes: 75 additions & 0 deletions src/modular/boxed_monty_form/lincomb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//! Linear combinations of integers in Montgomery form with a modulus set at runtime.

use super::BoxedMontyForm;
use crate::modular::lincomb::lincomb_boxed_monty_form;

impl BoxedMontyForm {
/// Calculate the sum of products of pairs `(a, b)` in `products`.
///
/// This method is variable time only with the value of the modulus.
/// For a modulus with leading zeros, this method is more efficient than a naive sum of products.
///
/// This method will panic if `products` is empty. All terms must be associated
/// with equivalent `MontyParams`.
pub fn lincomb_vartime(products: &[(&Self, &Self)]) -> Self {
assert!(!products.is_empty(), "empty products");
let params = &products[0].0.params;
Self {
montgomery_form: lincomb_boxed_monty_form(
products,
&params.modulus,
params.mod_neg_inv,
params.mod_leading_zeros,
),
params: products[0].0.params.clone(),
}
}
}

#[cfg(test)]
mod tests {

#[cfg(feature = "rand")]
#[test]
fn lincomb_expected() {
use crate::modular::{BoxedMontyForm, BoxedMontyParams};
use crate::{BoxedUint, Odd, RandomMod};
use rand_core::SeedableRng;

const SIZE: u32 = 511;

let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
for n in 0..100 {
let modulus = Odd::<BoxedUint>::random(&mut rng, SIZE);
let params = BoxedMontyParams::new(modulus.clone());
let a = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());
let b = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());
let c = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());
let d = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());
let e = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());
let f = BoxedUint::random_mod(&mut rng, modulus.as_nz_ref());

let std = a
.mul_mod(&b, &modulus)
.add_mod(&c.mul_mod(&d, &modulus), &modulus)
.add_mod(&e.mul_mod(&f, &modulus), &modulus);

let lincomb = BoxedMontyForm::lincomb_vartime(&[
(
&BoxedMontyForm::new(a, params.clone()),
&BoxedMontyForm::new(b, params.clone()),
),
(
&BoxedMontyForm::new(c, params.clone()),
&BoxedMontyForm::new(d, params.clone()),
),
(
&BoxedMontyForm::new(e, params.clone()),
&BoxedMontyForm::new(f, params.clone()),
),
]);

assert_eq!(std, lincomb.retrieve(), "n={n}");
}
}
}
3 changes: 3 additions & 0 deletions src/modular/const_monty_form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

mod add;
pub(super) mod inv;
mod lincomb;
mod mul;
mod neg;
mod pow;
Expand Down Expand Up @@ -50,6 +51,8 @@ pub trait ConstMontyParams<const LIMBS: usize>:
/// The lowest limbs of -(MODULUS^-1) mod R
// We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS.
const MOD_NEG_INV: Limb;
/// Leading zeros in the modulus, used to choose optimized algorithms
const MOD_LEADING_ZEROS: u32;

/// Precompute a Bernstein-Yang inverter for this modulus.
///
Expand Down
60 changes: 60 additions & 0 deletions src/modular/const_monty_form/lincomb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//! Linear combinations of integers n Montgomery form with a constant modulus.

use core::marker::PhantomData;

use super::{ConstMontyForm, ConstMontyParams};
use crate::modular::lincomb::lincomb_const_monty_form;

impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS> {
/// Calculate the sum of products of pairs `(a, b)` in `products`.
///
/// This method is variable time only with the value of the modulus.
/// For a modulus with leading zeros, this method is more efficient than a naive sum of products.
pub const fn lincomb_vartime(products: &[(Self, Self)]) -> Self {
Self {
montgomery_form: lincomb_const_monty_form(products, &MOD::MODULUS, MOD::MOD_NEG_INV),
phantom: PhantomData,
}
}
}

#[cfg(test)]
mod tests {

#[cfg(feature = "rand")]
#[test]
fn lincomb_expected() {
use super::{ConstMontyForm, ConstMontyParams};
use crate::{impl_modulus, RandomMod, U256};
use rand_core::SeedableRng;
impl_modulus!(
P,
U256,
"7fffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
);
let modulus = P::MODULUS.as_nz_ref();

let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
for n in 0..1000 {
let a = U256::random_mod(&mut rng, modulus);
let b = U256::random_mod(&mut rng, modulus);
let c = U256::random_mod(&mut rng, modulus);
let d = U256::random_mod(&mut rng, modulus);
let e = U256::random_mod(&mut rng, modulus);
let f = U256::random_mod(&mut rng, modulus);

assert_eq!(
a.mul_mod(&b, modulus)
.add_mod(&c.mul_mod(&d, modulus), modulus)
.add_mod(&e.mul_mod(&f, modulus), modulus),
ConstMontyForm::<P, { P::LIMBS }>::lincomb_vartime(&[
(ConstMontyForm::new(&a), ConstMontyForm::new(&b)),
(ConstMontyForm::new(&c), ConstMontyForm::new(&d)),
(ConstMontyForm::new(&e), ConstMontyForm::new(&f)),
])
.retrieve(),
"n={n}"
)
}
}
}
10 changes: 10 additions & 0 deletions src/modular/const_monty_form/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ macro_rules! impl_modulus {
),
);

// Leading zeros in the modulus, used to choose optimized algorithms.
const MOD_LEADING_ZEROS: u32 = {
let z = Self::MODULUS.as_ref().leading_zeros();
if z >= $crate::Word::BITS {
$crate::Word::BITS - 1
} else {
z
}
};

const R3: $uint_type = $crate::modular::montgomery_reduction(
&Self::R2.square_wide(),
&Self::MODULUS,
Expand Down
Loading

0 comments on commit 21cba95

Please sign in to comment.