Skip to content

Commit

Permalink
Fixed multicore for z_bind
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Dec 30, 2024
1 parent c7508d3 commit 3209b9f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 73 deletions.
116 changes: 73 additions & 43 deletions spartan_parallel/src/custom_dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
use std::cmp::min;

use crate::dense_mlpoly::DensePolynomial;
use crate::math::Math;
use crate::scalar::SpartanExtensionField;
use rayon::prelude::*;

const MODE_P: usize = 1;
const MODE_Q: usize = 2;
const MODE_W: usize = 3;
const MODE_X: usize = 4;
const NUM_MULTI_THREAD_CORES: usize = 32;

// Customized Dense ML Polynomials for Data-Parallelism
// These Dense ML Polys are aimed for space-efficiency by removing the 0s for invalid (p, q, w, x) quadruple
Expand All @@ -30,6 +30,40 @@ pub struct DensePolynomialPqx<S: SpartanExtensionField> {
// The same applies to X
}

fn fold_rq<S: SpartanExtensionField>(proofs: &mut [Vec<Vec<S>>], r_q: &[S], step: usize, mut q: usize, w: usize, x: usize) {
for r in r_q {
let r1 = S::field_one() - r.clone();
let r2 = r.clone();

q = q.div_ceil(2);
(0..q).for_each(|q| {
(0..w).for_each(|w| {
(0..x).for_each(|x| {
proofs[q * step][w][x] = r1 * proofs[2 * q * step][w][x] + r2 * proofs[(2 * q + 1) * step][w][x];
});
});
});
}

/*
if lvl > final_lvl {
fold_rq(proofs, r_q, 2 * idx, step, lvl - 1, final_lvl, w, x);
fold_rq(proofs, r_q, 2 * idx + step, step, lvl - 1, final_lvl, w, x);
let r1 = S::field_one() - r_q[lvl - 1];
let r2 = r_q[lvl - 1];
(0..w).for_each(|w| {
(0..x).for_each(|x| {
proofs[idx][w][x] = r1 * proofs[idx * 2][w][x] + r2 * proofs[idx * 2 + step][w][x];
});
});
} else {
// base level. do nothing
}
*/
}

impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
// Assume z_mat is of form (p, q_rev, x_rev), construct DensePoly
pub fn new(
Expand Down Expand Up @@ -207,7 +241,7 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
}

