Skip to content

Commit

Permalink
Refactor MSM instances in the IPA (microsoft#290)
Browse files Browse the repository at this point in the history
* refactor: remove unneeded MSM from IPA

- Added new trait usage and extended `AffineExt` with more bounds in `src/provider/traits.rs` to enable greater functionality.
- new bounds on AffineExt are trivially satisfied and parrot existing ones, they are repeated to work around the lack of associated_type_bounds
- Optimized vector processing in the `fold` method of `src/provider/pedersen.rs`. Element-wise vector addition replaced the previous multiscalar multiplication method.

* chore: remove unneeded trait bounds

* chore: some simplifications

* refactor: Optimize commitment key handling and IPA

- Transitioned variable ck and ck_c to mutable to facilitate in-place operations.
- Optimized the `split_at` function in `Pedersen.rs` using the `split_off` method.
- Reconfigured the `combine` function to clone and chain iterables for expedited vector creation.
- Reassessed the `scale` function to in-place operations, eliminating the need for new struct instances.
- Streamlined code in `scale` and `fold` functions by removing superfluous variable assignments.

* refactor: test going back to MSM in fold
  • Loading branch information
huitseeker authored Feb 2, 2024
1 parent b168f27 commit 6fde742
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 90 deletions.
5 changes: 4 additions & 1 deletion src/provider/hyperkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ where
let B = kzg_compute_batch_polynomial(f, q);

// Now open B at u0, ..., u_{t-1}
let w = u.par_iter().map(|ui| kzg_open(&B, *ui)).collect::<Vec<_>>();
let w = u
.into_par_iter()
.map(|ui| kzg_open(&B, *ui))
.collect::<Vec<_>>();

// The prover computes the challenge to keep the transcript in the same
// state as that of the verifier
Expand Down
40 changes: 15 additions & 25 deletions src/provider/ipa_pc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
evaluation::EvaluationEngineTrait,
Engine, TranscriptEngineTrait, TranscriptReprTrait,
},
Commitment, CommitmentKey, CompressedCommitment, CE,
zip_with, Commitment, CommitmentKey, CompressedCommitment, CE,
};
use core::iter;
use ff::Field;
Expand All @@ -19,13 +19,13 @@ use std::marker::PhantomData;
use std::sync::Arc;

/// Provides an implementation of the prover key
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct ProverKey<E: Engine> {
ck_s: CommitmentKey<E>,
}

/// Provides an implementation of the verifier key
#[derive(Clone, Debug, Serialize)]
#[derive(Debug, Serialize)]
#[serde(bound = "")]
pub struct VerifierKey<E: Engine> {
ck_v: Arc<CommitmentKey<E>>,
Expand Down Expand Up @@ -76,7 +76,7 @@ where
let u = InnerProductInstance::new(comm, &EqPolynomial::evals_from_points(point), eval);
let w = InnerProductWitness::new(poly);

InnerProductArgument::prove(ck, &pk.ck_s, &u, &w, transcript)
InnerProductArgument::prove(ck.clone(), pk.ck_s.clone(), &u, &w, transcript)
}

/// A method to verify purported evaluations of a batch of polynomials
Expand All @@ -90,24 +90,14 @@ where
) -> Result<(), NovaError> {
let u = InnerProductInstance::new(comm, &EqPolynomial::evals_from_points(point), eval);

arg.verify(
&vk.ck_v,
&vk.ck_s,
(2_usize).pow(point.len() as u32),
&u,
transcript,
)?;
arg.verify(&vk.ck_v, vk.ck_s.clone(), 1 << point.len(), &u, transcript)?;

Ok(())
}
}

fn inner_product<T: Field + Send + Sync>(a: &[T], b: &[T]) -> T {
assert_eq!(a.len(), b.len());
(0..a.len())
.into_par_iter()
.map(|i| a[i] * b[i])
.reduce(|| T::ZERO, |x, y| x + y)
zip_with!(par_iter, (a, b), |x, y| *x * y).sum()
}

/// An inner product instance consists of a commitment to a vector `a` and another vector `b`
Expand Down Expand Up @@ -175,8 +165,8 @@ where
}

