From 9064bfeac73b8b5aec6a21a255e044700786eba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:03:59 -0500 Subject: [PATCH] chore: more Maintenance (#319) * refactor: Refactor public param digest tests in lib.rs - Refactored test functions within the tests module for greater efficiency and readability. - Removed `DlogGroup` trait usage within the `test_pp_digest_with` function. - Simplified `test_pp_digest` function by removing redundant circuit instantiation. * chore: Upgrade Rust Toolchain Version to 1.76.0 - Upgraded the rust toolchain version from `1.75` to `1.76.0` in `rust-toolchain.toml` file. - Closes #307 * chore: Refactor pprof dependency - Shifted `pprof` from `dev-dependencies` to `dependencies` for non-wasm32 targets to optimize benchmarking builds. - Introduced "flamegraph" feature to include `pprof` during benchmarking. - Closes #309 * refactor: Refactor computation in sumcheck module - Streamlined calculation of evaluation points in `compute_eval_points_quad` and `compute_eval_points_cubic_with_additive_term` functions within `src/spartan/sumcheck/mod.rs`. --- Cargo.toml | 2 +- rust-toolchain.toml | 2 +- src/lib.rs | 58 +++++++++++-------------------------- src/spartan/sumcheck/mod.rs | 37 ++++++++++++++--------- 4 files changed, 42 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c2520c2c..d3e8cf70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,9 @@ getrandom = { version = "0.2.0", default-features = false, features = ["js"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] proptest = "1.2.0" +pprof = { version = "0.13", optional = true } # in benches under feature "flamegraph" [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -pprof = { version = "0.13" } criterion = { version = "0.5", features = ["html_reports"] } [dev-dependencies] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 7d67641f..a58a147f 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,6 +1,6 @@ [toolchain] # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. profile = "default" -channel = "1.75" +channel = "1.76.0" targets = [ "wasm32-unknown-unknown" ] diff --git a/src/lib.rs b/src/lib.rs index 83eb263f..c103bb91 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1012,11 +1012,13 @@ type CE = ::CE; #[cfg(test)] mod tests { + use self::traits::CurveCycleEquipped; + use super::*; use crate::{ provider::{ - non_hiding_zeromorph::ZMPCS, traits::DlogGroup, Bn256Engine, Bn256EngineKZG, Bn256EngineZM, - GrumpkinEngine, PallasEngine, Secp256k1Engine, Secq256k1Engine, VestaEngine, + non_hiding_zeromorph::ZMPCS, Bn256Engine, Bn256EngineKZG, Bn256EngineZM, PallasEngine, + Secp256k1Engine, }, traits::{evaluation::EvaluationEngineTrait, snark::default_ck_hint}, }; @@ -1082,8 +1084,6 @@ mod tests { fn test_pp_digest_with(circuit1: &T1, circuit2: &T2, expected: &Expect) where E1: CurveCycleEquipped, - E1::GE: DlogGroup, - as Engine>::GE: DlogGroup, T1: StepCircuit, T2: StepCircuit< as Engine>::Scalar>, EE1: EvaluationEngineTrait, @@ -1112,62 +1112,38 @@ mod tests { #[test] fn test_pp_digest() { - let trivial_circuit1 = TrivialCircuit::<::Scalar>::default(); - let trivial_circuit2 = TrivialCircuit::<::Scalar>::default(); - let cubic_circuit1 = CubicCircuit::<::Scalar>::default(); - test_pp_digest_with::, EE<_>>( - &trivial_circuit1, - &trivial_circuit2, + &TrivialCircuit::default(), + &TrivialCircuit::default(), &expect!["492fd902cd7174159bc9a6f827d92eb54ff25efa9d0673dffdb0efd02995df01"], ); test_pp_digest_with::, EE<_>>( - &cubic_circuit1, - &trivial_circuit2, + &CubicCircuit::default(), + &TrivialCircuit::default(), &expect!["9b0701d9422658e3f74a85ab3e485c06f3ecca9c2b1800aab80004034d754f01"], ); - let trivial_circuit1_grumpkin = TrivialCircuit::<::Scalar>::default(); - let trivial_circuit2_grumpkin = TrivialCircuit::<::Scalar>::default(); - let cubic_circuit1_grumpkin = CubicCircuit::<::Scalar>::default(); - - // These tests should not need be different on the "asm" feature for bn256. - // See https://github.com/privacy-scaling-explorations/halo2curves/issues/100 for why they are - closing the issue there - // should eliminate the discrepancy here. test_pp_digest_with::, EE<_>>( - &trivial_circuit1_grumpkin, - &trivial_circuit2_grumpkin, + &TrivialCircuit::default(), + &TrivialCircuit::default(), &expect!["1267235eb3d139e466dd9c814eaf73f01b063ccb4cad04848c0eb62f079a9601"], ); + test_pp_digest_with::, EE<_>>( - &cubic_circuit1_grumpkin, - &trivial_circuit2_grumpkin, + &CubicCircuit::default(), + &TrivialCircuit::default(), &expect!["57afac2edd20d39b202151906e41154ba186c9dde497448d1332dc6de2f76302"], ); - test_pp_digest_with::, EE<_>>( - &trivial_circuit1_grumpkin, - &trivial_circuit2_grumpkin, - &expect!["070d247d83e17411d65c12260980ebcc59df88d3882d84eb62e6ab466e381503"], - ); - test_pp_digest_with::, EE<_>>( - &cubic_circuit1_grumpkin, - &trivial_circuit2_grumpkin, - &expect!["47c2caa008323b588b47ab8b6c0e94f980599188abe117c4d21ffff81494f303"], - ); - - let trivial_circuit1_secp = TrivialCircuit::<::Scalar>::default(); - let trivial_circuit2_secp = TrivialCircuit::<::Scalar>::default(); - let cubic_circuit1_secp = CubicCircuit::<::Scalar>::default(); test_pp_digest_with::, EE<_>>( - &trivial_circuit1_secp, - &trivial_circuit2_secp, + &TrivialCircuit::default(), + &TrivialCircuit::default(), &expect!["04b5d1798be6d74b3701390b87078e70ebf3ddaad80c375319f320cedf8bca00"], ); test_pp_digest_with::, EE<_>>( - &cubic_circuit1_secp, - &trivial_circuit2_secp, + &CubicCircuit::default(), + &TrivialCircuit::default(), &expect!["346b5f27cf24c79386f4de7a8bfb58970181ae7f0de7d2e3f10ad5dfd8fc2302"], ); } diff --git a/src/spartan/sumcheck/mod.rs b/src/spartan/sumcheck/mod.rs index 704f47cf..4a59b175 100644 --- a/src/spartan/sumcheck/mod.rs +++ b/src/spartan/sumcheck/mod.rs @@ -316,10 +316,14 @@ impl SumcheckProof { // eval 0: bound_func is A(low) let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); + let poly_A_right_term = poly_A[len + i] - poly_A[i]; + let poly_B_right_term = poly_B[len + i] - poly_B[i]; + let poly_C_right_term = poly_C[len + i] - poly_C[i]; + // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A[len + i] + poly_A_right_term; + let poly_B_bound_point = poly_B[len + i] + poly_B_right_term; + let poly_C_bound_point = poly_C[len + i] + poly_C_right_term; let eval_point_2 = comb_func( &poly_A_bound_point, &poly_B_bound_point, @@ -327,9 +331,9 @@ impl SumcheckProof { ); // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A_bound_point + poly_A_right_term; + let poly_B_bound_point = poly_B_bound_point + poly_B_right_term; + let poly_C_bound_point = poly_C_bound_point + poly_C_right_term; let eval_point_3 = comb_func( &poly_A_bound_point, &poly_B_bound_point, @@ -361,11 +365,16 @@ impl SumcheckProof { // eval 0: bound_func is A(low) let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); + let poly_A_right_term = poly_A[len + i] - poly_A[i]; + let poly_B_right_term = poly_B[len + i] - poly_B[i]; + let poly_C_right_term = poly_C[len + i] - poly_C[i]; + let poly_D_right_term = poly_D[len + i] - poly_D[i]; + // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; + let poly_A_bound_point = poly_A[len + i] + poly_A_right_term; + let poly_B_bound_point = poly_B[len + i] + poly_B_right_term; + let poly_C_bound_point = poly_C[len + i] + poly_C_right_term; + let poly_D_bound_point = poly_D[len + i] + poly_D_right_term; let eval_point_2 = comb_func( &poly_A_bound_point, &poly_B_bound_point, @@ -374,10 +383,10 @@ impl SumcheckProof { ); // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; - let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; + let poly_A_bound_point = poly_A_bound_point + poly_A_right_term; + let poly_B_bound_point = poly_B_bound_point + poly_B_right_term; + let poly_C_bound_point = poly_C_bound_point + poly_C_right_term; + let poly_D_bound_point = poly_D_bound_point + poly_D_right_term; let eval_point_3 = comb_func( &poly_A_bound_point, &poly_B_bound_point,