Skip to content

Commit

Permalink
chore: more Maintenance (microsoft#319)
Browse files Browse the repository at this point in the history
* refactor: Refactor public param digest tests in lib.rs

- Refactored test functions within the tests module for greater efficiency and readability.
- Removed `DlogGroup` trait usage within the `test_pp_digest_with` function.
- Simplified `test_pp_digest` function by removing redundant circuit instantiation.

* chore: Upgrade Rust Toolchain Version to 1.76.0

- Upgraded the rust toolchain version from `1.75` to `1.76.0` in `rust-toolchain.toml` file.
- Closes microsoft#307

* chore: Refactor pprof dependency

- Shifted `pprof` from `dev-dependencies` to `dependencies` for non-wasm32 targets to optimize benchmarking builds.
- Introduced "flamegraph" feature to include `pprof` during benchmarking.
- Closes microsoft#309

* refactor: Refactor computation in sumcheck module

- Streamlined calculation of evaluation points in `compute_eval_points_quad` and `compute_eval_points_cubic_with_additive_term` functions within `src/spartan/sumcheck/mod.rs`.
  • Loading branch information
huitseeker authored Feb 12, 2024
1 parent 26fd303 commit 9064bfe
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ getrandom = { version = "0.2.0", default-features = false, features = ["js"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
proptest = "1.2.0"
pprof = { version = "0.13", optional = true } # in benches under feature "flamegraph"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
pprof = { version = "0.13" }
criterion = { version = "0.5", features = ["html_reports"] }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[toolchain]
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
profile = "default"
channel = "1.75"
channel = "1.76.0"
targets = [ "wasm32-unknown-unknown" ]

58 changes: 17 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1012,11 +1012,13 @@ type CE<E> = <E as Engine>::CE;

#[cfg(test)]
mod tests {
use self::traits::CurveCycleEquipped;

use super::*;
use crate::{
provider::{
non_hiding_zeromorph::ZMPCS, traits::DlogGroup, Bn256Engine, Bn256EngineKZG, Bn256EngineZM,
GrumpkinEngine, PallasEngine, Secp256k1Engine, Secq256k1Engine, VestaEngine,
non_hiding_zeromorph::ZMPCS, Bn256Engine, Bn256EngineKZG, Bn256EngineZM, PallasEngine,
Secp256k1Engine,
},
traits::{evaluation::EvaluationEngineTrait, snark::default_ck_hint},
};
Expand Down Expand Up @@ -1082,8 +1084,6 @@ mod tests {
fn test_pp_digest_with<E1, T1, T2, EE1, EE2>(circuit1: &T1, circuit2: &T2, expected: &Expect)
where
E1: CurveCycleEquipped,
E1::GE: DlogGroup,
<Dual<E1> as Engine>::GE: DlogGroup,
T1: StepCircuit<E1::Scalar>,
T2: StepCircuit<<Dual<E1> as Engine>::Scalar>,
EE1: EvaluationEngineTrait<E1>,
Expand Down Expand Up @@ -1112,62 +1112,38 @@ mod tests {

#[test]
fn test_pp_digest() {
let trivial_circuit1 = TrivialCircuit::<<PallasEngine as Engine>::Scalar>::default();
let trivial_circuit2 = TrivialCircuit::<<VestaEngine as Engine>::Scalar>::default();
let cubic_circuit1 = CubicCircuit::<<PallasEngine as Engine>::Scalar>::default();

test_pp_digest_with::<PallasEngine, _, _, EE<_>, EE<_>>(
&trivial_circuit1,
&trivial_circuit2,
&TrivialCircuit::default(),
&TrivialCircuit::default(),
&expect!["492fd902cd7174159bc9a6f827d92eb54ff25efa9d0673dffdb0efd02995df01"],
);

test_pp_digest_with::<PallasEngine, _, _, EE<_>, EE<_>>(
&cubic_circuit1,
&trivial_circuit2,
&CubicCircuit::default(),
&TrivialCircuit::default(),
&expect!["9b0701d9422658e3f74a85ab3e485c06f3ecca9c2b1800aab80004034d754f01"],
);

let trivial_circuit1_grumpkin = TrivialCircuit::<<Bn256Engine as Engine>::Scalar>::default();
let trivial_circuit2_grumpkin = TrivialCircuit::<<GrumpkinEngine as Engine>::Scalar>::default();
let cubic_circuit1_grumpkin = CubicCircuit::<<Bn256Engine as Engine>::Scalar>::default();

// These tests should not need be different on the "asm" feature for bn256.
// See https://github.com/privacy-scaling-explorations/halo2curves/issues/100 for why they are - closing the issue there
// should eliminate the discrepancy here.
test_pp_digest_with::<Bn256Engine, _, _, EE<_>, EE<_>>(
&trivial_circuit1_grumpkin,
&trivial_circuit2_grumpkin,
&TrivialCircuit::default(),
&TrivialCircuit::default(),
&expect!["1267235eb3d139e466dd9c814eaf73f01b063ccb4cad04848c0eb62f079a9601"],
);

test_pp_digest_with::<Bn256Engine, _, _, EE<_>, EE<_>>(
&cubic_circuit1_grumpkin,
&trivial_circuit2_grumpkin,
&CubicCircuit::default(),
&TrivialCircuit::default(),
&expect!["57afac2edd20d39b202151906e41154ba186c9dde497448d1332dc6de2f76302"],
);
test_pp_digest_with::<Bn256EngineZM, _, _, ZMPCS<Bn256, _>, EE<_>>(
&trivial_circuit1_grumpkin,
&trivial_circuit2_grumpkin,
&expect!["070d247d83e17411d65c12260980ebcc59df88d3882d84eb62e6ab466e381503"],
);
test_pp_digest_with::<Bn256EngineZM, _, _, ZMPCS<Bn256, _>, EE<_>>(
&cubic_circuit1_grumpkin,
&trivial_circuit2_grumpkin,
&expect!["47c2caa008323b588b47ab8b6c0e94f980599188abe117c4d21ffff81494f303"],
);

let trivial_circuit1_secp = TrivialCircuit::<<Secp256k1Engine as Engine>::Scalar>::default();
let trivial_circuit2_secp = TrivialCircuit::<<Secq256k1Engine as Engine>::Scalar>::default();
let cubic_circuit1_secp = CubicCircuit::<<Secp256k1Engine as Engine>::Scalar>::default();

test_pp_digest_with::<Secp256k1Engine, _, _, EE<_>, EE<_>>(
&trivial_circuit1_secp,
&trivial_circuit2_secp,
&TrivialCircuit::default(),
&TrivialCircuit::default(),
&expect!["04b5d1798be6d74b3701390b87078e70ebf3ddaad80c375319f320cedf8bca00"],
);
test_pp_digest_with::<Secp256k1Engine, _, _, EE<_>, EE<_>>(
&cubic_circuit1_secp,
&trivial_circuit2_secp,
&CubicCircuit::default(),
&TrivialCircuit::default(),
&expect!["346b5f27cf24c79386f4de7a8bfb58970181ae7f0de7d2e3f10ad5dfd8fc2302"],
);
}
Expand Down
37 changes: 23 additions & 14 deletions src/spartan/sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,24 @@ impl<E: Engine> SumcheckProof<E> {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i]);

let poly_A_right_term = poly_A[len + i] - poly_A[i];
let poly_B_right_term = poly_B[len + i] - poly_B[i];
let poly_C_right_term = poly_C[len + i] - poly_C[i];

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A[len + i] + poly_A_right_term;
let poly_B_bound_point = poly_B[len + i] + poly_B_right_term;
let poly_C_bound_point = poly_C[len + i] + poly_C_right_term;
let eval_point_2 = comb_func(
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
let poly_A_bound_point = poly_A_bound_point + poly_A_right_term;
let poly_B_bound_point = poly_B_bound_point + poly_B_right_term;
let poly_C_bound_point = poly_C_bound_point + poly_C_right_term;
let eval_point_3 = comb_func(
&poly_A_bound_point,
&poly_B_bound_point,
Expand Down Expand Up @@ -361,11 +365,16 @@ impl<E: Engine> SumcheckProof<E> {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]);

let poly_A_right_term = poly_A[len + i] - poly_A[i];
let poly_B_right_term = poly_B[len + i] - poly_B[i];
let poly_C_right_term = poly_C[len + i] - poly_C[i];
let poly_D_right_term = poly_D[len + i] - poly_D[i];

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i];
let poly_A_bound_point = poly_A[len + i] + poly_A_right_term;
let poly_B_bound_point = poly_B[len + i] + poly_B_right_term;
let poly_C_bound_point = poly_C[len + i] + poly_C_right_term;
let poly_D_bound_point = poly_D[len + i] + poly_D_right_term;
let eval_point_2 = comb_func(
&poly_A_bound_point,
&poly_B_bound_point,
Expand All @@ -374,10 +383,10 @@ impl<E: Engine> SumcheckProof<E> {
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i];
let poly_A_bound_point = poly_A_bound_point + poly_A_right_term;
let poly_B_bound_point = poly_B_bound_point + poly_B_right_term;
let poly_C_bound_point = poly_C_bound_point + poly_C_right_term;
let poly_D_bound_point = poly_D_bound_point + poly_D_right_term;
let eval_point_3 = comb_func(
&poly_A_bound_point,
&poly_B_bound_point,
Expand Down

0 comments on commit 9064bfe

Please sign in to comment.