diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index d74133cf..64edd643 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -51,33 +51,57 @@ impl PolyEvalWitness { let size_max = W.iter().map(|w| w.p.len()).max().unwrap(); // Scale the input polynomials by the power of s - let p = W - .into_par_iter() - .zip_eq(powers.par_iter()) - .map(|(mut w, s)| { - if *s != E::Scalar::ONE { - w.p.par_iter_mut().for_each(|e| *e *= s); - } - w.p - }) - .reduce( - || vec![E::Scalar::ZERO; size_max], - |left, right| { - // Sum into the largest polynomial - let (mut big, small) = if left.len() > right.len() { - (left, right) - } else { - (right, left) - }; - - big - .par_iter_mut() - .zip(small.par_iter()) - .for_each(|(b, s)| *b += s); - - big - }, - ); + let num_chunks = rayon::current_num_threads().next_power_of_two(); + let chunk_size = size_max / num_chunks; + + let p = if chunk_size > 0 { + (0..num_chunks) + .into_par_iter() + .flat_map_iter(|chunk_index| { + let mut chunk = vec![E::Scalar::ZERO; chunk_size]; + for (coeff, poly) in powers.iter().zip(W.iter()) { + for (rlc, poly_eval) in chunk + .iter_mut() + .zip(poly.p[chunk_index * chunk_size..].iter()) + { + if *coeff == E::Scalar::ONE { + *rlc += *poly_eval; + } else { + *rlc += *coeff * poly_eval; + }; + } + } + chunk + }) + .collect::>() + } else { + W.into_par_iter() + .zip_eq(powers.par_iter()) + .map(|(mut w, s)| { + if *s != E::Scalar::ONE { + w.p.par_iter_mut().for_each(|e| *e *= s); + } + w.p + }) + .reduce( + || vec![E::Scalar::ZERO; size_max], + |left, right| { + // Sum into the largest polynomial + let (mut big, small) = if left.len() > right.len() { + (left, right) + } else { + (right, left) + }; + + big + .par_iter_mut() + .zip(small.par_iter()) + .for_each(|(b, s)| *b += s); + + big + }, + ) + }; PolyEvalWitness { p } } @@ -96,17 +120,42 @@ impl PolyEvalWitness { let powers_of_s = powers::(s, p_vec.len()); - let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| { - // compute the weighted sum for each vector - v.iter().map(|&x| x * *weight).collect::>() - }) - .reduce( - || vec![E::Scalar::ZERO; p_vec[0].len()], - |acc, v| { - // perform vector addition to combine the weighted vectors - zip_with!((acc.into_iter(), v), |x, y| x + y).collect() - }, - ); + let num_chunks = rayon::current_num_threads().next_power_of_two(); + let chunk_size = p_vec[0].len() / num_chunks; + + let p = if chunk_size > 0 { + (0..num_chunks) + .into_par_iter() + .flat_map_iter(|chunk_index| { + let mut chunk = vec![E::Scalar::ZERO; chunk_size]; + for (coeff, poly) in powers_of_s.iter().zip(p_vec.iter()) { + for (rlc, poly_eval) in chunk + .iter_mut() + .zip(poly[chunk_index * chunk_size..].iter()) + { + if *coeff == E::Scalar::ONE { + *rlc += *poly_eval; + } else { + *rlc += *coeff * poly_eval; + }; + } + } + chunk + }) + .collect::>() + } else { + zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| { + // compute the weighted sum for each vector + v.iter().map(|&x| x * *weight).collect::>() + }) + .reduce( + || vec![E::Scalar::ZERO; p_vec[0].len()], + |acc, v| { + // perform vector addition to combine the weighted vectors + zip_with!((acc.into_iter(), v), |x, y| x + y).collect() + }, + ) + }; PolyEvalWitness { p } }