diff --git a/Cargo.toml b/Cargo.toml index 5bf485719..66b9c7487 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "monty-31", "poseidon", "poseidon2", + "poseidon2-air", "rescue", "sha256", "symmetric", diff --git a/audits/Least Authority-Polygon Plonky3 Audit Report-18 July 2024.pdf b/audits/Least Authority- Polygon Plonky3 Updated Final Audit Report.pdf similarity index 57% rename from audits/Least Authority-Polygon Plonky3 Audit Report-18 July 2024.pdf rename to audits/Least Authority- Polygon Plonky3 Updated Final Audit Report.pdf index 673ee5c3b..abb5826aa 100644 Binary files a/audits/Least Authority-Polygon Plonky3 Audit Report-18 July 2024.pdf and b/audits/Least Authority- Polygon Plonky3 Updated Final Audit Report.pdf differ diff --git a/baby-bear/Cargo.toml b/baby-bear/Cargo.toml index dc713fa8c..0b4e1601e 100644 --- a/baby-bear/Cargo.toml +++ b/baby-bear/Cargo.toml @@ -13,12 +13,12 @@ p3-mds = { path = "../mds" } p3-monty-31 = { path = "../monty-31" } p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } -num-bigint = { version = "0.4.3", default-features = false } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } [dev-dependencies] p3-field-testing = { path = "../field-testing" } +p3-dft = { path = "../dft" } rand = { version = "0.8.5", features = ["min_const_gen"] } criterion = "0.5.1" rand_chacha = "0.3.1" diff --git a/baby-bear/benches/bench_field.rs b/baby-bear/benches/bench_field.rs index f9238bafa..65edc7fc9 100644 --- a/baby-bear/benches/bench_field.rs +++ b/baby-bear/benches/bench_field.rs @@ -1,9 +1,12 @@ +use std::any::type_name; + use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use p3_baby_bear::BabyBear; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_field_testing::bench_func::{ benchmark_add_latency, benchmark_add_throughput, benchmark_inv, benchmark_iter_sum, - benchmark_sub_latency, benchmark_sub_throughput, + benchmark_mul_latency, benchmark_mul_throughput, benchmark_sub_latency, + benchmark_sub_throughput, }; type F = BabyBear; @@ -33,5 +36,20 @@ fn bench_field(c: &mut Criterion) { }); } -criterion_group!(baby_bear_arithmetic, bench_field); +fn bench_packedfield(c: &mut Criterion) { + let name = type_name::<::Packing>().to_string(); + // Note that each round of throughput has 10 operations + // So we should have 10 * more repetitions for latency tests. + const REPS: usize = 100; + const L_REPS: usize = 10 * REPS; + + benchmark_add_latency::<::Packing, L_REPS>(c, &name); + benchmark_add_throughput::<::Packing, REPS>(c, &name); + benchmark_sub_latency::<::Packing, L_REPS>(c, &name); + benchmark_sub_throughput::<::Packing, REPS>(c, &name); + benchmark_mul_latency::<::Packing, L_REPS>(c, &name); + benchmark_mul_throughput::<::Packing, REPS>(c, &name); +} + +criterion_group!(baby_bear_arithmetic, bench_field, bench_packedfield); criterion_main!(baby_bear_arithmetic); diff --git a/baby-bear/benches/extension.rs b/baby-bear/benches/extension.rs index a631ed186..0585c5e65 100644 --- a/baby-bear/benches/extension.rs +++ b/baby-bear/benches/extension.rs @@ -1,23 +1,32 @@ use criterion::{criterion_group, criterion_main, Criterion}; use p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; -use p3_field_testing::bench_func::{benchmark_inv, benchmark_mul, benchmark_square}; +use p3_field_testing::bench_func::{ + benchmark_inv, benchmark_mul_latency, benchmark_mul_throughput, benchmark_square, +}; type EF4 = BinomialExtensionField; type EF5 = BinomialExtensionField; +// Note that each round of throughput has 10 operations +// So we should have 10 * more repetitions for latency tests. +const REPS: usize = 100; +const L_REPS: usize = 10 * REPS; + fn bench_quartic_extension(c: &mut Criterion) { let name = "BinomialExtensionField"; benchmark_square::(c, name); benchmark_inv::(c, name); - benchmark_mul::(c, name); + benchmark_mul_throughput::(c, name); + benchmark_mul_latency::(c, name); } fn bench_qunitic_extension(c: &mut Criterion) { let name = "BinomialExtensionField"; benchmark_square::(c, name); benchmark_inv::(c, name); - benchmark_mul::(c, name); + benchmark_mul_throughput::(c, name); + benchmark_mul_latency::(c, name); } criterion_group!( diff --git a/baby-bear/src/aarch64_neon/packing.rs b/baby-bear/src/aarch64_neon/packing.rs index 4cede7c22..03c5b9bed 100644 --- a/baby-bear/src/aarch64_neon/packing.rs +++ b/baby-bear/src/aarch64_neon/packing.rs @@ -17,12 +17,10 @@ pub type PackedBabyBearNeon = PackedMontyField31Neon; #[cfg(test)] mod tests { - use p3_field::AbstractField; use p3_field_testing::test_packed_field; - use p3_monty_31::PackedMontyField31Neon; use super::WIDTH; - use crate::{BabyBear, BabyBearParameters}; + use crate::BabyBear; const SPECIAL_VALS: [BabyBear; WIDTH] = BabyBear::new_array([0x00000000, 0x00000001, 0x00000002, 0x78000000]); @@ -32,38 +30,4 @@ mod tests { crate::PackedBabyBearNeon::zero(), p3_monty_31::PackedMontyField31Neon::(super::SPECIAL_VALS) ); - - #[test] - fn test_cube_vs_mul() { - let vec = PackedMontyField31Neon::(BabyBear::new_array([ - 0x4efd5eaf, 0x311b8e0c, 0x74dd27c1, 0x449613f0, - ])); - let res0 = vec * vec.square(); - let res1 = vec.cube(); - assert_eq!(res0, res1); - } - - #[test] - fn test_cube_vs_scalar() { - let arr = BabyBear::new_array([0x57155037, 0x71bdcc8e, 0x301f94d, 0x435938a6]); - - let vec = PackedMontyField31Neon::(arr); - let vec_res = vec.cube(); - - #[allow(clippy::needless_range_loop)] - for i in 0..WIDTH { - assert_eq!(vec_res.0[i], arr[i].cube()); - } - } - - #[test] - fn test_cube_vs_scalar_special_vals() { - let vec = PackedMontyField31Neon::(SPECIAL_VALS); - let vec_res = vec.cube(); - - #[allow(clippy::needless_range_loop)] - for i in 0..WIDTH { - assert_eq!(vec_res.0[i], SPECIAL_VALS[i].cube()); - } - } } diff --git a/baby-bear/src/baby_bear.rs b/baby-bear/src/baby_bear.rs index 37f3bcf26..97c5a0f40 100644 --- a/baby-bear/src/baby_bear.rs +++ b/baby-bear/src/baby_bear.rs @@ -66,14 +66,24 @@ impl FieldParameters for BabyBearParameters { impl TwoAdicData for BabyBearParameters { const TWO_ADICITY: usize = 27; - type ArrayLike = [BabyBear; Self::TWO_ADICITY + 1]; + type ArrayLike = &'static [BabyBear]; - const TWO_ADIC_GENERATORS: Self::ArrayLike = BabyBear::new_array([ + const TWO_ADIC_GENERATORS: Self::ArrayLike = &BabyBear::new_array([ 0x1, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, 0x17b56c64, 0x67456167, 0x688442f9, 0x145e952d, 0x4fe61226, 0x4c734715, 0x11c33e2a, 0x62c3d2b1, 0x77cad399, 0x54c131f4, 0x4cabd6a6, 0x5cf5713f, 0x3e9430e8, 0xba067a3, 0x18adc27d, 0x21fd55bc, 0x4b859b3d, 0x3bd57996, 0x4483d85a, 0x3a26eef8, 0x1a427a41, ]); + + const ROOTS_8: Self::ArrayLike = &BabyBear::new_array([0x5ee99486, 0x67055c21, 0xc9ea3ba]); + const INV_ROOTS_8: Self::ArrayLike = &BabyBear::new_array([0x6b615c47, 0x10faa3e0, 0x19166b7b]); + + const ROOTS_16: Self::ArrayLike = &BabyBear::new_array([ + 0xbb4c4e4, 0x5ee99486, 0x4b49e08, 0x67055c21, 0x5376917a, 0xc9ea3ba, 0x563112a7, + ]); + const INV_ROOTS_16: Self::ArrayLike = &BabyBear::new_array([ + 0x21ceed5a, 0x6b615c47, 0x24896e87, 0x10faa3e0, 0x734b61f9, 0x19166b7b, 0x6c4b3b1d, + ]); } impl BinomialExtensionData<4> for BabyBearParameters { @@ -101,8 +111,8 @@ impl BinomialExtensionData<5> for BabyBearParameters { mod tests { use core::array; - use p3_field::{AbstractField, Field, PrimeField32, PrimeField64, TwoAdicField}; - use p3_field_testing::{test_field, test_two_adic_field}; + use p3_field::{PrimeField32, PrimeField64, TwoAdicField}; + use p3_field_testing::{test_field, test_field_dft, test_two_adic_field}; use super::*; @@ -215,4 +225,13 @@ mod tests { test_field!(crate::BabyBear); test_two_adic_field!(crate::BabyBear); + + test_field_dft!(radix2dit, crate::BabyBear, p3_dft::Radix2Dit<_>); + test_field_dft!(bowers, crate::BabyBear, p3_dft::Radix2Bowers); + test_field_dft!(parallel, crate::BabyBear, p3_dft::Radix2DitParallel); + test_field_dft!( + recur_dft, + crate::BabyBear, + p3_monty_31::dft::RecursiveDft<_> + ); } diff --git a/bn254-fr/Cargo.toml b/bn254-fr/Cargo.toml index fbf601c22..df6c7f52b 100644 --- a/bn254-fr/Cargo.toml +++ b/bn254-fr/Cargo.toml @@ -13,7 +13,7 @@ ff = { version = "0.13", features = ["derive", "derive_bits"] } num-bigint = { version = "0.4.3", default-features = false } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } -halo2curves = { version = "0.6.1", features = ["bits", "derive_serde"] } +halo2curves = { version = "0.7.0", features = ["bits", "derive_serde"] } [dev-dependencies] p3-field-testing = { path = "../field-testing" } diff --git a/circle/Cargo.toml b/circle/Cargo.toml index 4d3d63e03..fc697f2fa 100644 --- a/circle/Cargo.toml +++ b/circle/Cargo.toml @@ -24,7 +24,6 @@ p3-keccak = { path = "../keccak" } p3-mds = { path = "../mds" } p3-mersenne-31 = { path = "../mersenne-31" } p3-merkle-tree = { path = "../merkle-tree" } -p3-poseidon = { path = "../poseidon" } p3-symmetric = { path = "../symmetric" } hashbrown = "0.14.3" diff --git a/circle/benches/cfft.rs b/circle/benches/cfft.rs index 722b9f857..78998caf9 100644 --- a/circle/benches/cfft.rs +++ b/circle/benches/cfft.rs @@ -1,5 +1,3 @@ -use std::any::type_name; - use criterion::measurement::Measurement; use criterion::{criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion}; use p3_baby_bear::BabyBear; @@ -8,18 +6,10 @@ use p3_dft::{Radix2Bowers, Radix2Dit, Radix2DitParallel, TwoAdicSubgroupDft}; use p3_field::TwoAdicField; use p3_matrix::dense::RowMajorMatrix; use p3_mersenne_31::Mersenne31; +use p3_util::pretty_name; use rand::distributions::{Distribution, Standard}; use rand::thread_rng; -fn pretty_name() -> String { - let name = type_name::(); - let mut result = String::new(); - for qual in name.split_inclusive(&['<', '>', ',']) { - result.push_str(qual.split("::").last().unwrap()); - } - result -} - fn bench_lde(c: &mut Criterion) { let log_n = 18; let log_w = 8; diff --git a/circle/src/cfft.rs b/circle/src/cfft.rs index b01c3e59b..5eeefb0fc 100644 --- a/circle/src/cfft.rs +++ b/circle/src/cfft.rs @@ -273,7 +273,6 @@ pub fn circle_basis(p: Point, log_n: usize) -> Vec { mod tests { use itertools::iproduct; use p3_field::extension::BinomialExtensionField; - use p3_matrix::dense::RowMajorMatrix; use p3_mersenne_31::Mersenne31; use rand::{random, thread_rng}; diff --git a/circle/src/folding.rs b/circle/src/folding.rs index 367ab1f9e..9b4df876f 100644 --- a/circle/src/folding.rs +++ b/circle/src/folding.rs @@ -132,7 +132,6 @@ mod tests { use itertools::iproduct; use p3_field::extension::BinomialExtensionField; use p3_matrix::dense::RowMajorMatrix; - use p3_matrix::Matrix; use p3_mersenne_31::Mersenne31; use rand::{random, thread_rng}; diff --git a/circle/src/pcs.rs b/circle/src/pcs.rs index d98f74bd9..aed3b779f 100644 --- a/circle/src/pcs.rs +++ b/circle/src/pcs.rs @@ -15,7 +15,7 @@ use p3_matrix::{Dimensions, Matrix}; use p3_maybe_rayon::prelude::*; use p3_util::log2_strict_usize; use serde::{Deserialize, Serialize}; -use tracing::{info_span, instrument}; +use tracing::info_span; use crate::deep_quotient::{deep_quotient_reduce_row, extract_lambda}; use crate::domain::CircleDomain; @@ -131,7 +131,6 @@ where } } - #[instrument(skip_all)] fn open( &self, // For each round, @@ -468,7 +467,7 @@ where #[cfg(test)] mod tests { use p3_challenger::{HashChallenger, SerializingChallenger32}; - use p3_commit::{ExtensionMmcs, Pcs}; + use p3_commit::ExtensionMmcs; use p3_field::extension::BinomialExtensionField; use p3_keccak::Keccak256Hash; use p3_merkle_tree::FieldMerkleTreeMmcs; diff --git a/commit/Cargo.toml b/commit/Cargo.toml index 2fd6f87ca..1d7649403 100644 --- a/commit/Cargo.toml +++ b/commit/Cargo.toml @@ -20,6 +20,5 @@ serde = { version = "1.0", default-features = false } p3-dft = { path = "../dft", optional = true } [dev-dependencies] -p3-baby-bear = { path = "../baby-bear" } p3-dft = { path = "../dft" } rand = "0.8.5" diff --git a/dft/Cargo.toml b/dft/Cargo.toml index 284a930b2..99e64a98f 100644 --- a/dft/Cargo.toml +++ b/dft/Cargo.toml @@ -10,8 +10,10 @@ p3-matrix = { path = "../matrix" } p3-maybe-rayon = { path = "../maybe-rayon" } p3-util = { path = "../util" } tracing = "0.1.37" +itertools = "0.13.0" [dev-dependencies] +p3-monty-31 = { path = "../monty-31" } p3-baby-bear = { path = "../baby-bear" } p3-goldilocks = { path = "../goldilocks" } p3-mersenne-31 = { path = "../mersenne-31" } diff --git a/dft/benches/fft.rs b/dft/benches/fft.rs index 2ba4fcce5..088efbb60 100644 --- a/dft/benches/fft.rs +++ b/dft/benches/fft.rs @@ -1,5 +1,3 @@ -use std::any::type_name; - use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use p3_baby_bear::BabyBear; use p3_dft::{Radix2Bowers, Radix2Dit, Radix2DitParallel, TwoAdicSubgroupDft}; @@ -8,6 +6,8 @@ use p3_field::TwoAdicField; use p3_goldilocks::Goldilocks; use p3_matrix::dense::RowMajorMatrix; use p3_mersenne_31::{Mersenne31, Mersenne31ComplexRadix2Dit, Mersenne31Dft}; +use p3_monty_31::dft::RecursiveDft; +use p3_util::pretty_name; use rand::distributions::{Distribution, Standard}; use rand::thread_rng; @@ -15,12 +15,13 @@ fn bench_fft(c: &mut Criterion) { // log_sizes correspond to the sizes of DFT we want to benchmark; // for the DFT over the quadratic extension "Mersenne31Complex" a // fairer comparison is to use half sizes, which is the log minus 1. - let log_sizes = &[14, 16, 18]; + let log_sizes = &[14, 16, 18, 20, 22]; let log_half_sizes = &[13, 15, 17]; - const BATCH_SIZE: usize = 100; + const BATCH_SIZE: usize = 256; fft::, BATCH_SIZE>(c, log_sizes); + fft::, BATCH_SIZE>(c, log_sizes); fft::(c, log_sizes); fft::(c, log_sizes); fft::, BATCH_SIZE>(c, log_sizes); @@ -34,11 +35,13 @@ fn bench_fft(c: &mut Criterion) { m31_fft::, BATCH_SIZE>(c, log_sizes); m31_fft::(c, log_sizes); - ifft::, BATCH_SIZE>(c); + ifft::, BATCH_SIZE>(c, log_sizes); - coset_lde::(c); - coset_lde::(c); - coset_lde::(c); + coset_lde::, BATCH_SIZE>(c, log_sizes); + coset_lde::, BATCH_SIZE>(c, log_sizes); + coset_lde::(c, log_sizes); + coset_lde::(c, log_sizes); + coset_lde::(c, log_sizes); } fn fft(c: &mut Criterion, log_sizes: &[usize]) @@ -47,10 +50,10 @@ where Dft: TwoAdicSubgroupDft, Standard: Distribution, { - let mut group = c.benchmark_group(&format!( - "fft::<{}, {}, {}>", - type_name::(), - type_name::(), + let mut group = c.benchmark_group(format!( + "fft/{}/{}/ncols={}", + pretty_name::(), + pretty_name::(), BATCH_SIZE )); group.sample_size(10); @@ -75,9 +78,9 @@ where Dft: TwoAdicSubgroupDft>, Standard: Distribution, { - let mut group = c.benchmark_group(&format!( + let mut group = c.benchmark_group(format!( "m31_fft::<{}, {}>", - type_name::(), + pretty_name::(), BATCH_SIZE )); group.sample_size(10); @@ -96,22 +99,22 @@ where } } -fn ifft(c: &mut Criterion) +fn ifft(c: &mut Criterion, log_sizes: &[usize]) where F: TwoAdicField, Dft: TwoAdicSubgroupDft, Standard: Distribution, { - let mut group = c.benchmark_group(&format!( - "ifft::<{}, {}, {}>", - type_name::(), - type_name::(), + let mut group = c.benchmark_group(format!( + "ifft/{}/{}/ncols={}", + pretty_name::(), + pretty_name::(), BATCH_SIZE )); group.sample_size(10); let mut rng = thread_rng(); - for n_log in [14, 16, 18] { + for n_log in log_sizes { let n = 1 << n_log; let messages = RowMajorMatrix::rand(&mut rng, n, BATCH_SIZE); @@ -125,22 +128,22 @@ where } } -fn coset_lde(c: &mut Criterion) +fn coset_lde(c: &mut Criterion, log_sizes: &[usize]) where F: TwoAdicField, Dft: TwoAdicSubgroupDft, Standard: Distribution, { - let mut group = c.benchmark_group(&format!( - "coset_lde::<{}, {}, {}>", - type_name::(), - type_name::(), + let mut group = c.benchmark_group(format!( + "coset_lde/{}/{}/ncols={}", + pretty_name::(), + pretty_name::(), BATCH_SIZE )); group.sample_size(10); let mut rng = thread_rng(); - for n_log in [14, 16, 18] { + for n_log in log_sizes { let n = 1 << n_log; let messages = RowMajorMatrix::rand(&mut rng, n, BATCH_SIZE); diff --git a/dft/src/lib.rs b/dft/src/lib.rs index 1e546cb45..0c4cba06b 100644 --- a/dft/src/lib.rs +++ b/dft/src/lib.rs @@ -9,8 +9,6 @@ mod naive; mod radix_2_bowers; mod radix_2_dit; mod radix_2_dit_parallel; -#[cfg(test)] -mod testing; mod traits; mod util; diff --git a/dft/src/radix_2_bowers.rs b/dft/src/radix_2_bowers.rs index e7b568742..c821515bf 100644 --- a/dft/src/radix_2_bowers.rs +++ b/dft/src/radix_2_bowers.rs @@ -126,48 +126,3 @@ fn butterfly_layer>( }); }); } - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_goldilocks::Goldilocks; - - use crate::radix_2_bowers::Radix2Bowers; - use crate::testing::*; - - #[test] - fn dft_matches_naive() { - test_dft_matches_naive::(); - } - - #[test] - fn coset_dft_matches_naive() { - test_coset_dft_matches_naive::(); - } - - #[test] - fn idft_matches_naive() { - test_idft_matches_naive::(); - } - - #[test] - fn coset_idft_matches_naive() { - test_coset_idft_matches_naive::(); - test_coset_idft_matches_naive::(); - } - - #[test] - fn lde_matches_naive() { - test_lde_matches_naive::(); - } - - #[test] - fn coset_lde_matches_naive() { - test_coset_lde_matches_naive::(); - } - - #[test] - fn dft_idft_consistency() { - test_dft_idft_consistency::(); - } -} diff --git a/dft/src/radix_2_dit.rs b/dft/src/radix_2_dit.rs index e48ae63a1..0bbe0e2c1 100644 --- a/dft/src/radix_2_dit.rs +++ b/dft/src/radix_2_dit.rs @@ -67,48 +67,3 @@ fn dit_layer(mat: &mut RowMajorMatrixViewMut<'_, F>, layer: usize, twi }); }); } - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_goldilocks::Goldilocks; - - use crate::testing::*; - use crate::Radix2Dit; - - #[test] - fn dft_matches_naive() { - test_dft_matches_naive::>(); - } - - #[test] - fn coset_dft_matches_naive() { - test_coset_dft_matches_naive::>(); - } - - #[test] - fn idft_matches_naive() { - test_idft_matches_naive::>(); - } - - #[test] - fn coset_idft_matches_naive() { - test_coset_idft_matches_naive::>(); - test_coset_idft_matches_naive::>(); - } - - #[test] - fn lde_matches_naive() { - test_lde_matches_naive::>(); - } - - #[test] - fn coset_lde_matches_naive() { - test_coset_lde_matches_naive::>(); - } - - #[test] - fn dft_idft_consistency() { - test_dft_idft_consistency::>(); - } -} diff --git a/dft/src/radix_2_dit_parallel.rs b/dft/src/radix_2_dit_parallel.rs index d960f194e..a0b2ac4e0 100644 --- a/dft/src/radix_2_dit_parallel.rs +++ b/dft/src/radix_2_dit_parallel.rs @@ -1,5 +1,6 @@ use alloc::vec::Vec; +use itertools::izip; use p3_field::{Field, Powers, TwoAdicField}; use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut}; @@ -146,19 +147,22 @@ fn dit_layer( twiddles: &[F], ) { let layer_rev = log_h - 1 - layer; + let layer_pow = 1 << layer_rev; let half_block_size = 1 << layer; let block_size = half_block_size * 2; + let width = submat.width(); debug_assert!(submat.height() >= block_size); - for block_start in (0..submat.height()).step_by(block_size) { - for i in 0..half_block_size { - let hi = block_start + i; - let lo = hi + half_block_size; - let twiddle = twiddles[i << layer_rev]; + for block in submat.values.chunks_mut(block_size * width) { + let (lows, highs) = block.split_at_mut(half_block_size * width); - let (hi_chunk, lo_chunk) = submat.row_pair_mut(hi, lo); - DitButterfly(twiddle).apply_to_rows(hi_chunk, lo_chunk); + for (lo, hi, &twiddle) in izip!( + lows.chunks_mut(width), + highs.chunks_mut(width), + twiddles.iter().step_by(layer_pow) + ) { + DitButterfly(twiddle).apply_to_rows(lo, hi); } } } @@ -178,56 +182,12 @@ fn dit_layer_rev( let width = submat.width(); debug_assert!(submat.height() >= block_size); - for (block_i, block_start) in (0..submat.height()).step_by(block_size).enumerate() { - let twiddle = twiddles_rev[block_i]; - - let block = &mut submat.values[block_start * width..(block_start + block_size) * width]; + for (block, &twiddle) in submat + .values + .chunks_mut(block_size * width) + .zip(twiddles_rev) + { let (lo, hi) = block.split_at_mut(half_block_size * width); DitButterfly(twiddle).apply_to_rows(lo, hi) } } - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_goldilocks::Goldilocks; - - use crate::testing::*; - use crate::Radix2DitParallel; - - #[test] - fn dft_matches_naive() { - test_dft_matches_naive::(); - } - - #[test] - fn coset_dft_matches_naive() { - test_coset_dft_matches_naive::(); - } - - #[test] - fn idft_matches_naive() { - test_idft_matches_naive::(); - } - - #[test] - fn coset_idft_matches_naive() { - test_coset_idft_matches_naive::(); - test_coset_idft_matches_naive::(); - } - - #[test] - fn lde_matches_naive() { - test_lde_matches_naive::(); - } - - #[test] - fn coset_lde_matches_naive() { - test_coset_lde_matches_naive::(); - } - - #[test] - fn dft_idft_consistency() { - test_dft_idft_consistency::(); - } -} diff --git a/dft/src/traits.rs b/dft/src/traits.rs index 4c02de966..cbc422895 100644 --- a/dft/src/traits.rs +++ b/dft/src/traits.rs @@ -6,7 +6,7 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::util::swap_rows; use p3_matrix::Matrix; -use crate::util::divide_by_height; +use crate::util::{coset_shift_cols, divide_by_height}; pub trait TwoAdicSubgroupDft: Clone + Default { // Effectively this is either RowMajorMatrix or BitReversedMatrixView. @@ -41,13 +41,7 @@ pub trait TwoAdicSubgroupDft: Clone + Default { // = \sum_j (c_j s^j) (g^i)^j // which has the structure of an ordinary DFT, except each coefficient c_j is first replaced // by c_j s^j. - mat.rows_mut() - .zip(shift.powers()) - .for_each(|(row, weight)| { - row.iter_mut().for_each(|coeff| { - *coeff *= weight; - }) - }); + coset_shift_cols(&mut mat, shift); self.dft_batch(mat) } @@ -83,15 +77,7 @@ pub trait TwoAdicSubgroupDft: Clone + Default { /// subgroup itself. fn coset_idft_batch(&self, mut mat: RowMajorMatrix, shift: F) -> RowMajorMatrix { mat = self.idft_batch(mat); - - mat.rows_mut() - .zip(shift.inverse().powers()) - .for_each(|(row, weight)| { - row.iter_mut().for_each(|coeff| { - *coeff *= weight; - }) - }); - + coset_shift_cols(&mut mat, shift.inverse()); mat } diff --git a/dft/src/util.rs b/dft/src/util.rs index 1e7e07727..fe8531b74 100644 --- a/dft/src/util.rs +++ b/dft/src/util.rs @@ -18,6 +18,8 @@ pub fn divide_by_height + BorrowMut<[F]>>( /// so in actuality we're interleaving zero rows. #[inline] pub fn bit_reversed_zero_pad(mat: &mut RowMajorMatrix, added_bits: usize) { + // TODO: This is copypasta from matrix/dense.rs that should be + // refactored; it is only used by Radix2Bowers. if added_bits == 0 { return; } @@ -38,3 +40,14 @@ pub fn bit_reversed_zero_pad(mat: &mut RowMajorMatrix, added_bits: } *mat = RowMajorMatrix::new(values, w); } + +/// Multiply each element of row `i` of `mat` by `shift**i`. +pub(crate) fn coset_shift_cols(mat: &mut RowMajorMatrix, shift: F) { + mat.rows_mut() + .zip(shift.powers()) + .for_each(|(row, weight)| { + row.iter_mut().for_each(|coeff| { + *coeff *= weight; + }) + }); +} diff --git a/field-testing/Cargo.toml b/field-testing/Cargo.toml index 90fab46d2..a117fe26c 100644 --- a/field-testing/Cargo.toml +++ b/field-testing/Cargo.toml @@ -5,7 +5,9 @@ edition = "2021" license = "MIT OR Apache-2.0" [dependencies] +p3-dft = { path="../dft" } p3-field = { path = "../field" } +p3-matrix = { path="../matrix" } rand = { version = "0.8.5", features = ["min_const_gen"] } rand_chacha = "0.3.1" criterion = "0.5.1" diff --git a/field-testing/src/bench_func.rs b/field-testing/src/bench_func.rs index 24609ca4b..09b5455c2 100644 --- a/field-testing/src/bench_func.rs +++ b/field-testing/src/bench_func.rs @@ -2,7 +2,7 @@ use alloc::format; use alloc::vec::Vec; use criterion::{black_box, BatchSize, Criterion}; -use p3_field::Field; +use p3_field::{AbstractField, Field}; use rand::distributions::Standard; use rand::prelude::Distribution; use rand::Rng; @@ -29,18 +29,6 @@ where }); } -pub fn benchmark_mul(c: &mut Criterion, name: &str) -where - Standard: Distribution, -{ - let mut rng = rand::thread_rng(); - let x = rng.gen::(); - let y = rng.gen::(); - c.bench_function(&format!("{} mul", name), |b| { - b.iter(|| black_box(black_box(x) * black_box(y))) - }); -} - /// Benchmark the time taken to sum an array [F; N] using .sum() method. /// Repeat the summation REPS times. pub fn benchmark_iter_sum( @@ -65,45 +53,49 @@ pub fn benchmark_iter_sum( }); } -pub fn benchmark_add_latency(c: &mut Criterion, name: &str) -where - Standard: Distribution, +pub fn benchmark_add_latency( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, { - c.bench_function(&format!("{} add-latency/{}", name, N), |b| { + c.bench_function(&format!("add-latency/{} {}", N, name), |b| { b.iter_batched( || { let mut rng = rand::thread_rng(); let mut vec = Vec::new(); for _ in 0..N { - vec.push(rng.gen::()) + vec.push(rng.gen::()) } vec }, - |x| x.iter().fold(F::zero(), |x, y| x + *y), + |x| x.iter().fold(AF::zero(), |x, y| x + *y), BatchSize::SmallInput, ) }); } -pub fn benchmark_add_throughput(c: &mut Criterion, name: &str) -where - Standard: Distribution, +pub fn benchmark_add_throughput( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, { - c.bench_function(&format!("{} add-throughput/{}", name, N), |b| { + c.bench_function(&format!("add-throughput/{} {}", N, name), |b| { b.iter_batched( || { let mut rng = rand::thread_rng(); ( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), ) }, |(mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h, mut i, mut j)| { @@ -128,45 +120,49 @@ where }); } -pub fn benchmark_sub_latency(c: &mut Criterion, name: &str) -where - Standard: Distribution, +pub fn benchmark_sub_latency( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, { - c.bench_function(&format!("{} sub-latency/{}", name, N), |b| { + c.bench_function(&format!("sub-latency/{} {}", N, name), |b| { b.iter_batched( || { let mut rng = rand::thread_rng(); let mut vec = Vec::new(); for _ in 0..N { - vec.push(rng.gen::()) + vec.push(rng.gen::()) } vec }, - |x| x.iter().fold(F::zero(), |x, y| x - *y), + |x| x.iter().fold(AF::zero(), |x, y| x - *y), BatchSize::SmallInput, ) }); } -pub fn benchmark_sub_throughput(c: &mut Criterion, name: &str) -where - Standard: Distribution, +pub fn benchmark_sub_throughput( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, { - c.bench_function(&format!("{} sub-throughput/{}", name, N), |b| { + c.bench_function(&format!("sub-throughput/{} {}", N, name), |b| { b.iter_batched( || { let mut rng = rand::thread_rng(); ( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), ) }, |(mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h, mut i, mut j)| { @@ -190,3 +186,70 @@ where ) }); } + +pub fn benchmark_mul_latency( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, +{ + c.bench_function(&format!("mul-latency/{} {}", N, name), |b| { + b.iter_batched( + || { + let mut rng = rand::thread_rng(); + let mut vec = Vec::new(); + for _ in 0..N { + vec.push(rng.gen::()) + } + vec + }, + |x| x.iter().fold(AF::zero(), |x, y| x * *y), + BatchSize::SmallInput, + ) + }); +} + +pub fn benchmark_mul_throughput( + c: &mut Criterion, + name: &str, +) where + Standard: Distribution, +{ + c.bench_function(&format!("mul-throughput/{} {}", N, name), |b| { + b.iter_batched( + || { + let mut rng = rand::thread_rng(); + ( + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + ) + }, + |(mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h, mut i, mut j)| { + for _ in 0..N { + (a, b, c, d, e, f, g, h, i, j) = ( + a * b, + b * c, + c * d, + d * e, + e * f, + f * g, + g * h, + h * i, + i * j, + j * a, + ); + } + (a, b, c, d, e, f, g, h, i, j) + }, + BatchSize::SmallInput, + ) + }); +} diff --git a/dft/src/testing.rs b/field-testing/src/dft_testing.rs similarity index 65% rename from dft/src/testing.rs rename to field-testing/src/dft_testing.rs index af3ca6399..00031717e 100644 --- a/dft/src/testing.rs +++ b/field-testing/src/dft_testing.rs @@ -1,12 +1,11 @@ +use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; use p3_field::TwoAdicField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use rand::distributions::{Distribution, Standard}; use rand::thread_rng; -use crate::{NaiveDft, TwoAdicSubgroupDft}; - -pub(crate) fn test_dft_matches_naive() +pub fn test_dft_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -23,7 +22,7 @@ where } } -pub(crate) fn test_coset_dft_matches_naive() +pub fn test_coset_dft_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -41,7 +40,7 @@ where } } -pub(crate) fn test_idft_matches_naive() +pub fn test_idft_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -53,12 +52,12 @@ where let h = 1 << log_h; let mat = RowMajorMatrix::::rand(&mut rng, h, 3); let idft_naive = NaiveDft.idft_batch(mat.clone()); - let idft_result = dft.idft_batch(mat); - assert_eq!(idft_naive, idft_result); + let idft_result = dft.idft_batch(mat.clone()); + assert_eq!(idft_naive, idft_result.to_row_major_matrix()); } } -pub(crate) fn test_coset_idft_matches_naive() +pub fn test_coset_idft_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -72,11 +71,11 @@ where let shift = F::generator(); let idft_naive = NaiveDft.coset_idft_batch(mat.clone(), shift); let idft_result = dft.coset_idft_batch(mat, shift); - assert_eq!(idft_naive, idft_result); + assert_eq!(idft_naive, idft_result.to_row_major_matrix()); } } -pub(crate) fn test_lde_matches_naive() +pub fn test_lde_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -93,7 +92,7 @@ where } } -pub(crate) fn test_coset_lde_matches_naive() +pub fn test_coset_lde_matches_naive() where F: TwoAdicField, Standard: Distribution, @@ -111,7 +110,7 @@ where } } -pub(crate) fn test_dft_idft_consistency() +pub fn test_dft_idft_consistency() where F: TwoAdicField, Standard: Distribution, @@ -124,6 +123,48 @@ where let original = RowMajorMatrix::::rand(&mut rng, h, 3); let dft_output = dft.dft_batch(original.clone()); let idft_output = dft.idft_batch(dft_output.to_row_major_matrix()); - assert_eq!(original, idft_output); + assert_eq!(original, idft_output.to_row_major_matrix()); } } + +#[macro_export] +macro_rules! test_field_dft { + ($mod:ident, $field:ty, $dft:ty) => { + mod $mod { + #[test] + fn dft_matches_naive() { + $crate::test_dft_matches_naive::<$field, $dft>(); + } + + #[test] + fn coset_dft_matches_naive() { + $crate::test_coset_dft_matches_naive::<$field, $dft>(); + } + + #[test] + fn idft_matches_naive() { + $crate::test_idft_matches_naive::<$field, $dft>(); + } + + #[test] + fn coset_idft_matches_naive() { + $crate::test_coset_idft_matches_naive::<$field, $dft>(); + } + + #[test] + fn lde_matches_naive() { + $crate::test_lde_matches_naive::<$field, $dft>(); + } + + #[test] + fn coset_lde_matches_naive() { + $crate::test_coset_lde_matches_naive::<$field, $dft>(); + } + + #[test] + fn dft_idft_consistency() { + $crate::test_dft_idft_consistency::<$field, $dft>(); + } + } + }; +} diff --git a/field-testing/src/lib.rs b/field-testing/src/lib.rs index 9b47c2fba..a4b92bc31 100644 --- a/field-testing/src/lib.rs +++ b/field-testing/src/lib.rs @@ -5,9 +5,11 @@ extern crate alloc; pub mod bench_func; +pub mod dft_testing; pub mod packedfield_testing; pub use bench_func::*; +pub use dft_testing::*; use num_bigint::BigUint; use num_traits::identities::One; use p3_field::{ diff --git a/field-testing/src/packedfield_testing.rs b/field-testing/src/packedfield_testing.rs index 56e8c023b..92456ad16 100644 --- a/field-testing/src/packedfield_testing.rs +++ b/field-testing/src/packedfield_testing.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use p3_field::{Field, PackedField, PackedValue}; +use p3_field::{AbstractField, Field, PackedField, PackedValue}; use rand::distributions::{Distribution, Standard}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; @@ -208,6 +208,21 @@ where PF::two() * vec0, "Error when comparing x.double() to 2 * x." ); + assert_eq!( + vec0.exp_const_u64::<3>(), + vec0 * vec0 * vec0, + "Error when comparing x.exp_const_u64::<3> to x*x*x." + ); + assert_eq!( + vec0.exp_const_u64::<5>(), + vec0 * vec0 * vec0 * vec0 * vec0, + "Error when comparing x.exp_const_u64::<5> to x*x*x*x*x." + ); + assert_eq!( + vec0.exp_const_u64::<7>(), + vec0 * vec0 * vec0 * vec0 * vec0 * vec0 * vec0, + "Error when comparing x.exp_const_u64::<7> to x*x*x*x*x*x*x." + ); } pub fn test_distributivity() @@ -290,6 +305,21 @@ where let vec_special_neg = -vec_special; let arr_special_neg = vec_special_neg.as_slice(); + let vec_exp_3 = vec0.exp_const_u64::<3>(); + let arr_exp_3 = vec_exp_3.as_slice(); + let vec_special_exp_3 = vec_special.exp_const_u64::<3>(); + let arr_special_exp_3 = vec_special_exp_3.as_slice(); + + let vec_exp_5 = vec0.exp_const_u64::<5>(); + let arr_exp_5 = vec_exp_5.as_slice(); + let vec_special_exp_5 = vec_special.exp_const_u64::<5>(); + let arr_special_exp_5 = vec_special_exp_5.as_slice(); + + let vec_exp_7 = vec0.exp_const_u64::<7>(); + let arr_exp_7 = vec_exp_7.as_slice(); + let vec_special_exp_7 = vec_special.exp_const_u64::<7>(); + let arr_special_exp_7 = vec_special_exp_7.as_slice(); + let special_vals = special_vals.as_slice(); for i in 0..PF::WIDTH { assert_eq!( @@ -355,6 +385,36 @@ where "Error when testing consistency of neg for special values for packed and scalar at location {}.", i ); + assert_eq!(arr_exp_3[i], + arr0[i].exp_const_u64::<3>(), + "Error when testing exp_const_u64::<3> consistency of packed and scalar at location {}.", + i + ); + assert_eq!(arr_special_exp_3[i], + special_vals[i].exp_const_u64::<3>(), + "Error when testing consistency of exp_const_u64::<3> for special values for packed and scalar at location {}.", + i + ); + assert_eq!(arr_exp_5[i], + arr0[i].exp_const_u64::<5>(), + "Error when testing exp_const_u64::<5> consistency of packed and scalar at location {}.", + i + ); + assert_eq!(arr_special_exp_5[i], + special_vals[i].exp_const_u64::<5>(), + "Error when testing consistency of exp_const_u64::<5> for special values for packed and scalar at location {}.", + i + ); + assert_eq!(arr_exp_7[i], + arr0[i].exp_const_u64::<7>(), + "Error when testing exp_const_u64::<7> consistency of packed and scalar at location {}.", + i + ); + assert_eq!(arr_special_exp_7[i], + special_vals[i].exp_const_u64::<7>(), + "Error when testing consistency of exp_const_u64::<7> for special values for packed and scalar at location {}.", + i + ); } } diff --git a/goldilocks/benches/extension.rs b/goldilocks/benches/extension.rs index 74e698856..0a87cb7e9 100644 --- a/goldilocks/benches/extension.rs +++ b/goldilocks/benches/extension.rs @@ -1,15 +1,23 @@ use criterion::{criterion_group, criterion_main, Criterion}; use p3_field::extension::BinomialExtensionField; -use p3_field_testing::bench_func::{benchmark_inv, benchmark_mul, benchmark_square}; +use p3_field_testing::bench_func::{ + benchmark_inv, benchmark_mul_latency, benchmark_mul_throughput, benchmark_square, +}; use p3_goldilocks::Goldilocks; type EF2 = BinomialExtensionField; +// Note that each round of throughput has 10 operations +// So we should have 10 * more repetitions for latency tests. +const REPS: usize = 50; +const L_REPS: usize = 10 * REPS; + fn bench_qudratic_extension(c: &mut Criterion) { let name = "BinomialExtensionField"; benchmark_square::(c, name); benchmark_inv::(c, name); - benchmark_mul::(c, name); + benchmark_mul_throughput::(c, name); + benchmark_mul_latency::(c, name); } criterion_group!(bench_goldilocks_ef2, bench_qudratic_extension); diff --git a/goldilocks/src/lib.rs b/goldilocks/src/lib.rs index d5386a91a..979ba17c8 100644 --- a/goldilocks/src/lib.rs +++ b/goldilocks/src/lib.rs @@ -453,7 +453,7 @@ pub(crate) const fn to_goldilocks_array(input: [u64; N]) -> [Gol #[cfg(test)] mod tests { - use p3_field_testing::{test_field, test_two_adic_field}; + use p3_field_testing::{test_field, test_field_dft, test_two_adic_field}; use super::*; @@ -543,4 +543,8 @@ mod tests { test_field!(crate::Goldilocks); test_two_adic_field!(crate::Goldilocks); + + test_field_dft!(radix2dit, crate::Goldilocks, p3_dft::Radix2Dit<_>); + test_field_dft!(bowers, crate::Goldilocks, p3_dft::Radix2Bowers); + test_field_dft!(parallel, crate::Goldilocks, p3_dft::Radix2DitParallel); } diff --git a/keccak-air/Cargo.toml b/keccak-air/Cargo.toml index 2804df4c2..60a289515 100644 --- a/keccak-air/Cargo.toml +++ b/keccak-air/Cargo.toml @@ -14,6 +14,7 @@ tracing = "0.1.37" [dev-dependencies] p3-baby-bear = { path = "../baby-bear" } +p3-koala-bear = { path = "../koala-bear" } p3-challenger = { path = "../challenger" } p3-circle = { path = "../circle" } p3-commit = { path = "../commit" } @@ -25,9 +26,10 @@ p3-keccak = { path = "../keccak" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } p3-mersenne-31 = { path = "../mersenne-31" } +p3-monty-31 = { path = "../monty-31" } p3-poseidon = { path = "../poseidon" } p3-poseidon2 = { path = "../poseidon2" } -p3-sha256 = { path = "../sha256", features = ["asm"] } +p3-sha256 = { path = "../sha256" } p3-symmetric = { path = "../symmetric" } p3-uni-stark = { path = "../uni-stark" } rand = "0.8.5" @@ -39,3 +41,4 @@ tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } # We should be able to enable p3-maybe-rayon/parallel directly; this just doesn't # seem to work when using cargo with the -p or --package option. parallel = ["p3-maybe-rayon/parallel"] +asm = ["p3-sha256/asm"] diff --git a/keccak-air/examples/prove_baby_bear_keccak.rs b/keccak-air/examples/prove_baby_bear_keccak.rs index 51d872ffb..95887e750 100644 --- a/keccak-air/examples/prove_baby_bear_keccak.rs +++ b/keccak-air/examples/prove_baby_bear_keccak.rs @@ -3,12 +3,13 @@ use std::fmt::Debug; use p3_baby_bear::BabyBear; use p3_challenger::{HashChallenger, SerializingChallenger32}; use p3_commit::ExtensionMmcs; -use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_keccak::Keccak256Hash; use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_monty_31::dft::RecursiveDft; use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; use p3_uni_stark::{prove, verify, StarkConfig}; use rand::random; @@ -47,9 +48,6 @@ fn main() -> Result<(), impl Debug> { type ChallengeMmcs = ExtensionMmcs; let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); - type Dft = Radix2DitParallel; - let dft = Dft {}; - type Challenger = SerializingChallenger32>; let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); @@ -61,6 +59,9 @@ fn main() -> Result<(), impl Debug> { proof_of_work_bits: 16, mmcs: challenge_mmcs, }; + type Dft = RecursiveDft; + let dft = Dft::new(trace.height() << fri_config.log_blowup); + type Pcs = TwoAdicFriPcs; let pcs = Pcs::new(dft, val_mmcs, fri_config); diff --git a/keccak-air/examples/prove_baby_bear_poseidon2.rs b/keccak-air/examples/prove_baby_bear_poseidon2.rs index a2c7bb15a..816172e5f 100644 --- a/keccak-air/examples/prove_baby_bear_poseidon2.rs +++ b/keccak-air/examples/prove_baby_bear_poseidon2.rs @@ -3,12 +3,13 @@ use std::fmt::Debug; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_challenger::DuplexChallenger; use p3_commit::ExtensionMmcs; -use p3_dft::Radix2DitParallel; use p3_field::extension::BinomialExtensionField; use p3_field::Field; use p3_fri::{FriConfig, TwoAdicFriPcs}; use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_matrix::Matrix; use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_monty_31::dft::RecursiveDft; use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use p3_uni_stark::{prove, verify, StarkConfig}; @@ -59,14 +60,14 @@ fn main() -> Result<(), impl Debug> { type ChallengeMmcs = ExtensionMmcs; let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); - type Dft = Radix2DitParallel; - let dft = Dft {}; - type Challenger = DuplexChallenger; let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); let trace = generate_trace_rows::(inputs); + type Dft = RecursiveDft; + let dft = Dft::new(trace.height()); + let fri_config = FriConfig { log_blowup: 1, num_queries: 100, diff --git a/keccak-air/examples/prove_baby_bear_sha256.rs b/keccak-air/examples/prove_baby_bear_sha256.rs index 35126c308..469c16b2e 100644 --- a/keccak-air/examples/prove_baby_bear_sha256.rs +++ b/keccak-air/examples/prove_baby_bear_sha256.rs @@ -18,7 +18,7 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Registry}; -const NUM_HASHES: usize = 1365; +const NUM_HASHES: usize = 1_365; fn main() -> Result<(), impl Debug> { let env_filter = EnvFilter::builder() diff --git a/keccak-air/examples/prove_baby_bear_sha256_compress.rs b/keccak-air/examples/prove_baby_bear_sha256_compress.rs new file mode 100644 index 000000000..6b5bc2fb0 --- /dev/null +++ b/keccak-air/examples/prove_baby_bear_sha256_compress.rs @@ -0,0 +1,75 @@ +use std::fmt::Debug; + +use p3_baby_bear::BabyBear; +use p3_challenger::{HashChallenger, SerializingChallenger32}; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_sha256::{Sha256, Sha256Compress}; +use p3_symmetric::SerializingHasher32; +use p3_uni_stark::{prove, verify, StarkConfig}; +use rand::random; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const NUM_HASHES: usize = 1_365; + +fn main() -> Result<(), impl Debug> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = BabyBear; + type Challenge = BinomialExtensionField; + + type ByteHash = Sha256; + type FieldHash = SerializingHasher32; + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(byte_hash); + + type MyCompress = Sha256Compress; + let compress = MyCompress {}; + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(field_hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = SerializingChallenger32>; + + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(dft, val_mmcs, fri_config); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) +} diff --git a/keccak-air/examples/prove_koala_bear_poseidon2.rs b/keccak-air/examples/prove_koala_bear_poseidon2.rs new file mode 100644 index 000000000..b9febd1a1 --- /dev/null +++ b/keccak-air/examples/prove_koala_bear_poseidon2.rs @@ -0,0 +1,87 @@ +use std::fmt::Debug; + +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_keccak_air::{generate_trace_rows, KeccakAir}; +use p3_koala_bear::{DiffusionMatrixKoalaBear, KoalaBear}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{prove, verify, StarkConfig}; +use rand::{random, thread_rng}; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const NUM_HASHES: usize = 1365; + +fn main() -> Result<(), impl Debug> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = KoalaBear; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon2; + let perm = Perm::new_from_rng_128( + Poseidon2ExternalMatrixGeneral, + DiffusionMatrixKoalaBear::default(), + &mut thread_rng(), + ); + + type MyHash = PaddingFreeSponge; + let hash = MyHash::new(perm.clone()); + + type MyCompress = TruncatedPermutation; + let compress = MyCompress::new(perm.clone()); + + type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + MyHash, + MyCompress, + 8, + >; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::(inputs); + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(dft, val_mmcs, fri_config); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let mut challenger = Challenger::new(perm.clone()); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); + + let mut challenger = Challenger::new(perm); + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) +} diff --git a/koala-bear/Cargo.toml b/koala-bear/Cargo.toml index d23730e32..1e0a7fb8c 100644 --- a/koala-bear/Cargo.toml +++ b/koala-bear/Cargo.toml @@ -13,14 +13,12 @@ p3-mds = { path = "../mds" } p3-monty-31 = { path = "../monty-31" } p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } -num-bigint = { version = "0.4.3", default-features = false } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } [dev-dependencies] +p3-dft = { path = "../dft" } p3-field-testing = { path = "../field-testing" } -ark-ff = { version = "^0.4.0", default-features = false } -zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } rand = { version = "0.8.5", features = ["min_const_gen"] } criterion = "0.5.1" rand_chacha = "0.3.1" @@ -29,4 +27,4 @@ rand_xoshiro = "0.6.0" [[bench]] name = "bench_field" -harness = false \ No newline at end of file +harness = false diff --git a/koala-bear/benches/bench_field.rs b/koala-bear/benches/bench_field.rs index bab1c3ad0..522259c97 100644 --- a/koala-bear/benches/bench_field.rs +++ b/koala-bear/benches/bench_field.rs @@ -1,8 +1,11 @@ +use std::any::type_name; + use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_field_testing::bench_func::{ benchmark_add_latency, benchmark_add_throughput, benchmark_inv, benchmark_iter_sum, - benchmark_sub_latency, benchmark_sub_throughput, + benchmark_mul_latency, benchmark_mul_throughput, benchmark_sub_latency, + benchmark_sub_throughput, }; use p3_koala_bear::KoalaBear; @@ -33,5 +36,20 @@ fn bench_field(c: &mut Criterion) { }); } -criterion_group!(koala_bear_arithmetic, bench_field); +fn bench_packedfield(c: &mut Criterion) { + let name = type_name::<::Packing>().to_string(); + // Note that each round of throughput has 10 operations + // So we should have 10 * more repetitions for latency tests. + const REPS: usize = 100; + const L_REPS: usize = 10 * REPS; + + benchmark_add_latency::<::Packing, L_REPS>(c, &name); + benchmark_add_throughput::<::Packing, REPS>(c, &name); + benchmark_sub_latency::<::Packing, L_REPS>(c, &name); + benchmark_sub_throughput::<::Packing, REPS>(c, &name); + benchmark_mul_latency::<::Packing, L_REPS>(c, &name); + benchmark_mul_throughput::<::Packing, REPS>(c, &name); +} + +criterion_group!(koala_bear_arithmetic, bench_field, bench_packedfield); criterion_main!(koala_bear_arithmetic); diff --git a/koala-bear/src/aarch64_neon/packing.rs b/koala-bear/src/aarch64_neon/packing.rs index 86e16745e..add4a48a5 100644 --- a/koala-bear/src/aarch64_neon/packing.rs +++ b/koala-bear/src/aarch64_neon/packing.rs @@ -17,12 +17,10 @@ pub type PackedKoalaBearNeon = PackedMontyField31Neon; #[cfg(test)] mod tests { - use p3_field::AbstractField; use p3_field_testing::test_packed_field; - use p3_monty_31::PackedMontyField31Neon; use super::WIDTH; - use crate::{KoalaBear, KoalaBearParameters}; + use crate::KoalaBear; const SPECIAL_VALS: [KoalaBear; WIDTH] = KoalaBear::new_array([0x00000000, 0x00000001, 0x00000002, 0x7f000000]); @@ -32,38 +30,4 @@ mod tests { crate::PackedKoalaBearNeon::zero(), p3_monty_31::PackedMontyField31Neon::(super::SPECIAL_VALS) ); - - #[test] - fn test_cube_vs_mul() { - let vec = PackedMontyField31Neon::(KoalaBear::new_array([ - 0x4efd5eaf, 0x311b8e0c, 0x74dd27c1, 0x449613f0, - ])); - let res0 = vec * vec.square(); - let res1 = vec.cube(); - assert_eq!(res0, res1); - } - - #[test] - fn test_cube_vs_scalar() { - let arr = KoalaBear::new_array([0x57155037, 0x71bdcc8e, 0x301f94d, 0x435938a6]); - - let vec = PackedMontyField31Neon::(arr); - let vec_res = vec.cube(); - - #[allow(clippy::needless_range_loop)] - for i in 0..WIDTH { - assert_eq!(vec_res.0[i], arr[i].cube()); - } - } - - #[test] - fn test_cube_vs_scalar_special_vals() { - let vec = PackedMontyField31Neon::(SPECIAL_VALS); - let vec_res = vec.cube(); - - #[allow(clippy::needless_range_loop)] - for i in 0..WIDTH { - assert_eq!(vec_res.0[i], SPECIAL_VALS[i].cube()); - } - } } diff --git a/koala-bear/src/koala_bear.rs b/koala-bear/src/koala_bear.rs index b1b637909..e9ccd7ca8 100644 --- a/koala-bear/src/koala_bear.rs +++ b/koala-bear/src/koala_bear.rs @@ -66,14 +66,24 @@ impl FieldParameters for KoalaBearParameters { impl TwoAdicData for KoalaBearParameters { const TWO_ADICITY: usize = 24; - type ArrayLike = [KoalaBear; Self::TWO_ADICITY + 1]; + type ArrayLike = &'static [KoalaBear]; - const TWO_ADIC_GENERATORS: Self::ArrayLike = KoalaBear::new_array([ + const TWO_ADIC_GENERATORS: Self::ArrayLike = &KoalaBear::new_array([ 0x1, 0x7f000000, 0x7e010002, 0x6832fe4a, 0x8dbd69c, 0xa28f031, 0x5c4a5b99, 0x29b75a80, 0x17668b8a, 0x27ad539b, 0x334d48c7, 0x7744959c, 0x768fc6fa, 0x303964b2, 0x3e687d4d, 0x45a60e61, 0x6e2f4d7a, 0x163bd499, 0x6c4a8a45, 0x143ef899, 0x514ddcad, 0x484ef19b, 0x205d63c3, 0x68e7dd49, 0x6ac49f88, ]); + + const ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x6832fe4a, 0x7e010002, 0x174e3650]); + const INV_ROOTS_8: Self::ArrayLike = &KoalaBear::new_array([0x67b1c9b1, 0xfeffff, 0x16cd01b7]); + + const ROOTS_16: Self::ArrayLike = &KoalaBear::new_array([ + 0x8dbd69c, 0x6832fe4a, 0x27ae21e2, 0x7e010002, 0x3a89a025, 0x174e3650, 0x27dfce22, + ]); + const INV_ROOTS_16: Self::ArrayLike = &KoalaBear::new_array([ + 0x572031df, 0x67b1c9b1, 0x44765fdc, 0xfeffff, 0x5751de1f, 0x16cd01b7, 0x76242965, + ]); } impl BinomialExtensionData<4> for KoalaBearParameters { @@ -91,7 +101,7 @@ impl BinomialExtensionData<4> for KoalaBearParameters { #[cfg(test)] mod tests { use p3_field::{PrimeField32, PrimeField64, TwoAdicField}; - use p3_field_testing::{test_field, test_two_adic_field}; + use p3_field_testing::{test_field, test_field_dft, test_two_adic_field}; use super::*; @@ -195,4 +205,13 @@ mod tests { test_field!(crate::KoalaBear); test_two_adic_field!(crate::KoalaBear); + + test_field_dft!(radix2dit, crate::KoalaBear, p3_dft::Radix2Dit<_>); + test_field_dft!(bowers, crate::KoalaBear, p3_dft::Radix2Bowers); + test_field_dft!(parallel, crate::KoalaBear, p3_dft::Radix2DitParallel); + test_field_dft!( + recur_dft, + crate::KoalaBear, + p3_monty_31::dft::RecursiveDft<_> + ); } diff --git a/matrix/benches/transpose_benchmark.rs b/matrix/benches/transpose_benchmark.rs index 0b58aaf81..670172173 100644 --- a/matrix/benches/transpose_benchmark.rs +++ b/matrix/benches/transpose_benchmark.rs @@ -1,20 +1,33 @@ use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput}; use p3_matrix::dense::RowMajorMatrix; +use rand::thread_rng; fn transpose_benchmark(c: &mut Criterion) { const SMALL_DIMS: [(usize, usize); 4] = [(4, 4), (8, 8), (10, 10), (12, 12)]; const LARGE_DIMS: [(usize, usize); 4] = [(20, 8), (21, 8), (22, 8), (23, 8)]; let inner = |g: &mut BenchmarkGroup<_>, dims: &[(usize, usize)]| { + let mut rng = thread_rng(); for (lg_nrows, lg_ncols) in dims { let nrows = 1 << lg_nrows; let ncols = 1 << lg_ncols; - let matrix = RowMajorMatrix::new(vec![0u32; nrows * ncols], ncols); + let mut matrix1 = RowMajorMatrix::::rand(&mut rng, nrows, ncols); + let mut matrix2 = RowMajorMatrix::default(nrows, ncols); + let name = format!("2^{lg_nrows} x 2^{lg_ncols}"); g.throughput(Throughput::Bytes( (nrows * ncols * core::mem::size_of::()) as u64, )); - g.bench_function(&name, |b| b.iter(|| matrix.transpose())); + g.bench_function(&name, |b| b.iter(|| matrix1.transpose_into(&mut matrix2))); + + if nrows != ncols { + let matrix2 = RowMajorMatrix::rand(&mut rng, ncols, nrows); + let name = format!("2^{lg_ncols} x 2^{lg_nrows}"); + g.throughput(Throughput::Bytes( + (nrows * ncols * core::mem::size_of::()) as u64, + )); + g.bench_function(&name, |b| b.iter(|| matrix2.transpose_into(&mut matrix1))); + } } }; diff --git a/merkle-tree/src/merkle_tree.rs b/merkle-tree/src/merkle_tree.rs index 11967de26..21c22dfb3 100644 --- a/merkle-tree/src/merkle_tree.rs +++ b/merkle-tree/src/merkle_tree.rs @@ -50,23 +50,22 @@ impl, const DIGEST_ELEMS: usize> assert_eq!(P::WIDTH, PW::WIDTH, "Packing widths must match"); + let mut leaves_largest_first = leaves + .iter() + .sorted_by_key(|l| Reverse(l.height())) + .peekable(); + // check height property assert!( - leaves - .iter() + leaves_largest_first + .clone() .map(|m| m.height()) - .sorted() .tuple_windows() .all(|(curr, next)| curr == next || curr.next_power_of_two() != next.next_power_of_two()), "matrix heights that round up to the same power of two must be equal" ); - let mut leaves_largest_first = leaves - .iter() - .sorted_by_key(|l| Reverse(l.height())) - .peekable(); - let max_height = leaves_largest_first.peek().unwrap().height(); let tallest_matrices = leaves_largest_first .peeking_take_while(|m| m.height() == max_height) diff --git a/mersenne-31/benches/extension.rs b/mersenne-31/benches/extension.rs index 85019728b..f7c663194 100644 --- a/mersenne-31/benches/extension.rs +++ b/mersenne-31/benches/extension.rs @@ -1,23 +1,30 @@ use criterion::{criterion_group, criterion_main, Criterion}; use p3_field::extension::{BinomialExtensionField, Complex}; -use p3_field_testing::bench_func::{benchmark_inv, benchmark_mul, benchmark_square}; +use p3_field_testing::bench_func::{ + benchmark_inv, benchmark_mul_latency, benchmark_mul_throughput, benchmark_square, +}; use p3_mersenne_31::Mersenne31; type EF2 = BinomialExtensionField, 2>; type EF3 = BinomialExtensionField, 3>; +const REPS: usize = 100; +const L_REPS: usize = 10 * REPS; + fn bench_qudratic_extension(c: &mut Criterion) { let name = "BinomialExtensionField, 2>"; benchmark_square::(c, name); benchmark_inv::(c, name); - benchmark_mul::(c, name); + benchmark_mul_throughput::(c, name); + benchmark_mul_latency::(c, name); } fn bench_cubic_extension(c: &mut Criterion) { let name = "BinomialExtensionField, 3>"; benchmark_square::(c, name); benchmark_inv::(c, name); - benchmark_mul::(c, name); + benchmark_mul_throughput::(c, name); + benchmark_mul_latency::(c, name); } criterion_group!(bench_mersennecomplex_ef2, bench_qudratic_extension); diff --git a/monolith/Cargo.toml b/monolith/Cargo.toml index b5de02b51..b2a6982a9 100644 --- a/monolith/Cargo.toml +++ b/monolith/Cargo.toml @@ -7,7 +7,6 @@ license = "MIT OR Apache-2.0" [dependencies] generic-array = "1.0" p3-field = { path = "../field" } -p3-goldilocks = { path = "../goldilocks" } p3-mersenne-31 = { path = "../mersenne-31" } p3-mds = { path = "../mds" } p3-symmetric = { path = "../symmetric" } diff --git a/monty-31/Cargo.toml b/monty-31/Cargo.toml index 9e454b3f1..8e335af9f 100644 --- a/monty-31/Cargo.toml +++ b/monty-31/Cargo.toml @@ -8,10 +8,17 @@ license = "MIT OR Apache-2.0" nightly-features = [] [dependencies] +itertools = "0.13.0" +p3-dft = { path = "../dft" } p3-field = { path = "../field" } +p3-matrix = { path = "../matrix" } +p3-maybe-rayon = { path = "../maybe-rayon" } p3-mds = { path = "../mds" } p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } +p3-util = { path = "../util" } num-bigint = { version = "0.4.3", default-features = false } rand = "0.8.5" serde = { version = "1.0", default-features = false, features = ["derive"] } +tracing = "0.1.37" +transpose = "0.2.3" diff --git a/monty-31/src/data_traits.rs b/monty-31/src/data_traits.rs index 000344a3f..32e5028f6 100644 --- a/monty-31/src/data_traits.rs +++ b/monty-31/src/data_traits.rs @@ -89,12 +89,30 @@ pub trait TwoAdicData: MontyParameters { /// Largest n such that 2^n divides p - 1. const TWO_ADICITY: usize; - /// ArrayLike should usually be [MontyField31; TWO_ADICITY + 1]. + /// ArrayLike should usually be &'static [MontyField31]. type ArrayLike: AsRef<[MontyField31]> + Sized; /// A list of generators of 2-adic subgroups. /// The i'th element must be a 2^i root of unity and the i'th element squared must be the i-1'th element. const TWO_ADIC_GENERATORS: Self::ArrayLike; + + /// Precomputation of the first 3 8th-roots of unity. + /// + /// Must agree with the 8th-root in TWO_ADIC_GENERATORS, i.e. + /// ROOTS_8[0] == TWO_ADIC_GENERATORS[3] + const ROOTS_8: Self::ArrayLike; + + /// Precomputation of the inverses of ROOTS_8. + const INV_ROOTS_8: Self::ArrayLike; + + /// Precomputation of the first 7 16th-roots of unity. + /// + /// Must agree with the 16th-root in TWO_ADIC_GENERATORS, i.e. + /// ROOTS_16[0] == TWO_ADIC_GENERATORS[4] + const ROOTS_16: Self::ArrayLike; + + /// Precomputation of the inverses of ROOTS_16. + const INV_ROOTS_16: Self::ArrayLike; } /// TODO: This should be deleted long term once we have improved our API for defining extension fields. diff --git a/monty-31/src/dft/backward.rs b/monty-31/src/dft/backward.rs new file mode 100644 index 000000000..e6091b09e --- /dev/null +++ b/monty-31/src/dft/backward.rs @@ -0,0 +1,175 @@ +//! Discrete Fourier Transform, in-place, decimation-in-time +//! +//! Straightforward recursive algorithm, "unrolled" up to size 256. +//! +//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html + +extern crate alloc; +use alloc::vec::Vec; + +use itertools::izip; + +use crate::{monty_reduce, MontyField31, MontyParameters, TwoAdicData}; + +impl MontyField31 { + #[inline(always)] + fn backward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) { + let t = y * w; + (x + t, x - t) + } + + #[inline] + fn backward_pass(a: &mut [Self], roots: &[Self]) { + let half_n = a.len() / 2; + assert_eq!(roots.len(), half_n - 1); + + // Safe because 0 <= half_n < a.len() + let (top, tail) = unsafe { a.split_at_mut_unchecked(half_n) }; + + let s = top[0] + tail[0]; + let t = top[0] - tail[0]; + top[0] = s; + tail[0] = t; + + izip!(&mut top[1..], &mut tail[1..], roots).for_each(|(hi, lo, &root)| { + (*hi, *lo) = Self::backward_butterfly(*hi, *lo, root); + }); + } + + #[inline(always)] + fn backward_2(a: &mut [Self]) { + assert_eq!(a.len(), 2); + + let s = a[0] + a[1]; + let t = a[0] - a[1]; + a[0] = s; + a[1] = t; + } + + #[inline(always)] + fn backward_4(a: &mut [Self]) { + assert_eq!(a.len(), 4); + + // Read in bit-reversed order + let a0 = a[0]; + let a2 = a[1]; + let a1 = a[2]; + let a3 = a[3]; + + // Expanding the calculation of t3 saves one instruction + let t1 = MP::PRIME + a1.value - a3.value; + let t3 = MontyField31::new_monty(monty_reduce::( + t1 as u64 * MP::INV_ROOTS_8.as_ref()[1].value as u64, + )); + let t5 = a1 + a3; + let t4 = a0 + a2; + let t2 = a0 - a2; + + a[0] = t4 + t5; + a[1] = t2 + t3; + a[2] = t4 - t5; + a[3] = t2 - t3; + } + + #[inline(always)] + fn backward_8(a: &mut [Self]) { + assert_eq!(a.len(), 8); + + // Safe because a.len() == 8 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_4(a0); + Self::backward_4(a1); + + Self::backward_pass(a, MP::INV_ROOTS_8.as_ref()); + } + + #[inline(always)] + fn backward_16(a: &mut [Self]) { + assert_eq!(a.len(), 16); + + // Safe because a.len() == 16 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_8(a0); + Self::backward_8(a1); + + Self::backward_pass(a, MP::INV_ROOTS_16.as_ref()); + } + + #[inline(always)] + fn backward_32(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 32); + + // Safe because a.len() == 32 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_16(a0); + Self::backward_16(a1); + + Self::backward_pass(a, &root_table[0]); + } + + #[inline(always)] + fn backward_64(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 64); + + // Safe because a.len() == 64 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_32(a0, &root_table[1..]); + Self::backward_32(a1, &root_table[1..]); + + Self::backward_pass(a, &root_table[0]); + } + + #[inline(always)] + fn backward_128(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 128); + + // Safe because a.len() == 128 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_64(a0, &root_table[1..]); + Self::backward_64(a1, &root_table[1..]); + + Self::backward_pass(a, &root_table[0]); + } + + #[inline(always)] + fn backward_256(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 256); + + // Safe because a.len() == 256 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::backward_128(a0, &root_table[1..]); + Self::backward_128(a1, &root_table[1..]); + + Self::backward_pass(a, &root_table[0]); + } + + #[inline] + pub fn backward_fft(a: &mut [Self], root_table: &[Vec]) { + let n = a.len(); + if n == 1 { + return; + } + + assert_eq!(n, 1 << (root_table.len() + 1)); + match n { + 256 => Self::backward_256(a, root_table), + 128 => Self::backward_128(a, root_table), + 64 => Self::backward_64(a, root_table), + 32 => Self::backward_32(a, root_table), + 16 => Self::backward_16(a), + 8 => Self::backward_8(a), + 4 => Self::backward_4(a), + 2 => Self::backward_2(a), + _ => { + debug_assert!(n > 64); + + // Safe because a.len() > 64 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(n / 2) }; + Self::backward_fft(a0, &root_table[1..]); + Self::backward_fft(a1, &root_table[1..]); + + Self::backward_pass(a, &root_table[0]); + } + } + } +} diff --git a/monty-31/src/dft/forward.rs b/monty-31/src/dft/forward.rs new file mode 100644 index 000000000..f13e0d575 --- /dev/null +++ b/monty-31/src/dft/forward.rs @@ -0,0 +1,201 @@ +//! Discrete Fourier Transform, in-place, decimation-in-frequency +//! +//! Straightforward recursive algorithm, "unrolled" up to size 256. +//! +//! Inspired by Bernstein's djbfft: https://cr.yp.to/djbfft.html + +extern crate alloc; +use alloc::vec::Vec; + +use itertools::izip; +use p3_field::{AbstractField, TwoAdicField}; +use p3_util::log2_strict_usize; + +use crate::{monty_reduce, FieldParameters, MontyField31, MontyParameters, TwoAdicData}; + +impl MontyField31 { + /// Given a field element `gen` of order n where `n = 2^lg_n`, + /// return a vector of vectors `table` where table[i] is the + /// vector of twiddle factors for an fft of length n/2^i. The values + /// gen^0 = 1 are skipped, as are g_i^k for k >= i/2 as these are + /// just the negatives of the other roots (using g_i^{i/2} = -1). + pub fn roots_of_unity_table(n: usize) -> Vec> { + let lg_n = log2_strict_usize(n); + let gen = Self::two_adic_generator(lg_n); + let half_n = 1 << (lg_n - 1); + // nth_roots = [g, g^2, g^3, ..., g^{n/2 - 1}] + let nth_roots: Vec<_> = gen.powers().take(half_n).skip(1).collect(); + + (0..(lg_n - 1)) + .map(|i| { + nth_roots + .iter() + .skip((1 << i) - 1) + .step_by(1 << i) + .copied() + .collect() + }) + .collect() + } +} + +impl MontyField31 { + #[inline(always)] + fn forward_butterfly(x: Self, y: Self, w: Self) -> (Self, Self) { + let t = MP::PRIME + x.value - y.value; + ( + x + y, + Self::new_monty(monty_reduce::(t as u64 * w.value as u64)), + ) + } + + #[inline] + fn forward_pass(a: &mut [Self], roots: &[Self]) { + let half_n = a.len() / 2; + assert_eq!(roots.len(), half_n - 1); + + // Safe because 0 <= half_n < a.len() + let (top, tail) = unsafe { a.split_at_mut_unchecked(half_n) }; + + let s = top[0] + tail[0]; + let t = top[0] - tail[0]; + top[0] = s; + tail[0] = t; + + izip!(&mut top[1..], &mut tail[1..], roots).for_each(|(hi, lo, &root)| { + (*hi, *lo) = Self::forward_butterfly(*hi, *lo, root); + }); + } + + #[inline(always)] + fn forward_2(a: &mut [Self]) { + assert_eq!(a.len(), 2); + + let s = a[0] + a[1]; + let t = a[0] - a[1]; + a[0] = s; + a[1] = t; + } + + #[inline(always)] + fn forward_4(a: &mut [Self]) { + assert_eq!(a.len(), 4); + + // Expanding the calculation of t3 saves one instruction + let t1 = MP::PRIME + a[1].value - a[3].value; + let t3 = MontyField31::new_monty(monty_reduce::( + t1 as u64 * MP::ROOTS_8.as_ref()[1].value as u64, + )); + let t5 = a[1] + a[3]; + let t4 = a[0] + a[2]; + let t2 = a[0] - a[2]; + + // Return in bit-reversed order + a[0] = t4 + t5; + a[1] = t4 - t5; + a[2] = t2 + t3; + a[3] = t2 - t3; + } + + #[inline(always)] + fn forward_8(a: &mut [Self]) { + assert_eq!(a.len(), 8); + + Self::forward_pass(a, MP::ROOTS_8.as_ref()); + + // Safe because a.len() == 8 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_4(a0); + Self::forward_4(a1); + } + + #[inline(always)] + fn forward_16(a: &mut [Self]) { + assert_eq!(a.len(), 16); + + Self::forward_pass(a, MP::ROOTS_16.as_ref()); + + // Safe because a.len() == 16 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_8(a0); + Self::forward_8(a1); + } + + #[inline(always)] + fn forward_32(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 32); + + Self::forward_pass(a, &root_table[0]); + + // Safe because a.len() == 32 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_16(a0); + Self::forward_16(a1); + } + + #[inline(always)] + fn forward_64(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 64); + + Self::forward_pass(a, &root_table[0]); + + // Safe because a.len() == 64 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_32(a0, &root_table[1..]); + Self::forward_32(a1, &root_table[1..]); + } + + #[inline(always)] + fn forward_128(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 128); + + Self::forward_pass(a, &root_table[0]); + + // Safe because a.len() == 128 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_64(a0, &root_table[1..]); + Self::forward_64(a1, &root_table[1..]); + } + + #[inline(always)] + fn forward_256(a: &mut [Self], root_table: &[Vec]) { + assert_eq!(a.len(), 256); + + Self::forward_pass(a, &root_table[0]); + + // Safe because a.len() == 256 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(a.len() / 2) }; + Self::forward_128(a0, &root_table[1..]); + Self::forward_128(a1, &root_table[1..]); + } + + #[inline] + pub fn forward_fft(a: &mut [Self], root_table: &[Vec]) { + let n = a.len(); + if n == 1 { + return; + } + + assert_eq!(n, 1 << (root_table.len() + 1)); + match n { + 256 => Self::forward_256(a, root_table), + 128 => Self::forward_128(a, root_table), + 64 => Self::forward_64(a, root_table), + 32 => Self::forward_32(a, root_table), + 16 => Self::forward_16(a), + 8 => Self::forward_8(a), + 4 => Self::forward_4(a), + 2 => Self::forward_2(a), + _ => { + debug_assert!(n > 64); + Self::forward_pass(a, &root_table[0]); + + // Safe because a.len() > 64 + let (a0, a1) = unsafe { a.split_at_mut_unchecked(n / 2) }; + + Self::forward_fft(a0, &root_table[1..]); + Self::forward_fft(a1, &root_table[1..]); + } + } + } +} diff --git a/monty-31/src/dft/mod.rs b/monty-31/src/dft/mod.rs new file mode 100644 index 000000000..37e97c229 --- /dev/null +++ b/monty-31/src/dft/mod.rs @@ -0,0 +1,297 @@ +//! An implementation of the FFT for `MontyField31` +extern crate alloc; + +use alloc::vec; +use alloc::vec::Vec; +use core::cell::RefCell; +use core::mem::transmute; + +use itertools::izip; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{AbstractField, Field}; +use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use p3_maybe_rayon::prelude::*; +use tracing::{debug_span, instrument}; + +mod backward; +mod forward; + +use crate::{FieldParameters, MontyField31, MontyParameters, TwoAdicData}; + +/// Multiply each element of column `j` of `mat` by `shift**j`. +#[instrument(level = "debug", skip_all)] +fn coset_shift_and_scale_rows( + out: &mut [F], + out_ncols: usize, + mat: &[F], + ncols: usize, + shift: F, + scale: F, +) { + let powers = shift.shifted_powers(scale).take(ncols).collect::>(); + out.par_chunks_exact_mut(out_ncols) + .zip(mat.par_chunks_exact(ncols)) + .for_each(|(out_row, in_row)| { + izip!(out_row.iter_mut(), in_row, &powers).for_each(|(out, &coeff, &weight)| { + *out = coeff * weight; + }); + }); +} + +/// Recursive DFT, decimation-in-frequency in the forward direction, +/// decimation-in-time in the backward (inverse) direction. +#[derive(Clone, Debug, Default)] +pub struct RecursiveDft { + /// Memoized twiddle factors for each length log_n. + /// + /// TODO: The use of RefCell means this can't be shared across + /// threads; consider using RwLock or finding a better design + /// instead. + twiddles: RefCell>>, + inv_twiddles: RefCell>>, +} + +impl RecursiveDft> { + pub fn new(n: usize) -> Self { + let res = Self { + twiddles: RefCell::default(), + inv_twiddles: RefCell::default(), + }; + res.update_twiddles(n); + res + } + + #[inline] + fn decimation_in_freq_dft( + mat: &mut [MontyField31], + ncols: usize, + twiddles: &[Vec>], + ) { + if ncols > 1 { + let lg_fft_len = p3_util::log2_ceil_usize(ncols); + let roots_idx = (twiddles.len() + 1) - lg_fft_len; + let twiddles = &twiddles[roots_idx..]; + + mat.par_chunks_exact_mut(ncols) + .for_each(|v| MontyField31::forward_fft(v, twiddles)) + } + } + + #[inline] + fn decimation_in_time_dft( + mat: &mut [MontyField31], + ncols: usize, + twiddles: &[Vec>], + ) { + if ncols > 1 { + let lg_fft_len = p3_util::log2_ceil_usize(ncols); + let roots_idx = (twiddles.len() + 1) - lg_fft_len; + let twiddles = &twiddles[roots_idx..]; + + mat.par_chunks_exact_mut(ncols) + .for_each(|v| MontyField31::backward_fft(v, twiddles)) + } + } + + /// Compute twiddle factors, or take memoized ones if already available. + #[instrument(skip_all)] + fn update_twiddles(&self, fft_len: usize) { + // TODO: This recomputes the entire table from scratch if we + // need it to be bigger, which is wasteful. + + // As we don't save the twiddles for the final layer where + // the only twiddle is 1, roots_of_unity_table(fft_len) + // returns a vector of twiddles of length log_2(fft_len) - 1. + let curr_max_fft_len = 2 << self.twiddles.borrow().len(); + if fft_len > curr_max_fft_len { + let new_twiddles = MontyField31::roots_of_unity_table(fft_len); + // We can obtain the inverse twiddles by reversing and + // negating the twiddles. + let new_inv_twiddles = new_twiddles + .iter() + .map(|ts| { + ts.iter() + .rev() + // A twiddle t is never zero, so negation simplifies + // to P - t. + .map(|&t| MontyField31::new_monty(MP::PRIME - t.value)) + .collect() + }) + .collect(); + self.twiddles.replace(new_twiddles); + self.inv_twiddles.replace(new_inv_twiddles); + } + } +} + +/// DFT implementation that uses DIT for the inverse "backward" +/// direction and DIF for the "forward" direction. +/// +/// The API mandates that the LDE is applied column-wise on the +/// _row-major_ input. This is awkward for memory coherence, so the +/// algorithm here transposes the input and operates on the rows in +/// the typical way, then transposes back again for the output. Even +/// for modestly large inputs, the cost of the two tranposes +/// outweighed by the improved performance from operating row-wise. +/// +/// The choice of DIT for inverse and DIF for "forward" transform mean +/// that a (coset) LDE +/// +/// - IDFT / zero extend / DFT +/// +/// expands to +/// +/// - bit-reverse input +/// - invDFT DIT +/// - result is in "correct" order +/// - coset shift and zero extend result +/// - DFT DIF on result +/// - output is bit-reversed, as required for FRI. +/// +/// Hence the only bit-reversal that needs to take place is on the input. +/// +impl TwoAdicSubgroupDft> + for RecursiveDft> +{ + type Evaluations = BitReversedMatrixView>>; + + #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))] + fn dft_batch(&self, mut mat: RowMajorMatrix>) -> Self::Evaluations + where + MP: MontyParameters + FieldParameters + TwoAdicData, + { + let nrows = mat.height(); + let ncols = mat.width(); + if nrows <= 1 { + return mat.bit_reverse_rows(); + } + + let mut scratch = debug_span!("allocate scratch space") + .in_scope(|| RowMajorMatrix::default(nrows, ncols)); + + self.update_twiddles(nrows); + let twiddles = self.twiddles.borrow(); + + // transpose input + debug_span!("pre-transpose", nrows, ncols) + .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows)); + + debug_span!("dft batch", n_dfts = ncols, fft_len = nrows) + .in_scope(|| Self::decimation_in_freq_dft(&mut scratch.values, nrows, &twiddles)); + + // transpose output + debug_span!("post-transpose", nrows = ncols, ncols = nrows) + .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols)); + + mat.bit_reverse_rows() + } + + #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))] + fn idft_batch(&self, mat: RowMajorMatrix>) -> RowMajorMatrix> + where + MP: MontyParameters + FieldParameters + TwoAdicData, + { + let nrows = mat.height(); + let ncols = mat.width(); + if nrows <= 1 { + return mat; + } + + let mut scratch = debug_span!("allocate scratch space") + .in_scope(|| RowMajorMatrix::default(nrows, ncols)); + + let mut mat = + debug_span!("initial bitrev").in_scope(|| mat.bit_reverse_rows().to_row_major_matrix()); + + self.update_twiddles(nrows); + let inv_twiddles = self.inv_twiddles.borrow(); + + // transpose input + debug_span!("pre-transpose", nrows, ncols) + .in_scope(|| transpose::transpose(&mat.values, &mut scratch.values, ncols, nrows)); + + debug_span!("idft", n_dfts = ncols, fft_len = nrows) + .in_scope(|| Self::decimation_in_time_dft(&mut scratch.values, nrows, &inv_twiddles)); + + // transpose output + debug_span!("post-transpose", nrows = ncols, ncols = nrows) + .in_scope(|| transpose::transpose(&scratch.values, &mut mat.values, nrows, ncols)); + + let inv_len = MontyField31::from_canonical_usize(nrows).inverse(); + debug_span!("scale").in_scope(|| mat.scale(inv_len)); + mat + } + + #[instrument(skip_all, fields(dims = %mat.dimensions(), added_bits))] + fn coset_lde_batch( + &self, + mat: RowMajorMatrix>, + added_bits: usize, + shift: MontyField31, + ) -> Self::Evaluations { + let nrows = mat.height(); + let ncols = mat.width(); + let result_nrows = nrows << added_bits; + + if nrows == 1 { + let dupd_rows = core::iter::repeat(mat.values) + .take(result_nrows) + .flatten() + .collect(); + return RowMajorMatrix::new(dupd_rows, ncols).bit_reverse_rows(); + } + + let input_size = nrows * ncols; + let output_size = result_nrows * ncols; + + let mat = mat.bit_reverse_rows().to_row_major_matrix(); + + // Allocate space for the output and the intermediate state. + // NB: The unsafe version below takes well under 1ms, whereas doing + // vec![MontyField31::zero(); output_size]) + // takes 100s of ms. Safety is expensive. + let (mut output, mut padded) = debug_span!("allocate scratch space").in_scope(|| unsafe { + // Safety: These are pretty dodgy, but work because MontyField31 is #[repr(transparent)] + let output = transmute::, Vec>>(vec![0u32; output_size]); + let padded = transmute::, Vec>>(vec![0u32; output_size]); + (output, padded) + }); + + // `coeffs` will hold the result of the inverse FFT; use the + // output storage as scratch space. + let coeffs = &mut output[..input_size]; + + debug_span!("pre-transpose", nrows, ncols) + .in_scope(|| transpose::transpose(&mat.values, coeffs, ncols, nrows)); + + // Apply inverse DFT; result is not yet normalised. + self.update_twiddles(result_nrows); + let inv_twiddles = self.inv_twiddles.borrow(); + debug_span!("inverse dft batch", n_dfts = ncols, fft_len = nrows) + .in_scope(|| Self::decimation_in_time_dft(coeffs, nrows, &inv_twiddles)); + + // At this point the inverse FFT of each column of `mat` appears + // as a row in `coeffs`. + + // Normalise inverse DFT and coset shift in one go. + let inv_len = MontyField31::from_canonical_usize(nrows).inverse(); + coset_shift_and_scale_rows(&mut padded, result_nrows, coeffs, nrows, shift, inv_len); + + // `padded` is implicitly zero padded since it was initialised + // to zeros when declared above. + + let twiddles = self.twiddles.borrow(); + + // Apply DFT + debug_span!("dft batch", n_dfts = ncols, fft_len = result_nrows) + .in_scope(|| Self::decimation_in_freq_dft(&mut padded, result_nrows, &twiddles)); + + // transpose output + debug_span!("post-transpose", nrows = ncols, ncols = result_nrows) + .in_scope(|| transpose::transpose(&padded, &mut output, result_nrows, ncols)); + + RowMajorMatrix::new(output, ncols).bit_reverse_rows() + } +} diff --git a/monty-31/src/lib.rs b/monty-31/src/lib.rs index 1b6444853..254c1feed 100644 --- a/monty-31/src/lib.rs +++ b/monty-31/src/lib.rs @@ -9,12 +9,12 @@ )] mod data_traits; +pub mod dft; mod extension; mod mds; mod monty_31; mod poseidon2; mod utils; - pub use data_traits::*; pub use mds::*; pub use monty_31::*; diff --git a/monty-31/src/x86_64_avx2/packing.rs b/monty-31/src/x86_64_avx2/packing.rs index 2a6961d26..f72383fb7 100644 --- a/monty-31/src/x86_64_avx2/packing.rs +++ b/monty-31/src/x86_64_avx2/packing.rs @@ -160,21 +160,76 @@ fn add(lhs: __m256i, rhs: __m256i) -> __m256i { // definition of Q and μ, we have Q P = μ C P = P^-1 C P = C (mod B). We also have // C - Q P = C (mod P), so thus D = C B^-1 (mod P). // -// It remains to show that R is in the correct range. It suffices to show that -P <= D < P. We know +// It remains to show that R is in the correct range. It suffices to show that -P < D < P. We know // that 0 <= C < P B and 0 <= Q P < P B. Then -P B < C - QP < P B and -P < D < P, as desired. // // [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press, // 2010, algorithm 2.7. +// We provide 2 variants of Montgomery reduction depending on if the inputs are unsigned or signed. +// The unsigned variant follows steps 1 and 2 in the above protocol to produce D in (-P, ..., P). +// For the signed variant we assume -PB/2 < C < PB/2 and let Q := μ C mod B be the unique +// representative in [-B/2, ..., B/2 - 1]. The division in step 2 is clearly still exact and +// |C - Q P| <= |C| + |Q||P| < PB so D still lies in (-P, ..., P). + +/// Perform a partial Montgomery reduction on each 64 bit element. +/// Input must lie in {0, ..., 2^32P}. +/// The output will lie in {-P, ..., P} and be stored in the upper 32 bits. +#[inline] +#[must_use] +fn partial_monty_red_unsigned_to_signed(input: __m256i) -> __m256i { + unsafe { + let q = x86_64::_mm256_mul_epu32(input, MPAVX2::PACKED_MU); + let q_p = x86_64::_mm256_mul_epu32(q, MPAVX2::PACKED_P); + + // By construction, the bottom 32 bits of input and q_p are equal. + // Thus _mm256_sub_epi32 and _mm256_sub_epi64 should act identically. + // However for some reason, the compiler gets confused if we use _mm256_sub_epi64 + // and outputs a load of nonsense, see: https://godbolt.org/z/3W8M7Tv84. + x86_64::_mm256_sub_epi32(input, q_p) + } +} + +/// Perform a partial Montgomery reduction on each 64 bit element. +/// Input must lie in {-2^{31}P, ..., 2^31P}. +/// The output will lie in {-P, ..., P} and be stored in the upper 32 bits. +#[inline] +#[must_use] +fn partial_monty_red_signed_to_signed(input: __m256i) -> __m256i { + unsafe { + let q = x86_64::_mm256_mul_epi32(input, MPAVX2::PACKED_MU); + let q_p = x86_64::_mm256_mul_epi32(q, MPAVX2::PACKED_P); + + // Unlike the previous case the compiler output is essentially identical + // between _mm256_sub_epi32 and _mm256_sub_epi64. We use _mm256_sub_epi32 + // again just for consistency. + x86_64::_mm256_sub_epi32(input, q_p) + } +} + +/// Multiply the MontyField31 field elements in the even index entries. +/// lhs[2i], rhs[2i] must be unsigned 32-bit integers such that +/// lhs[2i] * rhs[2i] lies in {0, ..., 2^32P}. +/// The output will lie in {-P, ..., P} and be stored in output[2i + 1]. #[inline] #[must_use] -#[allow(non_snake_case)] -fn monty_d(lhs: __m256i, rhs: __m256i) -> __m256i { +fn monty_mul(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { let prod = x86_64::_mm256_mul_epu32(lhs, rhs); - let q = x86_64::_mm256_mul_epu32(prod, MPAVX2::PACKED_MU); - let q_P = x86_64::_mm256_mul_epu32(q, MPAVX2::PACKED_P); - x86_64::_mm256_sub_epi64(prod, q_P) + partial_monty_red_unsigned_to_signed::(prod) + } +} + +/// Multiply the MontyField31 field elements in the even index entries. +/// lhs[2i], rhs[2i] must be signed 32-bit integers such that +/// lhs[2i] * rhs[2i] lies in {-2^31P, ..., 2^31P}. +/// The output will lie in {-P, ..., P} stored in output[2i + 1]. +#[inline] +#[must_use] +fn monty_mul_signed(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { + let prod = x86_64::_mm256_mul_epi32(lhs, rhs); + partial_monty_red_signed_to_signed::(prod) } } @@ -216,8 +271,8 @@ fn mul(lhs: __m256i, rhs: __m256i) -> __m256i { let lhs_odd = movehdup_epi32(lhs); let rhs_odd = movehdup_epi32(rhs); - let d_evn = monty_d::(lhs_evn, rhs_evn); - let d_odd = monty_d::(lhs_odd, rhs_odd); + let d_evn = monty_mul::(lhs_evn, rhs_evn); + let d_odd = monty_mul::(lhs_odd, rhs_odd); let d_evn_hi = movehdup_epi32(d_evn); let t = x86_64::_mm256_blend_epi32::<0b10101010>(d_evn_hi, d_odd); @@ -227,6 +282,80 @@ fn mul(lhs: __m256i, rhs: __m256i) -> __m256i { } } +/// Square the MontyField31 field elements in the even index entries. +/// Inputs must be signed 32-bit integers. +/// Outputs will be a signed integer in (-P, ..., P) copied into both the even and odd indices. +#[inline] +#[must_use] +fn shifted_square(input: __m256i) -> __m256i { + // Note that we do not need a restriction on the size of input[i]^2 as + // 2^30 < P and |i32| <= 2^31 and so => input[i]^2 <= 2^62 < 2^32P. + unsafe { + let square = x86_64::_mm256_mul_epi32(input, input); + let square_red = partial_monty_red_unsigned_to_signed::(square); + movehdup_epi32(square_red) + } +} + +/// Cube the MontyField31 field elements in the even index entries. +/// Inputs must be signed 32-bit integers in [-P, ..., P]. +/// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices. +#[inline] +#[must_use] +fn packed_exp_3(input: __m256i) -> __m256i { + let square = shifted_square::(input); + monty_mul_signed::(square, input) +} + +/// Take the fifth power of the MontyField31 field elements in the even index entries. +/// Inputs must be signed 32-bit integers in [-P, ..., P]. +/// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices. +#[inline] +#[must_use] +fn packed_exp_5(input: __m256i) -> __m256i { + let square = shifted_square::(input); + let quad = shifted_square::(square); + monty_mul_signed::(quad, input) +} + +/// Take the seventh power of the MontyField31 field elements in the even index entries. +/// Inputs must lie in [-P, ..., P]. +/// Outputs will also lie in (-P, ..., P) stored in the odd indices. +#[inline] +#[must_use] +fn packed_exp_7(input: __m256i) -> __m256i { + let square = shifted_square::(input); + let cube = monty_mul_signed::(square, input); + let cube_shifted = movehdup_epi32(cube); + let quad = shifted_square::(square); + + monty_mul_signed::(quad, cube_shifted) +} + +/// Apply func to the even and odd indices of the input vector. +/// func should only depend in the 32 bit entries in the even indices. +/// The output of func must lie in (-P, ..., P) and be stored in the odd indices. +/// The even indices of the output of func will not be read. +/// The input should conform to the requirements of `func`. +#[inline] +#[must_use] +unsafe fn apply_func_to_even_odd( + input: __m256i, + func: fn(__m256i) -> __m256i, +) -> __m256i { + let input_evn = input; + let input_odd = movehdup_epi32(input); + + let d_evn = func(input_evn); + let d_odd = func(input_odd); + + let d_evn_hi = movehdup_epi32(d_evn); + let t = x86_64::_mm256_blend_epi32::<0b10101010>(d_evn_hi, d_odd); + + let u = x86_64::_mm256_add_epi32(t, MPAVX2::PACKED_P); + x86_64::_mm256_min_epu32(t, u) +} + /// Negate a vector of MontyField31 field elements in canonical form. /// If the inputs are not in canonical form, the result is undefined. #[inline] @@ -404,6 +533,49 @@ impl AbstractField for PackedMontyField31AVX2 { fn generator() -> Self { MontyField31::generator().into() } + + #[inline] + fn cube(&self) -> Self { + let val = self.to_vector(); + unsafe { + // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. + let res = apply_func_to_even_odd::(val, packed_exp_3::); + Self::from_vector(res) + } + } + + #[must_use] + #[inline(always)] + fn exp_const_u64(&self) -> Self { + // We provide specialised code for the powers 3, 5, 7 as these turn up regularly. + // The other powers could be specialised similarly but we ignore this for now. + // These ideas could also be used to speed up the more generic exp_u64. + match POWER { + 0 => Self::one(), + 1 => *self, + 2 => self.square(), + 3 => self.cube(), + 4 => self.square().square(), + 5 => { + let val = self.to_vector(); + unsafe { + // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. + let res = apply_func_to_even_odd::(val, packed_exp_5::); + Self::from_vector(res) + } + } + 6 => self.square().cube(), + 7 => { + let val = self.to_vector(); + unsafe { + // Safety: `apply_func_to_even_odd` returns values in canonical form when given values in canonical form. + let res = apply_func_to_even_odd::(val, packed_exp_7::); + Self::from_vector(res) + } + } + _ => self.exp_u64(POWER), + } + } } impl Add> for PackedMontyField31AVX2 { diff --git a/poseidon2-air/Cargo.toml b/poseidon2-air/Cargo.toml new file mode 100644 index 000000000..a503e98d6 --- /dev/null +++ b/poseidon2-air/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "p3-poseidon2-air" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +p3-air = { path = "../air" } +p3-field = { path = "../field" } +p3-matrix = { path = "../matrix" } +p3-maybe-rayon = { path = "../maybe-rayon" } +p3-util = { path = "../util" } +#rand = { version = "0.8.5", features = ["min_const_gen"] } +rand = "0.8.5" +tracing = "0.1.37" + +[dev-dependencies] +p3-koala-bear = { path = "../koala-bear" } +p3-challenger = { path = "../challenger" } +p3-commit = { path = "../commit" } +p3-dft = { path = "../dft" } +p3-fri = { path = "../fri" } +p3-keccak = { path = "../keccak" } +p3-mds = { path = "../mds" } +p3-merkle-tree = { path = "../merkle-tree" } +p3-mersenne-31 = { path = "../mersenne-31" } +p3-poseidon = { path = "../poseidon" } +p3-poseidon2 = { path = "../poseidon2" } +p3-symmetric = { path = "../symmetric" } +p3-uni-stark = { path = "../uni-stark" } +tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } + +[features] +# TODO: Consider removing, at least when this gets split off into another repository. +# We should be able to enable p3-maybe-rayon/parallel directly; this just doesn't +# seem to work when using cargo with the -p or --package option. +#parallel = ["p3-maybe-rayon/parallel"] diff --git a/poseidon2-air/examples/prove_poseidon2_koala_bear_keccak.rs b/poseidon2-air/examples/prove_poseidon2_koala_bear_keccak.rs new file mode 100644 index 000000000..0cff32ba4 --- /dev/null +++ b/poseidon2-air/examples/prove_poseidon2_koala_bear_keccak.rs @@ -0,0 +1,96 @@ +use std::fmt::Debug; + +use p3_challenger::{HashChallenger, SerializingChallenger32}; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_keccak::Keccak256Hash; +use p3_koala_bear::KoalaBear; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2_air::{generate_trace_rows, Poseidon2Air}; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; +use p3_uni_stark::{prove, verify, StarkConfig}; +use rand::{random, thread_rng}; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const WIDTH: usize = 16; +const SBOX_DEGREE: usize = 3; +const SBOX_REGISTERS: usize = 1; +const HALF_FULL_ROUNDS: usize = 4; +const PARTIAL_ROUNDS: usize = 20; + +const NUM_HASHES: usize = 1 << 16; + +fn main() -> Result<(), impl Debug> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = KoalaBear; + type Challenge = BinomialExtensionField; + + type ByteHash = Keccak256Hash; + type FieldHash = SerializingHasher32; + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(Keccak256Hash {}); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(byte_hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(field_hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Challenger = SerializingChallenger32>; + + let air: Poseidon2Air< + Val, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = Poseidon2Air::new_from_rng(&mut thread_rng()); + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::< + Val, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(inputs); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(dft, val_mmcs, fri_config); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + let proof = prove(&config, &air, &mut challenger, trace, &vec![]); + + let mut challenger = Challenger::from_hasher(vec![], byte_hash); + verify(&config, &air, &mut challenger, &proof, &vec![]) +} diff --git a/poseidon2-air/examples/prove_poseidon2_koala_bear_poseidon2.rs b/poseidon2-air/examples/prove_poseidon2_koala_bear_poseidon2.rs new file mode 100644 index 000000000..36fd60a0c --- /dev/null +++ b/poseidon2-air/examples/prove_poseidon2_koala_bear_poseidon2.rs @@ -0,0 +1,108 @@ +use std::fmt::Debug; + +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_koala_bear::{DiffusionMatrixKoalaBear, KoalaBear}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; +use p3_poseidon2_air::{generate_trace_rows, Poseidon2Air}; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{prove, verify, StarkConfig}; +use rand::{random, thread_rng}; +use tracing_forest::util::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Registry}; + +const WIDTH: usize = 16; +const SBOX_DEGREE: usize = 3; +const SBOX_REGISTERS: usize = 1; +const HALF_FULL_ROUNDS: usize = 4; +const PARTIAL_ROUNDS: usize = 20; + +const NUM_HASHES: usize = 1 << 16; + +fn main() -> Result<(), impl Debug> { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + + type Val = KoalaBear; + type Challenge = BinomialExtensionField; + + type Perm = Poseidon2; + let perm = Perm::new_from_rng_128( + Poseidon2ExternalMatrixGeneral, + DiffusionMatrixKoalaBear::default(), + &mut thread_rng(), + ); + + type MyHash = PaddingFreeSponge; + let hash = MyHash::new(perm.clone()); + + type MyCompress = TruncatedPermutation; + let compress = MyCompress::new(perm.clone()); + + type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + MyHash, + MyCompress, + 8, + >; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2DitParallel; + let dft = Dft {}; + + type Challenger = DuplexChallenger; + + let air: Poseidon2Air< + Val, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = Poseidon2Air::new_from_rng(&mut thread_rng()); + let inputs = (0..NUM_HASHES).map(|_| random()).collect::>(); + let trace = generate_trace_rows::< + Val, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(inputs); + + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 100, + proof_of_work_bits: 16, + mmcs: challenge_mmcs, + }; + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(dft, val_mmcs, fri_config); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs); + + let mut challenger = Challenger::new(perm.clone()); + let proof = prove(&config, &air, &mut challenger, trace, &vec![]); + + let mut challenger = Challenger::new(perm); + verify(&config, &air, &mut challenger, &proof, &vec![]) +} diff --git a/poseidon2-air/src/air.rs b/poseidon2-air/src/air.rs new file mode 100644 index 000000000..73ce73b7c --- /dev/null +++ b/poseidon2-air/src/air.rs @@ -0,0 +1,301 @@ +use alloc::vec::Vec; +use core::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; +use rand::distributions::{Distribution, Standard}; +use rand::Rng; + +use crate::columns::{num_cols, Poseidon2Cols}; +use crate::{FullRound, PartialRound, SBox}; + +/// Assumes the field size is at least 16 bits. +/// +/// ***WARNING***: this is a stub for now, not ready to use. +#[derive(Debug)] +pub struct Poseidon2Air< + F: Field, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + partial_round_constants: [F; PARTIAL_ROUNDS], + ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], +} + +impl< + F: Field, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, + > Poseidon2Air +{ + pub fn new_from_rng(rng: &mut R) -> Self + where + Standard: Distribution + Distribution<[F; WIDTH]>, + { + let beginning_full_round_constants = rng + .sample_iter(Standard) + .take(HALF_FULL_ROUNDS) + .collect::>() + .try_into() + .unwrap(); + let partial_round_constants = rng + .sample_iter(Standard) + .take(PARTIAL_ROUNDS) + .collect::>() + .try_into() + .unwrap(); + let ending_full_round_constants = rng + .sample_iter(Standard) + .take(HALF_FULL_ROUNDS) + .collect::>() + .try_into() + .unwrap(); + Self { + beginning_full_round_constants, + partial_round_constants, + ending_full_round_constants, + } + } +} + +impl< + F: Field, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, + > BaseAir + for Poseidon2Air +{ + fn width(&self) -> usize { + num_cols::() + } +} + +impl< + AB: AirBuilder, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, + > Air + for Poseidon2Air +{ + #[inline] + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &Poseidon2Cols< + AB::Var, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = (*local).borrow(); + + let mut state: [AB::Expr; WIDTH] = local.inputs.map(|x| x.into()); + + // assert_eq!( + // L::WIDTH, + // WIDTH, + // "The WIDTH for this STARK does not match the Linear Layer WIDTH." + // ); + + // L::matmul_external(state); + for round in 0..HALF_FULL_ROUNDS { + eval_full_round( + &mut state, + &local.beginning_full_rounds[round], + &self.beginning_full_round_constants[round], + builder, + ); + } + + for round in 0..PARTIAL_ROUNDS { + eval_partial_round( + &mut state, + &local.partial_rounds[round], + &self.partial_round_constants[round], + builder, + ); + } + + for round in 0..HALF_FULL_ROUNDS { + eval_full_round( + &mut state, + &local.ending_full_rounds[round], + &self.ending_full_round_constants[round], + builder, + ); + } + } +} + +#[inline] +fn eval_full_round< + AB: AirBuilder, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, +>( + state: &mut [AB::Expr; WIDTH], + full_round: &FullRound, + round_constants: &[AB::F; WIDTH], + builder: &mut AB, +) { + for (i, (s, r)) in state.iter_mut().zip(round_constants.iter()).enumerate() { + *s = s.clone() + *r; + eval_sbox(&full_round.sbox[i], s, builder); + } + // L::matmul_external(state); +} + +#[inline] +fn eval_partial_round< + AB: AirBuilder, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, +>( + state: &mut [AB::Expr; WIDTH], + partial_round: &PartialRound, + round_constant: &AB::F, + builder: &mut AB, +) { + state[0] = state[0].clone() + *round_constant; + eval_sbox(&partial_round.sbox, &mut state[0], builder); + // L::matmul_internal(state, internal_matrix_diagonal); +} + +/// Evaluates the S-BOX over a degree-`1` expression `x`. +/// +/// # Panics +/// +/// This method panics if the number of `REGISTERS` is not chosen optimally for the given +/// `DEGREE` or if the `DEGREE` is not supported by the S-BOX. The supported degrees are +/// `3`, `5`, `7`, and `11`. +/// +/// # Efficiency Note +/// +/// This method computes the S-BOX by computing the cube of `x` and then successively +/// multiplying the running sum by the cube of `x` until the last multiplication where we use +/// the appropriate power to reach the final product: +/// +/// ```text +/// (x^3) * (x^3) * ... * (x^k) where k = d mod 3 +/// ``` +/// +/// The intermediate powers are stored in the auxiliary column registers. To maximize the +/// efficiency of the registers we try to do three multiplications per round. This algorithm +/// only multiplies the cube of `x` but a more optimal product would be to find the base-3 +/// decomposition of the `DEGREE` and use that to generate the addition chain. Even this is not +/// the optimal number of multiplications for all possible degrees, but for the S-BOX powers we +/// are interested in for Poseidon2 (namely `3`, `5`, `7`, and `11`), we get the optimal number +/// with this algorithm. We use the following register table: +/// +/// | `DEGREE` | `REGISTERS` | +/// |:--------:|:-----------:| +/// | `3` | `1` | +/// | `5` | `2` | +/// | `7` | `3` | +/// | `11` | `3` | +/// +/// We record this table in [`Self::OPTIMAL_REGISTER_COUNT`] and this choice of registers is +/// enforced by this method. +#[inline] +fn eval_sbox( + sbox: &SBox, + x: &mut AB::Expr, + builder: &mut AB, +) where + AB: AirBuilder, +{ + // assert_ne!(REGISTERS, 0, "The number of REGISTERS must be positive."); + // assert!(DEGREE <= 11, "The DEGREE must be less than or equal to 11."); + // assert_eq!( + // REGISTERS, + // Self::OPTIMAL_REGISTER_COUNT[DEGREE], + // "The number of REGISTERS must be optimal for the given DEGREE." + // ); + + let x2 = x.square(); + let x3 = x2.clone() * x.clone(); + load(sbox, 0, x3.clone(), builder); + if REGISTERS == 1 { + *x = sbox.0[0].into(); + return; + } + if DEGREE == 11 { + (1..REGISTERS - 1).for_each(|j| load_product(sbox, j, &[0, 0, j - 1], builder)); + } else { + (1..REGISTERS - 1).for_each(|j| load_product(sbox, j, &[0, j - 1], builder)); + } + load_last_product(sbox, x.clone(), x2, x3, builder); + *x = sbox.0[REGISTERS - 1].into(); +} + +/// Loads `value` into the `i`-th S-BOX register. +#[inline] +fn load( + _sbox: &SBox, + _i: usize, + _value: AB::Expr, + _builder: &mut AB, +) where + AB: AirBuilder, +{ + // builder.assert_eq(sbox.0[i].into(), value); +} + +/// Loads the product over all `product` indices the into the `i`-th S-BOX register. +#[inline] +fn load_product( + _sbox: &SBox, + _i: usize, + _product: &[usize], + _builder: &mut AB, +) where + AB: AirBuilder, +{ + // assert!( + // product.len() <= 3, + // "Product is too big. We can only compute at most degree-3 constraints." + // ); + // load( + // sbox, + // i, + // product.iter().map(|j| AB::Expr::from(self.0[*j])).product(), + // builder, + // ); +} + +/// Loads the final product into the last S-BOX register. The final term in the product is +/// `pow(x, DEGREE % 3)`. +#[inline] +fn load_last_product( + _sbox: &SBox, + _x: AB::Expr, + _x2: AB::Expr, + _x3: AB::Expr, + _builder: &mut AB, +) where + AB: AirBuilder, +{ + // load( + // sbox, + // REGISTERS - 1, + // [x3, x, x2][DEGREE % 3].clone() * AB::Expr::from(self.0[REGISTERS - 2]), + // builder, + // ); +} diff --git a/poseidon2-air/src/columns.rs b/poseidon2-air/src/columns.rs new file mode 100644 index 000000000..715e5e692 --- /dev/null +++ b/poseidon2-air/src/columns.rs @@ -0,0 +1,164 @@ +use core::borrow::{Borrow, BorrowMut}; +use core::mem::size_of; + +/// Columns for Single-Row Poseidon2 STARK +/// +/// The columns of the STARK are divided into the three different round sections of the Poseidon2 +/// Permutation: beginning full rounds, partial rounds, and ending full rounds. For the full +/// rounds we store an [`SBox`] columnset for each state variable, and for the partial rounds we +/// store only for the first state variable. Because the matrix multiplications are linear +/// functions, we need only keep auxiliary columns for the S-BOX computations. +#[repr(C)] +pub struct Poseidon2Cols< + T, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + pub export: T, + + pub inputs: [T; WIDTH], + + /// Beginning Full Rounds + pub beginning_full_rounds: [FullRound; HALF_FULL_ROUNDS], + + /// Partial Rounds + pub partial_rounds: [PartialRound; PARTIAL_ROUNDS], + + /// Ending Full Rounds + pub ending_full_rounds: [FullRound; HALF_FULL_ROUNDS], +} + +/// Full Round Columns +#[repr(C)] +pub struct FullRound { + /// S-BOX Columns + pub sbox: [SBox; WIDTH], +} + +/// Partial Round Columns +#[repr(C)] +pub struct PartialRound< + T, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, +> { + /// S-BOX Columns + pub sbox: SBox, +} + +/// S-BOX Columns +/// +/// Use this column-set for an S-BOX that can be computed in `REGISTERS`-many columns. The S-BOX is +/// checked to ensure that `REGISTERS` is the optimal number of registers for the given `DEGREE` +/// for the degrees given in the Poseidon2 paper: `3`, `5`, `7`, and `11`. See [`Self::eval`] for +/// more information. +#[repr(C)] +pub struct SBox(pub [T; REGISTERS]); + +pub const fn num_cols< + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>() -> usize { + size_of::>( + ) +} + +pub const fn make_col_map< + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>() -> Poseidon2Cols { + todo!() + // let indices_arr = indices_arr::< + // { num_cols::() }, + // >(); + // unsafe { + // transmute::< + // [usize; + // num_cols::()], + // Poseidon2Cols< + // usize, + // WIDTH, + // SBOX_DEGREE, + // SBOX_REGISTERS, + // HALF_FULL_ROUNDS, + // PARTIAL_ROUNDS, + // >, + // >(indices_arr) + // } +} + +impl< + T, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, + > Borrow> + for [T] +{ + fn borrow( + &self, + ) -> &Poseidon2Cols + { + // debug_assert_eq!(self.len(), NUM_COLS); + let (prefix, shorts, suffix) = unsafe { + self.align_to::>() + }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl< + T, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, + > + BorrowMut< + Poseidon2Cols, + > for [T] +{ + fn borrow_mut( + &mut self, + ) -> &mut Poseidon2Cols + { + // debug_assert_eq!(self.len(), NUM_COLS); + let (prefix, shorts, suffix) = unsafe { + self.align_to_mut::>() + }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} diff --git a/poseidon2-air/src/generation.rs b/poseidon2-air/src/generation.rs new file mode 100644 index 000000000..7c31a611f --- /dev/null +++ b/poseidon2-air/src/generation.rs @@ -0,0 +1,71 @@ +use alloc::vec; +use alloc::vec::Vec; + +use p3_field::PrimeField; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; +use tracing::instrument; + +use crate::columns::{num_cols, Poseidon2Cols}; + +// TODO: Take generic iterable +#[instrument(name = "generate Poseidon2 trace", skip_all)] +pub fn generate_trace_rows< + F: PrimeField, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + inputs: Vec<[F; WIDTH]>, +) -> RowMajorMatrix { + let n = inputs.len(); + assert!( + n.is_power_of_two(), + "Callers expected to pad inputs to a power of two" + ); + let ncols = num_cols::(); + let mut trace = RowMajorMatrix::new(vec![F::zero(); n * ncols], ncols); + let (prefix, rows, suffix) = unsafe { + trace.values.align_to_mut::>() + }; + assert!(prefix.is_empty(), "Alignment should match"); + assert!(suffix.is_empty(), "Alignment should match"); + assert_eq!(rows.len(), n); + + rows.par_iter_mut().zip(inputs).for_each(|(row, input)| { + generate_trace_rows_for_perm(row, input); + }); + + trace +} + +/// `rows` will normally consist of 24 rows, with an exception for the final row. +fn generate_trace_rows_for_perm< + F: PrimeField, + const WIDTH: usize, + const SBOX_DEGREE: usize, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + _row: &mut Poseidon2Cols< + F, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + _input: [F; WIDTH], +) { + // TODO +} diff --git a/poseidon2-air/src/lib.rs b/poseidon2-air/src/lib.rs new file mode 100644 index 000000000..51f929df7 --- /dev/null +++ b/poseidon2-air/src/lib.rs @@ -0,0 +1,13 @@ +//! And AIR for the Poseidon2 permutation. + +#![no_std] + +extern crate alloc; + +mod air; +mod columns; +mod generation; + +pub use air::*; +pub use columns::*; +pub use generation::*; diff --git a/poseidon2/benches/poseidon2.rs b/poseidon2/benches/poseidon2.rs index 35a85ad37..1343ce6be 100644 --- a/poseidon2/benches/poseidon2.rs +++ b/poseidon2/benches/poseidon2.rs @@ -3,7 +3,7 @@ use std::any::type_name; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_bn254_fr::{Bn254Fr, DiffusionMatrixBN254}; -use p3_field::{PrimeField, PrimeField64}; +use p3_field::{AbstractField, PrimeField, PrimeField64}; use p3_goldilocks::{DiffusionMatrixGoldilocks, Goldilocks}; use p3_koala_bear::{DiffusionMatrixKoalaBear, KoalaBear}; use p3_mersenne_31::{DiffusionMatrixMersenne31, Mersenne31}; @@ -19,11 +19,7 @@ fn bench_poseidon2(c: &mut Criterion) { poseidon2_p64::(c); poseidon2_p64::(c); - poseidon2_p64::(c); - poseidon2_p64::(c); poseidon2_p64::(c); - poseidon2_p64::(c); - poseidon2_p64::(c); poseidon2_p64::( c, @@ -50,8 +46,8 @@ fn poseidon2( ) where F: PrimeField, Standard: Distribution, - MdsLight: MdsLightPermutation + Default, - Diffusion: DiffusionPermutation + Default, + MdsLight: MdsLightPermutation + Default, + Diffusion: DiffusionPermutation + Default, { let mut rng = thread_rng(); let external_linear_layer = MdsLight::default(); @@ -64,10 +60,10 @@ fn poseidon2( internal_linear_layer, &mut rng, ); - let input = [F::zero(); WIDTH]; + let input = [F::Packing::zero(); WIDTH]; let name = format!( "poseidon2::<{}, {}, {}, {}>", - type_name::(), + type_name::(), D, rounds_f, rounds_p @@ -81,8 +77,8 @@ fn poseidon2_p64(c: &m where F: PrimeField64, Standard: Distribution, - MdsLight: MdsLightPermutation + Default, - Diffusion: DiffusionPermutation + Default, + MdsLight: MdsLightPermutation + Default, + Diffusion: DiffusionPermutation + Default, { let mut rng = thread_rng(); let external_linear_layer = MdsLight::default(); @@ -93,8 +89,13 @@ where internal_linear_layer, &mut rng, ); - let input = [F::zero(); WIDTH]; - let name = format!("poseidon2::<{}, {}>", type_name::(), D); + let input = [F::Packing::zero(); WIDTH]; + let name = format!( + "poseidon2::<{}, {}, {}>", + type_name::(), + D, + WIDTH + ); let id = BenchmarkId::new(name, WIDTH); c.bench_with_input(id, &input, |b, &input| b.iter(|| poseidon.permute(input))); } diff --git a/sha256/Cargo.toml b/sha256/Cargo.toml index 56edd8175..4b5b73fad 100644 --- a/sha256/Cargo.toml +++ b/sha256/Cargo.toml @@ -7,7 +7,7 @@ description = "Plonky3 hash trait implementations for the SHA2-256 hash function [dependencies] p3-symmetric = { path = "../symmetric" } -sha2 = { version = "0.10.8", default-features = false } +sha2 = { version = "0.10.8", default-features = false, features = ["compress"] } [features] default = [] diff --git a/sha256/src/lib.rs b/sha256/src/lib.rs index 36c4b0f81..96e4db346 100644 --- a/sha256/src/lib.rs +++ b/sha256/src/lib.rs @@ -6,9 +6,15 @@ extern crate alloc; use alloc::vec::Vec; -use p3_symmetric::CryptographicHasher; +use p3_symmetric::{CompressionFunction, CryptographicHasher, PseudoCompressionFunction}; +use sha2::digest::generic_array::GenericArray; +use sha2::digest::typenum::U64; use sha2::Digest; +pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + /// The SHA2-256 hash function. #[derive(Copy, Clone, Debug)] pub struct Sha256; @@ -34,12 +40,34 @@ impl CryptographicHasher for Sha256 { } } +/// SHA2-256 without the padding (pre-processing), intended to be used +/// as a 2-to-1 [PseudoCompressionFunction]. +#[derive(Copy, Clone, Debug)] +pub struct Sha256Compress; + +impl PseudoCompressionFunction<[u8; 32], 2> for Sha256Compress { + fn compress(&self, input: [[u8; 32]; 2]) -> [u8; 32] { + let mut state = H256_256; + // GenericArray has same memory layout as [u8; 64] + let block: GenericArray = unsafe { core::mem::transmute(input) }; + sha2::compress256(&mut state, &[block]); + + let mut output = [0u8; 32]; + for (chunk, word) in output.chunks_exact_mut(4).zip(state) { + chunk.copy_from_slice(&word.to_be_bytes()); + } + output + } +} + +impl CompressionFunction<[u8; 32], 2> for Sha256Compress {} + #[cfg(test)] mod tests { use hex_literal::hex; - use p3_symmetric::CryptographicHasher; + use p3_symmetric::{CryptographicHasher, PseudoCompressionFunction}; - use crate::Sha256; + use crate::{Sha256, Sha256Compress}; #[test] fn test_hello_world() { @@ -53,4 +81,17 @@ mod tests { let sha256 = Sha256; assert_eq!(sha256.hash_iter(input.to_vec())[..], expected[..]); } + + #[test] + fn test_compress() { + let left = [0u8; 32]; + // `right` will simulate the SHA256 padding + let mut right = [0u8; 32]; + right[0] = 1 << 7; + right[30] = 1; // left has length 256 in bits, L = 0x100 + + let expected = Sha256.hash_iter(left); + let sha256_compress = Sha256Compress; + assert_eq!(sha256_compress.compress([left, right]), expected); + } } diff --git a/uni-stark/Cargo.toml b/uni-stark/Cargo.toml index dcc8024d0..83db1ab20 100644 --- a/uni-stark/Cargo.toml +++ b/uni-stark/Cargo.toml @@ -25,11 +25,8 @@ p3-fri = { path = "../fri" } p3-keccak = { path = "../keccak" } p3-mds = { path = "../mds" } p3-merkle-tree = { path = "../merkle-tree" } -p3-goldilocks = { path = "../goldilocks" } p3-mersenne-31 = { path = "../mersenne-31" } p3-poseidon2 = { path = "../poseidon2" } p3-symmetric = { path = "../symmetric" } rand = "0.8.5" -tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } -tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] } postcard = { version = "1.0.0", default-features = false, features = ["alloc"] } diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index 5c70568c1..3529f6cd8 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -123,26 +123,28 @@ where let zeta: SC::Challenge = challenger.sample(); let zeta_next = trace_domain.next_point(zeta).unwrap(); - let (opened_values, opening_proof) = pcs.open( - iter::empty() - .chain( - proving_key - .map(|proving_key| { - (&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]]) - }) - .into_iter(), - ) - .chain([ - (&trace_data, vec![vec![zeta, zeta_next]]), - ( - "ient_data, - // open every chunk at zeta - (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), - ), - ]) - .collect_vec(), - challenger, - ); + let (opened_values, opening_proof) = info_span!("open").in_scope(|| { + pcs.open( + iter::empty() + .chain( + proving_key + .map(|proving_key| { + (&proving_key.preprocessed_data, vec![vec![zeta, zeta_next]]) + }) + .into_iter(), + ) + .chain([ + (&trace_data, vec![vec![zeta, zeta_next]]), + ( + "ient_data, + // open every chunk at zeta + (0..quotient_degree).map(|_| vec![zeta]).collect_vec(), + ), + ]) + .collect_vec(), + challenger, + ) + }); let mut opened_values = opened_values.iter(); // maybe get values for the preprocessed columns diff --git a/util/src/lib.rs b/util/src/lib.rs index 82625d0a8..05b21ce49 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -4,7 +4,9 @@ extern crate alloc; +use alloc::string::String; use alloc::vec::Vec; +use core::any::type_name; use core::hint::unreachable_unchecked; pub mod array_serialization; @@ -67,6 +69,7 @@ pub const fn reverse_bits_len(x: usize, bit_len: usize) -> usize { .0 } +/// Permutes `arr` such that each index is mapped to its reverse in binary. pub fn reverse_slice_index_bits(vals: &mut [F]) { let n = vals.len(); if n == 0 { @@ -155,3 +158,60 @@ pub fn transpose_vec(v: Vec>) -> Vec> { }) .collect() } + +/// Return a String containing the name of T but with all the crate +/// and module prefixes removed. +pub fn pretty_name() -> String { + let name = type_name::(); + let mut result = String::new(); + for qual in name.split_inclusive(&['<', '>', ',']) { + result.push_str(qual.split("::").last().unwrap()); + } + result +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use super::*; + + #[test] + fn test_reverse_bits_len() { + assert_eq!(reverse_bits_len(0b0000000000, 10), 0b0000000000); + assert_eq!(reverse_bits_len(0b0000000001, 10), 0b1000000000); + assert_eq!(reverse_bits_len(0b1000000000, 10), 0b0000000001); + assert_eq!(reverse_bits_len(0b00000, 5), 0b00000); + assert_eq!(reverse_bits_len(0b01011, 5), 0b11010); + } + + #[test] + fn test_reverse_index_bits() { + let mut arg = vec![10, 20, 30, 40]; + reverse_slice_index_bits(&mut arg); + assert_eq!(arg, vec![10, 30, 20, 40]); + + let mut input256: Vec = (0..256).collect(); + #[rustfmt::skip] + let output256: Vec = vec![ + 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0, + 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8, + 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4, + 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc, + 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2, + 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa, + 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6, + 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe, + 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1, + 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9, + 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5, + 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd, + 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3, + 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb, + 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7, + 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff, + ]; + reverse_slice_index_bits(&mut input256[..]); + assert_eq!(input256, output256); + } +}