From 48887f8bc568d199814f6033eb3ed4d158119310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Wed, 10 Jan 2024 15:45:01 -0500 Subject: [PATCH] refactor: Integrate zip_with macro (#292) - Created a new file `macros.rs` with the implementation of `zip_with` and `zip_with_for_each` macros providing syntactic sugar for zipWith patterns. - zipWith patterns implemented through the use of the zip_with! macros now resolve to the use of zip_eq, a variant of zip that panics when the iterator arguments are of different length, - the zip_eq implementation is the native rayon one for parallel iterators, and the one from itertools (see below) for the sequential ones, - Optimized and refactored functions like `batch_eval_prove` and `batch_eval_verify` in `snark.rs`, methods inside `PolyEvalWitness` and `PolyEvalInstance` in `mod.rs`, and multiple functions in `multilinear.rs` through the use of implemented macros. - Introduced the use of itertools::Itertools in various files to import the use of zip_eq on sequential iterators. - Made use of the Itertools library for refactoring and optimizing computation in `sumcheck.rs` and `eq.rs` files. This backports (among others) content from the following Arecibo PRS: - https://github.com/lurk-lab/arecibo/pull/149 - https://github.com/lurk-lab/arecibo/pull/158 - https://github.com/lurk-lab/arecibo/pull/169 Co-authored-by: porcuquine --- Cargo.toml | 1 + src/spartan/macros.rs | 103 +++++++++++++++++++++++++++++++ src/spartan/mod.rs | 39 +++++------- src/spartan/polys/eq.rs | 11 ++-- src/spartan/polys/multilinear.rs | 29 ++++----- src/spartan/ppsnark.rs | 31 +++------- src/spartan/snark.rs | 49 +++++---------- src/spartan/sumcheck.rs | 23 +++---- 8 files changed, 170 insertions(+), 116 deletions(-) create mode 100644 src/spartan/macros.rs diff --git a/Cargo.toml b/Cargo.toml index 785fc94f..e0a37576 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ byteorder = "1.4.3" thiserror = "1.0" group = "0.13.0" once_cell = "1.18.0" +itertools = "0.12.0" [target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies] pasta-msm = { version = "0.1.4" } diff --git a/src/spartan/macros.rs b/src/spartan/macros.rs new file mode 100644 index 00000000..bab75e95 --- /dev/null +++ b/src/spartan/macros.rs @@ -0,0 +1,103 @@ +/// Macros to give syntactic sugar for zipWith pattern and variants. +/// +/// ```ignore +/// use crate::spartan::zip_with; +/// use itertools::Itertools as _; // we use zip_eq to zip! +/// let v = vec![0, 1, 2]; +/// let w = vec![2, 3, 4]; +/// let y = vec![4, 5, 6]; +/// +/// // Using the `zip_with!` macro to zip three iterators together and apply a closure +/// // that sums the elements of each iterator. +/// let res = zip_with!((v.iter(), w.iter(), y.iter()), |a, b, c| a + b + c) +/// .collect::>(); +/// +/// println!("{:?}", res); // Output: [6, 9, 12] +/// ``` + +#[macro_export] +macro_rules! zip_with { + // no iterator projection specified: the macro assumes the arguments *are* iterators + // ```ignore + // zip_with!((iter1, iter2, iter3), |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).map(|(a, (b, c))| a + b + c) + // ``` + // + // iterator projection specified: use it on each argument + // ```ignore + // zip_with!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) -> + // vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).map(|(a, (b, c))| a + b + c) + // ```` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_with!($($f,)? ($e $(, $rest)*), map, $($move)? |$($i),+| $($work)*) + }}; + // no iterator projection specified: the macro assumes the arguments *are* iterators + // optional zipping function specified as well: use it instead of map + // ```ignore + // zip_with!((iter1, iter2, iter3), for_each, |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c) + // ``` + // + // + // iterator projection specified: use it on each argument + // optional zipping function specified as well: use it instead of map + // ```ignore + // zip_with!(par_iter, (vec1, vec2, vec3), for_each, |a, b, c| a + b + c) -> + // vec1.part_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c) + // ``` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $worker:ident, $($move:ident,)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_all!($($f,)? ($e $(, $rest)*)) + .$worker($($move)? |$crate::nested_idents!($($i),+)| { + $($work)* + }) + }}; +} + +/// Like `zip_with` but use `for_each` instead of `map`. +#[macro_export] +macro_rules! zip_with_for_each { + // no iterator projection specified: the macro assumes the arguments *are* iterators + // ```ignore + // zip_with_for_each!((iter1, iter2, iter3), |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c) + // ``` + // + // iterator projection specified: use it on each argument + // ```ignore + // zip_with_for_each!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) -> + // vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c) + // ```` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_with!($($f,)? ($e $(, $rest)*), for_each, $($move)? |$($i),+| $($work)*) + }}; +} + +// Foldright-like nesting for idents (a, b, c) -> (a, (b, c)) +#[doc(hidden)] +#[macro_export] +macro_rules! nested_idents { + ($a:ident, $b:ident) => { + ($a, $b) + }; + ($first:ident, $($rest:ident),+) => { + ($first, $crate::nested_idents!($($rest),+)) + }; +} + +// Fold-right like zipping, with an optional function `f` to apply to each argument +#[doc(hidden)] +#[macro_export] +macro_rules! zip_all { + (($e:expr,)) => { + $e + }; + ($f:ident, ($e:expr,)) => { + $e.$f() + }; + ($f:ident, ($first:expr, $second:expr $(, $rest:expr)*)) => { + ($first.$f().zip_eq($crate::zip_all!($f, ($second, $( $rest),*)))) + }; + (($first:expr, $second:expr $(, $rest:expr)*)) => { + ($first.zip_eq($crate::zip_all!(($second, $( $rest),*)))) + }; +} diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 3e7c514e..18479d19 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -5,6 +5,8 @@ //! We also provide direct.rs that allows proving a step circuit directly with either of the two SNARKs. //! //! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials. +#[macro_use] +mod macros; pub mod direct; pub(crate) mod math; pub mod polys; @@ -14,6 +16,7 @@ mod sumcheck; use crate::{traits::Engine, Commitment}; use ff::Field; +use itertools::Itertools as _; use polys::multilinear::SparsePolynomial; use rayon::{iter::IntoParallelRefIterator, prelude::*}; @@ -64,20 +67,17 @@ impl PolyEvalWitness { let powers_of_s = powers::(s, p_vec.len()); - let p = p_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(v, &weight)| { - // compute the weighted sum for each vector - v.iter().map(|&x| x * weight).collect::>() - }) - .reduce( - || vec![E::Scalar::ZERO; p_vec[0].len()], - |acc, v| { - // perform vector addition to combine the weighted vectors - acc.into_iter().zip(v).map(|(x, y)| x + y).collect() - }, - ); + let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| { + // compute the weighted sum for each vector + v.iter().map(|&x| x * weight).collect::>() + }) + .reduce( + || vec![E::Scalar::ZERO; p_vec[0].len()], + |acc, v| { + // perform vector addition to combine the weighted vectors + zip_with!((acc.into_iter(), v), |x, y| x + y).collect() + }, + ); PolyEvalWitness { p } } @@ -113,15 +113,8 @@ impl PolyEvalInstance { s: &E::Scalar, ) -> PolyEvalInstance { let powers_of_s = powers::(s, c_vec.len()); - let e = e_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(e, p)| *e * p) - .sum(); - let c = c_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(c, p)| *c * *p) + let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum(); + let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p) .reduce(Commitment::::default, |acc, item| acc + item); PolyEvalInstance { diff --git a/src/spartan/polys/eq.rs b/src/spartan/polys/eq.rs index 22ef1a13..5e2dcd9e 100644 --- a/src/spartan/polys/eq.rs +++ b/src/spartan/polys/eq.rs @@ -52,13 +52,10 @@ impl EqPolynomial { let (evals_left, evals_right) = evals.split_at_mut(size); let (evals_right, _) = evals_right.split_at_mut(size); - evals_left - .par_iter_mut() - .zip(evals_right.par_iter_mut()) - .for_each(|(x, y)| { - *y = *x * r; - *x -= &*y; - }); + zip_with_for_each!(par_iter_mut, (evals_left, evals_right), |x, y| { + *y = *x * r; + *x -= &*y; + }); size *= 2; } diff --git a/src/spartan/polys/multilinear.rs b/src/spartan/polys/multilinear.rs index 385d8a34..6610fc2f 100644 --- a/src/spartan/polys/multilinear.rs +++ b/src/spartan/polys/multilinear.rs @@ -5,6 +5,7 @@ use std::ops::{Add, Index}; use ff::PrimeField; +use itertools::Itertools as _; use rayon::prelude::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, @@ -65,12 +66,9 @@ impl MultilinearPolynomial { let (left, right) = self.Z.split_at_mut(n); - left - .par_iter_mut() - .zip(right.par_iter()) - .for_each(|(a, b)| { - *a += *r * (*b - *a); - }); + zip_with_for_each!((left.par_iter_mut(), right.par_iter()), |a, b| { + *a += *r * (*b - *a); + }); self.Z.resize(n, Scalar::ZERO); self.num_vars -= 1; @@ -94,12 +92,12 @@ impl MultilinearPolynomial { /// Evaluates the polynomial with the given evaluations and point. pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar { - EqPolynomial::new(r.to_vec()) - .evals() - .into_par_iter() - .zip(Z.into_par_iter()) - .map(|(a, b)| a * b) - .sum() + zip_with!( + into_par_iter, + (EqPolynomial::new(r.to_vec()).evals(), Z), + |a, b| a * b + ) + .sum() } } @@ -167,12 +165,7 @@ impl Add for MultilinearPolynomial { return Err("The two polynomials must have the same number of variables"); } - let sum: Vec = self - .Z - .iter() - .zip(other.Z.iter()) - .map(|(a, b)| *a + *b) - .collect(); + let sum: Vec = zip_with!(iter, (self.Z, other.Z), |a, b| *a + *b).collect(); Ok(MultilinearPolynomial::new(sum)) } diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index cb023fd4..1fbe22cf 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -27,10 +27,11 @@ use crate::{ snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait}, Engine, TranscriptEngineTrait, TranscriptReprTrait, }, - Commitment, CommitmentKey, CompressedCommitment, + zip_with, Commitment, CommitmentKey, CompressedCommitment, }; use core::cmp::max; use ff::Field; +use itertools::Itertools as _; use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -339,13 +340,7 @@ impl MemorySumcheckInstance { let inv = batch_invert(&T.par_iter().map(|e| *e + *r).collect::>())?; // compute inv[i] * TS[i] in parallel - Ok( - inv - .par_iter() - .zip(TS.par_iter()) - .map(|(e1, e2)| *e1 * *e2) - .collect::>(), - ) + Ok(zip_with!(par_iter, (inv, TS), |e1, e2| *e1 * e2).collect::>()) }, || batch_invert(&W.par_iter().map(|e| *e + *r).collect::>()), ) @@ -853,11 +848,7 @@ impl> RelaxedR1CSSNARK { let coeffs = powers::(&s, claims.len()); // compute the joint claim - let claim = claims - .iter() - .zip(coeffs.iter()) - .map(|(c_1, c_2)| *c_1 * c_2) - .sum(); + let claim = zip_with!(iter, (claims, coeffs), |c_1, c_2| *c_1 * c_2).sum(); let mut e = claim; let mut r: Vec = Vec::new(); @@ -1086,14 +1077,12 @@ impl> RelaxedR1CSSNARKTrait for Relax ); // a sum-check instance to prove the second claim - let val = pk - .S_repr - .val_A - .par_iter() - .zip(pk.S_repr.val_B.par_iter()) - .zip(pk.S_repr.val_C.par_iter()) - .map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c) - .collect::>(); + let val = zip_with!( + par_iter, + (pk.S_repr.val_A, pk.S_repr.val_B, pk.S_repr.val_C), + |v_a, v_b, v_c| *v_a + c * *v_b + c * c * *v_c + ) + .collect::>(); let inner_sc_inst = InnerSumcheckInstance { claim: eval_Az_at_tau + c * eval_Bz_at_tau + c * c * eval_Cz_at_tau, poly_L_row: MultilinearPolynomial::new(L_row.clone()), diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 33d739e6..5e0e54d2 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -22,6 +22,7 @@ use crate::{ Commitment, CommitmentKey, }; use ff::Field; +use itertools::Itertools as _; use once_cell::sync::OnceCell; use rayon::prelude::*; @@ -469,11 +470,7 @@ fn batch_eval_prove( let rho = transcript.squeeze(b"r")?; let num_claims = w_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec_padded - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); + let claim_batch_joint = zip_with!(iter, (u_vec_padded, powers_of_rho), |u, p| u.e * p).sum(); let mut polys_left: Vec> = w_vec_padded .iter() @@ -504,17 +501,12 @@ fn batch_eval_prove( // we now combine evaluation claims at the same point rz into one let gamma = transcript.squeeze(b"g")?; let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) + let comm_joint = zip_with!(iter, (u_vec_padded, powers_of_gamma), |u, g_i| u.c * *g_i) .fold(Commitment::::default(), |acc, item| acc + item); let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); - let eval_joint = claims_batch_left - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); + let eval_joint = zip_with!(iter, (claims_batch_left, powers_of_gamma), |e, g_i| *e + * *g_i) + .sum(); Ok(( PolyEvalInstance:: { @@ -544,11 +536,7 @@ fn batch_eval_verify( let rho = transcript.squeeze(b"r")?; let num_claims: usize = u_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec_padded - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); + let claim_batch_joint = zip_with!(iter, (u_vec_padded, powers_of_rho), |u, p| u.e * p).sum(); let num_rounds_z = u_vec_padded[0].x.len(); @@ -562,12 +550,12 @@ fn batch_eval_verify( .map(|u| poly_rz.evaluate(&u.x)) .collect::>(); - evals - .iter() - .zip(evals_batch.iter()) - .zip(powers_of_rho.iter()) - .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) - .sum() + zip_with!( + iter, + (evals, evals_batch, powers_of_rho), + |e_i, p_i, rho_i| *e_i * *p_i * rho_i + ) + .sum() }; if claim_batch_final != claim_batch_final_expected { @@ -579,16 +567,9 @@ fn batch_eval_verify( // we now combine evaluation claims at the same point rz into one let gamma = transcript.squeeze(b"g")?; let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) + let comm_joint = zip_with!(iter, (u_vec_padded, powers_of_gamma), |u, g_i| u.c * *g_i) .fold(Commitment::::default(), |acc, item| acc + item); - let eval_joint = evals_batch - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); + let eval_joint = zip_with!(iter, (evals_batch, powers_of_gamma), |e, g_i| *e * *g_i).sum(); Ok(PolyEvalInstance:: { c: comm_joint, diff --git a/src/spartan/sumcheck.rs b/src/spartan/sumcheck.rs index fdf168be..f9a6ecb2 100644 --- a/src/spartan/sumcheck.rs +++ b/src/spartan/sumcheck.rs @@ -156,10 +156,10 @@ impl SumcheckProof { let mut quad_polys: Vec> = Vec::new(); for _ in 0..num_rounds { - let evals: Vec<(E::Scalar, E::Scalar)> = poly_A_vec - .par_iter() - .zip(poly_B_vec.par_iter()) - .map(|(poly_A, poly_B)| Self::compute_eval_points_quad(poly_A, poly_B, &comb_func)) + let evals: Vec<(E::Scalar, E::Scalar)> = + zip_with!(par_iter, (poly_A_vec, poly_B_vec), |poly_A, poly_B| { + Self::compute_eval_points_quad(poly_A, poly_B, &comb_func) + }) .collect(); let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum(); @@ -176,15 +176,12 @@ impl SumcheckProof { r.push(r_i); // bound all tables to the verifier's challenge - poly_A_vec - .par_iter_mut() - .zip(poly_B_vec.par_iter_mut()) - .for_each(|(poly_A, poly_B)| { - let _ = rayon::join( - || poly_A.bind_poly_var_top(&r_i), - || poly_B.bind_poly_var_top(&r_i), - ); - }); + zip_with_for_each!(par_iter_mut, (poly_A_vec, poly_B_vec), |poly_A, poly_B| { + let _ = rayon::join( + || poly_A.bind_poly_var_top(&r_i), + || poly_B.bind_poly_var_top(&r_i), + ); + }); e = poly.evaluate(&r_i); quad_polys.push(poly.compress());