fn prove(
ck: &CommitmentKey<E>,
ck_c: &CommitmentKey<E>,
ck: CommitmentKey<E>,
mut ck_c: CommitmentKey<E>,
U: &InnerProductInstance<E>,
W: &InnerProductWitness<E>,
transcript: &mut E::TE,
Expand All @@ -194,12 +184,12 @@ where

// sample a random base for commiting to the inner product
let r = transcript.squeeze(b"r")?;
let ck_c = ck_c.scale(&r);
ck_c.scale(&r);

// a closure that executes a step of the recursive inner product argument
let prove_inner = |a_vec: &[E::Scalar],
b_vec: &[E::Scalar],
ck: &CommitmentKey<E>,
ck: CommitmentKey<E>,
transcript: &mut E::TE|
-> Result<
(
Expand Down Expand Up @@ -255,7 +245,7 @@ where
.map(|(b_L, b_R)| *b_L * r_inverse + r * *b_R)
.collect::<Vec<E::Scalar>>();

let ck_folded = ck.fold(&r_inverse, &r);
let ck_folded = CommitmentKeyExtTrait::fold(&ck_L, &ck_R, &r_inverse, &r);

Ok((L, R, a_vec_folded, b_vec_folded, ck_folded))
};
Expand All @@ -270,7 +260,7 @@ where
let mut ck = ck;
for _i in 0..usize::try_from(U.b_vec.len().ilog2()).unwrap() {
let (L, R, a_vec_folded, b_vec_folded, ck_folded) =
prove_inner(&a_vec, &b_vec, &ck, transcript)?;
prove_inner(&a_vec, &b_vec, ck, transcript)?;
L_vec.push(L);
R_vec.push(R);

Expand All @@ -289,12 +279,12 @@ where
fn verify(
&self,
ck: &CommitmentKey<E>,
ck_c: &CommitmentKey<E>,
mut ck_c: CommitmentKey<E>,
n: usize,
U: &InnerProductInstance<E>,
transcript: &mut E::TE,
) -> Result<(), NovaError> {
let (ck, _) = ck.split_at(U.b_vec.len());
let (ck, _) = ck.clone().split_at(U.b_vec.len());

transcript.dom_sep(Self::protocol_name());
if U.b_vec.len() != n
Expand All @@ -310,7 +300,7 @@ where

// sample a random base for commiting to the inner product
let r = transcript.squeeze(b"r")?;
let ck_c = ck_c.scale(&r);
ck_c.scale(&r);

let P = U.comm_a_vec + CE::<E>::commit(&ck_c, &[U.c]);

Expand Down
97 changes: 37 additions & 60 deletions src/provider/pedersen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
commitment::{CommitmentEngineTrait, CommitmentTrait, Len},
AbsorbInROTrait, Engine, ROTrait, TranscriptReprTrait,
},
zip_with,
};
use abomonation_derive::Abomonation;
use core::{
Expand All @@ -14,12 +15,15 @@ use core::{
ops::{Add, Mul, MulAssign},
};
use ff::Field;
use group::{prime::PrimeCurve, Curve, Group, GroupEncoding};
use group::{
prime::{PrimeCurve, PrimeCurveAffine},
Curve, Group, GroupEncoding,
};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

/// A type that holds commitment generators
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Abomonation)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Abomonation)]
#[abomonation_omit_bounds]
pub struct CommitmentKey<E>
where
Expand All @@ -30,19 +34,6 @@ where
ck: Vec<<E::GE as PrimeCurve>::Affine>,
}

/// [`CommitmentKey`]s are often large, and this helps with cloning bottlenecks
impl<E> Clone for CommitmentKey<E>
where
E: Engine,
E::GE: DlogGroup<ScalarExt = E::Scalar>,
{
fn clone(&self) -> Self {
Self {
ck: self.ck[..].par_iter().cloned().collect(),
}
}
}

