Skip to content

Commit

Permalink
Add lookups for x_rev and q_rev
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Dec 18, 2024
1 parent 93e30d1 commit 5f17d48
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
21 changes: 16 additions & 5 deletions spartan_parallel/src/r1csinstance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ use super::sparse_mlpoly::{
use super::timer::Timer;
use flate2::{write::ZlibEncoder, Compression};
use merlin::Transcript;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::iter::zip;
use std::sync::{Arc, Mutex};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct R1CSInstance<S: SpartanExtensionField> {
Expand Down Expand Up @@ -249,39 +247,52 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSInstance<S> {
Bz.push(Vec::new());
Cz.push(Vec::new());

// Map x and y to x_rev and y_rev so we don't have to do it everytime
let x_step = max_num_cons / num_cons[p];
let x_rev_map = (0..num_cons[p]).map(|x|
rev_bits(x, max_num_cons) / x_step
).collect();
let y_step = max_num_inputs / num_inputs[p];
let y_rev_map = (0..num_inputs[p]).map(|y|
rev_bits(y, max_num_inputs) / y_step
).collect();

Az[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q| {
vec![self.A_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
max_num_inputs,
num_inputs[p],
&z_list[q],
&x_rev_map,
&y_rev_map,
)]
})
.collect();
Bz[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q| {
vec![self.B_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
max_num_inputs,
num_inputs[p],
&z_list[q],
&x_rev_map,
&y_rev_map,
)]
})
.collect();
Cz[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q| {
vec![self.C_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
max_num_inputs,
num_inputs[p],
&z_list[q],
&x_rev_map,
&y_rev_map,
)]
})
.collect();
Expand Down
9 changes: 6 additions & 3 deletions spartan_parallel/src/r1csproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,20 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
for p in 0..num_instances {
z_mat.push(vec![vec![vec![S::field_zero(); num_inputs[p]]; num_witness_secs]; num_proofs[p]]);
let q_step = max_num_proofs / num_proofs[p];

let y_step = max_num_inputs / num_inputs[p];
let y_rev_map: Vec<usize> = (0..num_inputs[p]).map(|y|
rev_bits(y, max_num_inputs) / y_step
).collect();
for q in 0..num_proofs[p] {
let q_rev = rev_bits(q, max_num_proofs) / q_step;
for w in 0..witness_secs.len() {
let ws = witness_secs[w];
let p_w = if ws.w_mat.len() == 1 { 0 } else { p };
let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q };
let y_step = max_num_inputs / num_inputs[p];
// Only append the first num_inputs_entries of w_mat[p][q]
for i in 0..min(ws.num_inputs[p_w], num_inputs[p]) {
let y_rev = rev_bits(i, max_num_inputs) / y_step;
z_mat[p][q_rev][w][y_rev] = ws.w_mat[p_w][q_w][i];
z_mat[p][q_rev][w][y_rev_map[i]] = ws.w_mat[p_w][q_w][i];
}
}
}
Expand Down
19 changes: 6 additions & 13 deletions spartan_parallel/src/sparse_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,31 +405,24 @@ impl<S: SpartanExtensionField> SparseMatPolynomial<S> {
// Z[i] contains entries i * max_num_cols ~ i * max_num_cols + num_cols
pub fn multiply_vec_disjoint_rounds(
&self,
max_num_rows: usize,
num_rows: usize,
max_num_cols: usize,
num_cols: usize,
z: &Vec<Vec<S>>
_num_cols: usize,
z: &Vec<Vec<S>>,
x_rev_map: &Vec<usize>,
y_rev_map: &Vec<usize>,
) -> Vec<S> {
let step_r = max_num_rows / num_rows;
(0..self.M.len())
.map(|i| {
let row = self.M[i].row;
let col = self.M[i].col;
let val = self.M[i].val.clone();
let w = col / max_num_cols;
let y = col % max_num_cols;
// Z expresses y in reverse bits order, so have to find the correct y
let y_step = max_num_cols / num_cols;
let y_rev = rev_bits(y, max_num_cols) / y_step;
(row, val * z[w][y_rev])
(row, val * z[w][y_rev_map[y]])
})
.fold(vec![S::field_zero(); num_rows], |mut Mz, (r, v)| {
// Reverse the bits of r. r_rev is a multiple of step_r
let r_rev = rev_bits(r, max_num_rows);
// Now r_rev is between 0 to num_inputs[p]
let r_rev = r_rev / step_r;
Mz[r_rev] += v;
Mz[x_rev_map[r]] += v;
Mz
})
}
Expand Down

0 comments on commit 5f17d48

Please sign in to comment.