diff --git a/air/src/air.rs b/air/src/air.rs index 7f51249f3..873ea7da6 100644 --- a/air/src/air.rs +++ b/air/src/air.rs @@ -13,27 +13,10 @@ pub trait BaseAir: Sync { None } - fn stage_count(&self) -> usize { - 1 - } - /// The number of preprocessed columns in this AIR fn preprocessed_width(&self) -> usize { 0 } - - /// The number of columns in a given higher-stage trace. - fn multi_stage_width(&self, stage: u32) -> usize { - match stage { - 0 => self.width(), - _ => unimplemented!(), - } - } - - /// The number of challenges produced at the end of each stage - fn challenge_count(&self, _stage: u32) -> usize { - 0 - } } /// An AIR that works with a particular `AirBuilder`. @@ -135,16 +118,7 @@ pub trait AirBuilder: Sized { pub trait AirBuilderWithPublicValues: AirBuilder { type PublicVar: Into + Copy; - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - match stage { - 0 => self.public_values(), - _ => unimplemented!(), - } - } - - fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) - } + fn public_values(&self) -> &[Self::PublicVar]; } pub trait PairBuilder: AirBuilder { @@ -188,15 +162,6 @@ pub trait PermutationAirBuilder: ExtensionBuilder { fn permutation_randomness(&self) -> &[Self::RandomVar]; } -pub trait MultistageAirBuilder: AirBuilder { - type Challenge: Clone + Into; - - /// Traces from each stage. - fn multi_stage(&self, stage: usize) -> Self::M; - - /// Challenges from each stage, drawn from the base field - fn challenges(&self, stage: usize) -> &[Self::Challenge]; -} #[derive(Debug)] pub struct FilteredAirBuilder<'a, AB: AirBuilder> { pub inner: &'a mut AB, diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 573d3cd24..9014dbc3e 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -1,13 +1,15 @@ use alloc::vec::Vec; use itertools::Itertools; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, MultistageAirBuilder, PairBuilder}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::Field; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; use p3_matrix::Matrix; use tracing::instrument; +use crate::traits::MultistageAirBuilder; + #[instrument(name = "check constraints", skip_all)] pub(crate) fn check_constraints( air: &A, @@ -131,8 +133,8 @@ where impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> { type PublicVar = Self::F; - fn stage_public_values(&self, stage: usize) -> &[Self::F] { - self.public_values[stage] + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) } } @@ -145,6 +147,10 @@ impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> { type Challenge = Self::Expr; + fn stage_public_values(&self, stage: usize) -> &[Self::F] { + self.public_values[stage] + } + fn multi_stage(&self, stage: usize) -> Self::M { self.stages[stage] } diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index c43c8a50c..1d353251b 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,10 +1,11 @@ use alloc::vec::Vec; -use p3_air::{AirBuilder, AirBuilderWithPublicValues, MultistageAirBuilder, PairBuilder}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::AbstractField; use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; +use crate::traits::MultistageAirBuilder; use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; #[derive(Debug)] @@ -69,29 +70,32 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { } impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> { - type PublicVar = Self::F; + type PublicVar = Val; - fn stage_public_values(&self, stage: usize) -> &[Self::F] { - &self.public_values[stage] - } -} - -impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { - fn preprocessed(&self) -> Self::M { - self.preprocessed.clone() + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) } } impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for ProverConstraintFolder<'a, SC> { type Challenge = Val; - fn multi_stage(&self, stage: usize) -> Self::M { + fn multi_stage(&self, stage: usize) -> ::M { self.stages[stage].clone() } fn challenges(&self, stage: usize) -> &[Self::Challenge] { &self.challenges[stage] } + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + &self.public_values[stage] + } +} + +impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { + fn preprocessed(&self) -> Self::M { + self.preprocessed.clone() + } } impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> { @@ -128,27 +132,30 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> } impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> { - type PublicVar = Self::F; + type PublicVar = Val; - fn stage_public_values(&self, stage: usize) -> &[Self::F] { - self.public_values[stage] - } -} - -impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { - fn preprocessed(&self) -> Self::M { - self.preprocessed + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) } } impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for VerifierConstraintFolder<'a, SC> { type Challenge = Val; - fn multi_stage(&self, stage: usize) -> Self::M { + fn multi_stage(&self, stage: usize) -> ::M { self.stages[stage] } fn challenges(&self, stage: usize) -> &[Self::Challenge] { &self.challenges[stage] } + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + self.public_values[stage] + } +} + +impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { + fn preprocessed(&self) -> Self::M { + self.preprocessed + } } diff --git a/uni-stark/src/lib.rs b/uni-stark/src/lib.rs index 488c8f266..6a0d401e6 100644 --- a/uni-stark/src/lib.rs +++ b/uni-stark/src/lib.rs @@ -11,6 +11,7 @@ mod prover; mod symbolic_builder; mod symbolic_expression; mod symbolic_variable; +mod traits; mod verifier; mod zerofier_coset; @@ -26,5 +27,6 @@ pub use prover::*; pub use symbolic_builder::*; pub use symbolic_expression::*; pub use symbolic_variable::*; +pub use traits::*; pub use verifier::*; pub use zerofier_coset::*; diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index c615f46ea..938ffdc65 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -14,6 +14,7 @@ use p3_util::log2_strict_usize; use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; +use crate::traits::MultiStageAir; use crate::{ CallbackResult, Commitments, NextStageTraceCallback, PackedChallenge, PackedVal, Proof, ProverConstraintFolder, QuotientInputs, Stage, StarkGenericConfig, StarkProvingKey, State, Val, @@ -43,7 +44,8 @@ pub fn prove< ) -> Proof where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { prove_with_key::<_, _, Panic>( config, @@ -74,12 +76,13 @@ pub fn prove_with_key< ) -> Proof where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, T: NextStageTraceCallback, { let degree = stage_0_trace.height(); let log_degree = log2_strict_usize(degree); - let stage_count = air.stage_count(); + let stage_count = >>::stage_count(air); let pcs = config.pcs(); let trace_domain = pcs.natural_domain_for_degree(degree); @@ -95,7 +98,7 @@ where let mut state = State::new(pcs, trace_domain, challenger, log_degree); let mut stage = Stage { trace: stage_0_trace, - challenge_count: air.challenge_count(0), + challenge_count: >>::challenge_count(air, 0), public_values: stage_0_public_values.to_owned(), }; @@ -118,7 +121,10 @@ where // go to the next stage stage = Stage { trace, - challenge_count: air.challenge_count(stage_id as u32), + challenge_count: >>::challenge_count( + air, + stage_id as u32, + ), public_values, }; } @@ -169,7 +175,8 @@ pub fn finish< ) -> Proof where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { let log_quotient_degree = get_log_quotient_degree::, A>( air, diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 73f0cf92b..1f11c746a 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, MultistageAirBuilder, PairBuilder}; +use p3_air::{AirBuilder, AirBuilderWithPublicValues, PairBuilder}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; @@ -9,13 +9,14 @@ use tracing::instrument; use crate::symbolic_expression::SymbolicExpression; use crate::symbolic_variable::SymbolicVariable; +use crate::traits::{MultiStageAir, MultistageAirBuilder}; use crate::Entry; #[instrument(name = "infer log of constraint degree", skip_all)] pub fn get_log_quotient_degree(air: &A, public_values_counts: &[usize]) -> usize where F: Field, - A: Air>, + A: MultiStageAir>, { // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. let constraint_degree = get_max_constraint_degree(air, public_values_counts).max(2); @@ -30,7 +31,7 @@ where pub fn get_max_constraint_degree(air: &A, public_values_counts: &[usize]) -> usize where F: Field, - A: Air>, + A: MultiStageAir>, { get_symbolic_constraints(air, public_values_counts) .iter() @@ -46,7 +47,7 @@ pub fn get_symbolic_constraints( ) -> Vec> where F: Field, - A: Air>, + A: MultiStageAir>, { let widths: Vec<_> = (0..air.stage_count()) .map(|i| air.multi_stage_width(i as u32)) @@ -174,20 +175,19 @@ impl AirBuilder for SymbolicAirBuilder { impl AirBuilderWithPublicValues for SymbolicAirBuilder { type PublicVar = SymbolicVariable; - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - &self.public_values[stage] - } -} -impl PairBuilder for SymbolicAirBuilder { - fn preprocessed(&self) -> Self::M { - self.preprocessed.clone() + fn public_values(&self) -> &[Self::PublicVar] { + self.stage_public_values(0) } } impl MultistageAirBuilder for SymbolicAirBuilder { type Challenge = Self::Var; + fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { + &self.public_values[stage] + } + fn multi_stage(&self, stage: usize) -> Self::M { self.stages[stage].clone() } @@ -196,3 +196,9 @@ impl MultistageAirBuilder for SymbolicAirBuilder { &self.challenges[stage] } } + +impl PairBuilder for SymbolicAirBuilder { + fn preprocessed(&self) -> Self::M { + self.preprocessed.clone() + } +} diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 8ff0fc90b..576f6f233 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -3,7 +3,7 @@ use alloc::vec::Vec; use core::iter; use itertools::{izip, Itertools}; -use p3_air::{Air, BaseAir}; +use p3_air::BaseAir; use p3_challenger::{CanObserve, CanSample, FieldChallenger}; use p3_commit::{Pcs, PolynomialSpace}; use p3_field::{AbstractExtensionField, AbstractField, Field}; @@ -12,6 +12,7 @@ use p3_matrix::stack::VerticalPair; use tracing::instrument; use crate::symbolic_builder::{get_log_quotient_degree, SymbolicAirBuilder}; +use crate::traits::MultiStageAir; use crate::{ PcsError, Proof, StarkGenericConfig, StarkVerifyingKey, Val, VerifierConstraintFolder, }; @@ -26,7 +27,8 @@ pub fn verify( ) -> Result<(), VerificationError>> where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { verify_with_key(config, None, air, challenger, proof, vec![public_values]) } @@ -42,7 +44,8 @@ pub fn verify_with_key( ) -> Result<(), VerificationError>> where SC: StarkGenericConfig, - A: Air>> + for<'a> Air>, + A: MultiStageAir>> + + for<'a> MultiStageAir>, { let Proof { commitments, @@ -70,7 +73,9 @@ where let quotient_chunks_domains = quotient_domain.split_domains(quotient_degree); let air_widths = (0..stages) - .map(|stage| >>::multi_stage_width(air, stage as u32)) + .map(|stage| { + >>>::multi_stage_width(air, stage as u32) + }) .collect::>(); let air_fixed_width = >>::preprocessed_width(air); let valid_shape = opened_values.preprocessed_local.len() == air_fixed_width