impl<E> Len for CommitmentKey<E>
where
E: Engine,
Expand Down Expand Up @@ -243,18 +234,18 @@ where
E::GE: DlogGroup,
{
/// Splits the commitment key into two pieces at a specified point
fn split_at(&self, n: usize) -> (Self, Self)
fn split_at(self, n: usize) -> (Self, Self)
where
Self: Sized;

/// Combines two commitment keys into one
fn combine(&self, other: &Self) -> Self;

/// Folds the two commitment keys into one using the provided weights
fn fold(&self, w1: &E::Scalar, w2: &E::Scalar) -> Self;
fn fold(L: &Self, R: &Self, w1: &E::Scalar, w2: &E::Scalar) -> Self;

/// Scales the commitment key using the provided scalar
fn scale(&self, r: &E::Scalar) -> Self;
fn scale(&mut self, r: &E::Scalar);

/// Reinterprets commitments as commitment keys
fn reinterpret_commitments_as_ck(
Expand All @@ -269,64 +260,50 @@ where
E: Engine<CE = CommitmentEngine<E>>,
E::GE: DlogGroup<ScalarExt = E::Scalar>,
{
fn split_at(&self, n: usize) -> (Self, Self) {
(
Self {
ck: self.ck[0..n].to_vec(),
},
Self {
ck: self.ck[n..].to_vec(),
},
)
fn split_at(mut self, n: usize) -> (Self, Self) {
let right = self.ck.split_off(n);
(self, Self { ck: right })
}

fn combine(&self, other: &Self) -> Self {
let ck = {
let mut c = self.ck.clone();
c.extend(other.ck.clone());
c
self
.ck
.iter()
.cloned()
.chain(other.ck.iter().cloned())
.collect::<Vec<_>>()
};
Self { ck }
}

// combines the left and right halves of `self` using `w1` and `w2` as the weights
fn fold(&self, w1: &E::Scalar, w2: &E::Scalar) -> Self {
let w = vec![*w1, *w2];
let (L, R) = self.split_at(self.ck.len() / 2);

let ck = (0..self.ck.len() / 2)
.into_par_iter()
.map(|i| {
let bases = [L.ck[i].clone(), R.ck[i].clone()].to_vec();
E::GE::vartime_multiscalar_mul(&w, &bases).to_affine()
})
.collect();

Self { ck }
fn fold(L: &Self, R: &Self, w1: &E::Scalar, w2: &E::Scalar) -> Self {
debug_assert!(L.ck.len() == R.ck.len());
let ck_curve: Vec<E::GE> = zip_with!(par_iter, (L.ck, R.ck), |l, r| {
E::GE::vartime_multiscalar_mul(&[*w1, *w2], &[*l, *r])
})
.collect();
let mut ck_affine = vec![<E::GE as PrimeCurve>::Affine::identity(); L.ck.len()];
E::GE::batch_normalize(&ck_curve, &mut ck_affine);

Self { ck: ck_affine }
}

/// Scales each element in `self` by `r`
fn scale(&self, r: &E::Scalar) -> Self {
let ck_scaled = self
.ck
.clone()
.into_par_iter()
.map(|g| E::GE::vartime_multiscalar_mul(&[*r], &[g]).to_affine())
.collect();

Self { ck: ck_scaled }
fn scale(&mut self, r: &E::Scalar) {
let ck_scaled: Vec<E::GE> = self.ck.par_iter().map(|g| *g * r).collect();
E::GE::batch_normalize(&ck_scaled, &mut self.ck);
}

/// reinterprets a vector of commitments as a set of generators
fn reinterpret_commitments_as_ck(c: &[CompressedCommitment<E>]) -> Result<Self, NovaError> {
let d = (0..c.len())
.into_par_iter()
.map(|i| Commitment::<E>::decompress(&c[i]))
.collect::<Result<Vec<Commitment<E>>, NovaError>>()?;
let ck = (0..d.len())
.into_par_iter()
.map(|i| d[i].comm.to_affine())
.collect();
let d = c
.par_iter()
.map(|c| Commitment::<E>::decompress(c).map(|c| c.comm))
.collect::<Result<Vec<E::GE>, NovaError>>()?;
let mut ck = vec![<E::GE as PrimeCurve>::Affine::identity(); d.len()];
E::GE::batch_normalize(&d, &mut ck);
Ok(Self { ck })
}
}
13 changes: 12 additions & 1 deletion src/provider/traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::traits::{Group, TranscriptReprTrait};
use group::prime::PrimeCurveAffine;
use group::{prime::PrimeCurve, GroupEncoding};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::ops::Mul;

/// A trait that defines extensions to the Group trait
pub trait DlogGroup:
Expand All @@ -11,7 +13,16 @@ pub trait DlogGroup:
+ PrimeCurve<Scalar = <Self as DlogGroup>::ScalarExt, Affine = <Self as DlogGroup>::AffineExt>
{
type ScalarExt;
type AffineExt: Clone + Debug + Eq + Serialize + for<'de> Deserialize<'de> + Sync + Send;
type AffineExt: Clone
+ Debug
+ Eq
+ Serialize
+ for<'de> Deserialize<'de>
+ Sync
+ Send
// technical bounds, should disappear when associated_type_bounds stabilizes
+ Mul<Self::ScalarExt, Output = Self>
+ PrimeCurveAffine<Curve = Self, Scalar = Self::ScalarExt>;
type Compressed: Clone
+ Debug
+ Eq
Expand Down
5 changes: 2 additions & 3 deletions src/traits/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ use serde::{Deserialize, Serialize};
/// A trait that ties different pieces of the commitment evaluation together
pub trait EvaluationEngineTrait<E: Engine>: Clone + Send + Sync {
/// A type that holds the prover key
type ProverKey: Clone + Send + Sync;
type ProverKey: Send + Sync;

/// A type that holds the verifier key
type VerifierKey: Clone
+ Send
type VerifierKey: Send
+ Sync
// required for easy Digest computation purposes, could be relaxed to
// [`crate::digest::Digestible`]
Expand Down

0 comments on commit 6fde742

Please sign in to comment.