// Bound the entire "q" section to r_q in reverse
pub fn bound_poly_vars_rq(
pub fn bound_poly_vars_rq_parallel(
&mut self,
r_q: &[S],
) {
Expand All @@ -218,50 +252,47 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
.enumerate()
.map(|(p, mut inst)| {
let num_proofs = self.num_proofs[p];
let dist_size = num_proofs / min(num_proofs, NUM_MULTI_THREAD_CORES); // distributed number of proofs on each thread
let dist_size = num_proofs / min(num_proofs, rayon::current_num_threads().next_power_of_two()); // distributed number of proofs on each thread
let num_threads = num_proofs / dist_size;

// To perform rigorous parallelism, both num_proofs and # threads must be powers of 2
// # threads must fully divide num_proofs for even distribution
assert!(num_proofs & (num_proofs - 1) == 0);
assert!(num_threads & (num_threads - 1) == 0);
assert_eq!(num_proofs, num_proofs.next_power_of_two());
assert_eq!(num_threads, num_threads.next_power_of_two());

// Determine parallelism levels
let levels = num_proofs.trailing_zeros() as usize; // total layers
let sub_levels = dist_size.trailing_zeros() as usize; // parallelism layers
let final_levels = num_threads.trailing_zeros() as usize; // single core final layers
let levels = num_proofs.log_2(); // total layers
let sub_levels = dist_size.log_2(); // parallel layers
let final_levels = num_threads.log_2(); // single core final layers
let left_over_q_len = r_q.len() - levels; // if r_q.len() > log2(num_proofs)

// single proof matrix dimension W x X
let num_witness_secs = min(self.num_witness_secs, inst[0].len());
let num_inputs = self.num_inputs[p];


// Divide rq into sub, final, and left_over
let sub_rq = &r_q[0..sub_levels];
let final_rq = &r_q[sub_levels..levels];
let left_over_rq = &r_q[(r_q.len() - left_over_q_len)..r_q.len()];

if sub_levels > 0 {
let thread_split_inst = (0..num_threads)
.map(|_| {
inst.split_off(inst.len() - dist_size)
inst = inst
.par_chunks_mut(dist_size)
.map(|chunk| {
fold_rq(chunk, sub_rq, 1, dist_size, num_witness_secs, num_inputs);
chunk.to_vec()
})
.rev()
.collect::<Vec<Vec<Vec<Vec<S>>>>>();

inst = thread_split_inst
.into_par_iter()
.map(|mut chunk| {
fold(&mut chunk, r_q, 0, 1, sub_levels, 0, num_witness_secs, num_inputs);
chunk
})
.collect::<Vec<Vec<Vec<Vec<S>>>>>()
.into_iter().flatten().collect()
.flatten().collect()
}

if final_levels > 0 {
// aggregate the final result from sub-threads outputs using a single core
fold(&mut inst, r_q, 0, dist_size, final_levels + sub_levels, sub_levels, num_witness_secs, num_inputs);
fold_rq(&mut inst, final_rq, dist_size, num_threads, num_witness_secs, num_inputs);
}

if left_over_q_len > 0 {
// the series of random challenges exceeds the total number of variables
let c = r_q[(r_q.len() - left_over_q_len)..r_q.len()].iter().fold(S::field_one(), |acc, n| acc * (S::field_one() - *n));
let c = left_over_rq.into_iter().fold(S::field_one(), |acc, n| acc * (S::field_one() - *n));
for w in 0..inst[0].len() {
for x in 0..inst[0][0].len() {
inst[0][w][x] *= c;
Expand All @@ -275,6 +306,23 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
self.max_num_proofs /= 2usize.pow(r_q.len() as u32);
}

// Bound the entire "q" section to r_q in reverse
// Must occur after r_q's are bounded
pub fn bound_poly_vars_rq(&mut self,
r_q: &[S],
) {
let mut count = 0;
for r in r_q {
self.bound_poly_q(r);
count += 1;
if count == 10 {
for p in 0..self.Z.len() {
println!("SINGLE CORE: P: {}, INST0: {:?}", p, self.Z[p][0][0][0]);
}
}
}
}

// Bound the entire "w" section to r_w in reverse
pub fn bound_poly_vars_rw(&mut self,
r_w: &[S],
Expand Down Expand Up @@ -327,22 +375,4 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
}
DensePolynomial::new(Z_poly)
}
}

fn fold<S: SpartanExtensionField>(proofs: &mut Vec<Vec<Vec<S>>>, r_q: &[S], idx: usize, step: usize, lvl: usize, final_lvl: usize, w: usize, x: usize) {
if lvl > final_lvl {
fold(proofs, r_q, 2 * idx, step, lvl - 1, final_lvl, w, x);
fold(proofs, r_q, 2 * idx + step, step, lvl - 1, final_lvl, w, x);

let r1 = S::field_one() - r_q[lvl - 1];
let r2 = r_q[lvl - 1];

(0..w).for_each(|w| {
(0..x).for_each(|x| {
proofs[idx][w][x] = r1 * proofs[idx * 2][w][x] + r2 * proofs[idx * 2 + step][w][x];
});
});
} else {
// base level. do nothing
}
}
51 changes: 22 additions & 29 deletions spartan_parallel/src/r1csproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,34 +184,25 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
// append input to variables to create a single vector z
let timer_tmp = Timer::new("prove_z_mat_gen");

let z_mat = (0..num_instances)
.into_par_iter()
.map(|p| {
(0..num_proofs[p])
.into_par_iter()
.map(|q| {
(0..witness_secs.len())
.map(|w| {
let ws = witness_secs[w];
let p_w = if ws.w_mat.len() == 1 { 0 } else { p };
let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q };

let r_w = if ws.num_inputs[p_w] < num_inputs[p] {
let padding = std::iter::repeat(S::field_zero()).take(num_inputs[p] - ws.num_inputs[p_w]).collect::<Vec<S>>();
let mut r = ws.w_mat[p_w][q_w].clone();
r.extend(padding);
r
} else {
ws.w_mat[p_w][q_w].iter().take(num_inputs[p]).cloned().collect::<Vec<S>>()
};

r_w
})
.collect::<Vec<Vec<S>>>()
})
.collect::<Vec<Vec<Vec<S>>>>()
})
.collect::<Vec<Vec<Vec<Vec<S>>>>>();
let z_mat = (0..num_instances).map(|p| {
(0..num_proofs[p]).into_par_iter().map(|q| {
(0..witness_secs.len()).map(|w| {
let ws = witness_secs[w];
let p_w = if ws.w_mat.len() == 1 { 0 } else { p };
let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q };

let r_w = if ws.num_inputs[p_w] < num_inputs[p] {
let padding = std::iter::repeat(S::field_zero()).take(num_inputs[p] - ws.num_inputs[p_w]).collect::<Vec<S>>();
let mut r = ws.w_mat[p_w][q_w].clone();
r.extend(padding);
r
} else {
ws.w_mat[p_w][q_w].iter().take(num_inputs[p]).cloned().collect::<Vec<S>>()
};
r_w
}).collect::<Vec<Vec<S>>>()
}).collect::<Vec<Vec<Vec<S>>>>()
}).collect::<Vec<Vec<Vec<Vec<S>>>>>();
timer_tmp.stop();

// derive the verifier's challenge \tau
Expand Down Expand Up @@ -346,13 +337,14 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
);
timer_tmp.stop();
let timer_tmp = Timer::new("prove_z_bind");
Z_poly.bound_poly_vars_rq(&rq_rev);
Z_poly.bound_poly_vars_rq_parallel(&rq_rev);
timer_tmp.stop();

// An Eq function to match p with rp
let mut eq_p_rp_poly = DensePolynomial::new(EqPolynomial::new(rp).evals());

// Sumcheck 2: (rA + rB + rC) * Z * eq(p) = e
let timer_tmp = Timer::new("prove_sum_check");
let (sc_proof_phase2, ry_rev, _claims_phase2) = R1CSProof::prove_phase_two(
num_rounds_y + num_rounds_w + num_rounds_p,
num_rounds_y,
Expand All @@ -367,6 +359,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
&mut Z_poly,
transcript,
);
timer_tmp.stop();
timer_sc_proof_phase2.stop();

// Separate ry into rp, rw, and ry
Expand Down
2 changes: 1 addition & 1 deletion zok_tests/benchmarks/poseidon_test/poseidon_const.zok
Original file line number Diff line number Diff line change
@@ -1 +1 @@
const u32 REPETITION = 1000
const u32 REPETITION = 10000

0 comments on commit 3209b9f

Please sign in to comment.