Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to bindings and Scalars for customisable pickles evaluation #14989

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 7 additions & 77 deletions src/lib/crypto/kimchi_bindings/stubs/src/linearization.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use kimchi::{
circuits::{
gate::GateType,
constraints::FeatureFlags,
expr::{Linearization, PolishToken},
expr::Linearization,
lookup::lookups::{LookupFeatures, LookupPatterns},
},
linearization::{constraints_expr, linearization_columns},
};

/// Converts the linearization of the kimchi circuit polynomial into a printable string.
pub fn linearization_strings<F: ark_ff::PrimeField + ark_ff::SquareRootField>(
// omit_custom_gate: bool,
custom_gate_type: Option<&Vec<PolishToken<F>>>,
omit_custom_gate: bool,
uses_custom_gates: bool,
) -> (String, Vec<(String, String)>) {
let features = if uses_custom_gates {
Expand All @@ -38,7 +36,7 @@ pub fn linearization_strings<F: ark_ff::PrimeField + ark_ff::SquareRootField>(
};
let evaluated_cols = linearization_columns::<F>(features.as_ref());
let (linearization, _powers_of_alpha) =
constraints_expr::<F>(/* omit_custom_gate */ false, custom_gate_type, features.as_ref(), true);
constraints_expr::<F>(omit_custom_gate, None, features.as_ref(), true);

let Linearization {
constant_term,
Expand All @@ -58,90 +56,22 @@ pub fn linearization_strings<F: ark_ff::PrimeField + ark_ff::SquareRootField>(
(constant, other_terms)
}

#[ocaml::func]
pub fn fp_linearization_strings_plus() -> (String, Vec<(String, String)>) {
// Define conditional gate in RPN
// w(0) = w(1) * w(3) + (1 - w(3)) * w(2)
use kimchi::circuits::expr::{PolishToken::*, *};
use kimchi::circuits::gate::CurrOrNext::Curr;
let conditional_gate = Some(vec![
Cell(Variable {
col: Column::Index(GateType::ForeignFieldAdd),
row: Curr,
}),
Cell(Variable {
col: Column::Witness(3),
row: Curr,
}),
Dup,
Mul,
Cell(Variable {
col: Column::Witness(3),
row: Curr,
}),
Sub,
Alpha,
Pow(1),
Cell(Variable {
col: Column::Witness(0),
row: Curr,
}),
Cell(Variable {
col: Column::Witness(3),
row: Curr,
}),
Cell(Variable {
col: Column::Witness(1),
row: Curr,
}),
Mul,
Literal(mina_curves::pasta::Fp::from(1u32)),
Cell(Variable {
col: Column::Witness(3),
row: Curr,
}),
Sub,
Cell(Variable {
col: Column::Witness(2),
row: Curr,
}),
Mul,
Add,
Sub,
Mul,
Add,
Mul,
]);

linearization_strings::<mina_curves::pasta::Fp>(conditional_gate.as_ref(), true)
}

#[ocaml::func]
pub fn fq_linearization_strings_plus() -> (String, Vec<(String, String)>) {
linearization_strings::<mina_curves::pasta::Fq>(None, false)
}


#[ocaml::func]
pub fn fp_linearization_strings_minus() -> (String, Vec<(String, String)>) {
// linearization_strings::<mina_curves::pasta::Fp>(true, true)
linearization_strings::<mina_curves::pasta::Fp>(None, true)
linearization_strings::<mina_curves::pasta::Fp>(true, true)
}

#[ocaml::func]
pub fn fq_linearization_strings_minus() -> (String, Vec<(String, String)>) {
// linearization_strings::<mina_curves::pasta::Fq>(true, false)
linearization_strings::<mina_curves::pasta::Fq>(None, false)
linearization_strings::<mina_curves::pasta::Fq>(true, false)
}

#[ocaml::func]
pub fn fp_linearization_strings() -> (String, Vec<(String, String)>) {
// linearization_strings::<mina_curves::pasta::Fp>(false, true)
linearization_strings::<mina_curves::pasta::Fp>(None, true)
linearization_strings::<mina_curves::pasta::Fp>(false, true)
}

#[ocaml::func]
pub fn fq_linearization_strings() -> (String, Vec<(String, String)>) {
// linearization_strings::<mina_curves::pasta::Fq>(false, false)
linearization_strings::<mina_curves::pasta::Fq>(None, false)
linearization_strings::<mina_curves::pasta::Fq>(false, false)
}
2 changes: 1 addition & 1 deletion src/lib/pickles/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ struct
, r.custom_gate_type )
| r :: rules ->
let feature_flags, custom_gate_type = go rules in
(* Note: For now we only support one choice when custom gates are defined *)
(* Note: For now we only support one choice when configurable gates are defined *)
if
Option.is_some custom_gate_type
|| Option.is_some r.custom_gate_type
Expand Down
3 changes: 3 additions & 0 deletions src/lib/pickles/plonk_checks/dune
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
;; local libraries
pickles_types
pickles_base
pickles.backend
pickles.composition_types
kimchi_backend.pasta.basic
kimchi_backend.pasta
kimchi_backend
kimchi_types
snarky.backendless
Expand Down
23 changes: 11 additions & 12 deletions src/lib/pickles/plonk_checks/gen_scalars/gen_scalars.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ let () =
(* turn off fragile pattern-matching warning from sexp ppx *)
[@@@warning "-4"]

type curr_or_next = Curr | Next
[@@deriving hash, eq, compare, sexp]
type curr_or_next = Kimchi_types.curr_or_next = Curr | Next [@@deriving hash, eq, compare, sexp]

module Gate_type = struct
module T = struct
Expand Down Expand Up @@ -60,19 +59,19 @@ module Lookup_pattern = struct
end

module Column = struct
open Core_kernel

module T = struct
type t =
type t = Kimchi_types.column =
| Witness of int
| Index of Gate_type.t
| Coefficient of int
| LookupTable
| Z
| LookupSorted of int
| LookupAggreg
| LookupTable
| LookupKindIndex of Lookup_pattern.t
| LookupRuntimeSelector
| LookupRuntimeTable
| Index of Gate_type.t
| Coefficient of int
| Permutation of int
[@@deriving hash, eq, compare, sexp]
end

Expand Down Expand Up @@ -248,7 +247,7 @@ let () =
output_string
{ocaml|
(* The constraints for overriden gate *)
module TickPlus : S = struct
module TickMinus : S = struct
let constant_term (type a)
({ add = ( + )
; sub = ( - )
Expand Down Expand Up @@ -277,10 +276,10 @@ let () =
a Env.t) =
|ocaml}

external fp_linearization_plus : bool -> string * (string * string) array
= "fp_linearization_strings_plus"
external fp_linearization_minus : unit -> string * (string * string) array
= "fp_linearization_strings_minus"

let fp_constant_term, fp_index_terms = fp_linearization_plus true
let fp_constant_term, fp_index_terms = fp_linearization_minus ()

let () = output_string fp_constant_term

Expand Down
113 changes: 102 additions & 11 deletions src/lib/pickles/plonk_checks/plonk_checks.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ module type Field_intf = sig
val inv : t -> t

val negate : t -> t
(*
module Constant : sig
(** The finite field over which the R1CS operates. *)
type t = field [@@deriving bin_io, sexp, hash, compare]

(** Return a constraint system constant representing the given value. *)
val constant : ('var, 'value) Typ.t -> 'value -> 'var
end *)
end

module type Field_with_if_intf = sig
Expand Down Expand Up @@ -197,6 +205,10 @@ let scalars_env (type boolean t) (module B : Bool_intf with type t = boolean)
i ()
| Coefficient i ->
get_eval coefficients.(i)
| Permutation _i ->
failwith "Not implemented"
| Z ->
failwith "Not implemented"
| LookupTable ->
get_eval (Opt.value_exn e.lookup_table)
| LookupSorted i ->
Expand Down Expand Up @@ -339,20 +351,84 @@ let scalars_env (type boolean t) (module B : Bool_intf with type t = boolean)
let perm_alpha0 : int = 21

module Make (Shifted_value : Shifted_value.S) (Sc : Scalars.S) = struct
let evaluate_rpn (type t u) (module F : Field_intf with type t = t)
~(* (module U : Field_intf with type t = u) *) (env : t Scalars.Env.t)
~(gate_rpn : u Kimchi_types.polish_token array)
~(* ~(gate_rpn : Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array) *)
map_constant =
(* ~(_evals : (_ * _, _) Plonk_types.Evals.In_circuit.t) *)
(* ~(gate_rpn : Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array) = *)
printf "HI\n" ;

let stack = Stack.create () in
Array.iteri gate_rpn ~f:(fun _idx token ->
Kimchi_types.(
match token with
| Alpha ->
Stack.push stack @@ env.alpha_pow 1
| Beta ->
Stack.push stack env.beta
| Gamma ->
Stack.push stack env.gamma
| JointCombiner ->
Stack.push stack env.joint_combiner
| EndoCoefficient ->
Stack.push stack env.endo_coefficient
| Mds mds ->
Stack.push stack @@ env.mds (mds.row, mds.col)
(* JES: CHECK this: is this (row, col) format *)
| VanishesOnZeroKnowledgeAndPreviousRows ->
Stack.push stack env.vanishes_on_zero_knowledge_and_previous_rows
| UnnormalizedLagrangeBasis i ->
Stack.push stack
@@ env.unnormalized_lagrange_basis
(i.zk_rows, Int32.to_int_exn i.offset)
| Literal x ->
Stack.push stack @@ map_constant x
| Dup ->
Stack.(push stack @@ top_exn stack)
| Cell v ->
Stack.push stack @@ env.var (v.col, v.row)
| Pow n ->
Stack.(
push stack
@@ pow2pow (module F) (pop_exn stack) (Int32.to_int_exn n))
| Add ->
Stack.(push stack @@ F.( + ) (pop_exn stack) (pop_exn stack))
| Mul ->
Stack.(push stack @@ F.( * ) (pop_exn stack) (pop_exn stack))
| Sub ->
Stack.(push stack @@ F.( - ) (pop_exn stack) (pop_exn stack))
| Store ->
failwith "Unsupported RPN token: Store"
| Load _ ->
failwith "Unsupported RPN token: Load"
| SkipIf _ ->
failwith "Unsupported RPN token: SkipIf"
| SkipIfNot _ ->
failwith "Unsupported RPN token: SkipIfNot") ) ;

Stack.pop_exn stack

(** Computes the ft evaluation at zeta.
(see https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#the-evaluation-of-l)
*)
let ft_eval0 (type t) (module F : Field_intf with type t = t) ~domain
~(env : t Scalars.Env.t)
({ alpha = _
; beta
; gamma
; zeta
; joint_combiner = _
; feature_flags = _
; _
} :
_ Minimal.t ) (e : (_ * _, _) Plonk_types.Evals.In_circuit.t) p_eval0 =
let ft_eval0 (type t) (* (type u) *) (module F : Field_intf with type t = t)
~(* (module U : Field_intf with type t = u) *) domain
~(env : t Scalars.Env.t) ?custom_gate_type ?map_constant
(* ~(const_map : 'a -> 'b) *)
(* ~(custom_gate_type : Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array option) *)
(* ?custom_gate_type *)
({ alpha = _
; beta
; gamma
; zeta
; joint_combiner = _
; feature_flags = _
; _
} :
_ Minimal.t ) (e : (_ * _, _) Plonk_types.Evals.In_circuit.t) p_eval0
=
let open Plonk_types.Evals.In_circuit in
let e0 field = fst (field e) in
let e1 field = snd (field e) in
Expand Down Expand Up @@ -401,6 +477,21 @@ module Make (Shifted_value : Shifted_value.S) (Sc : Scalars.S) = struct
(* (1) Create ScalarsMinus, without ffAdd *)
let constant_term =
Sc.constant_term env
+
match custom_gate_type with
| Some custom_gate_type -> (
match map_constant with
| Some map_constant ->
evaluate_rpn
(module F)
~env ~gate_rpn:custom_gate_type ~map_constant
| None ->
failwith "Need constant mapping function"
(* evaluate_rpn
(module F)
~env ~gate_rpn:custom_gate_type ~map_constant:Fn.id *) )
| None ->
F.zero
(* (2) ~override_ffadd: optional... Sc.evaluate_custom_gate ~env ~polish_gate *)
in
ft_eval0 - constant_term
Expand Down
22 changes: 19 additions & 3 deletions src/lib/pickles/plonk_checks/plonk_checks.mli
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,27 @@ val scalars_env :
-> 't Scalars.Env.t

module Make (Shifted_value : Pickles_types.Shifted_value.S) (_ : Scalars.S) : sig
val evaluate_rpn :
(module Field_intf with type t = 't)
(* -> (module Field_intf with type t = 'u) *)
-> env:'t Scalars.Env.t
(* -> _evals:('t * 't, 'a) Pickles_types.Plonk_types.Evals.In_circuit.t *)
(* -> gate_rpn:Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array *)
-> gate_rpn:'u Kimchi_types.polish_token array
-> map_constant:('u -> 't)
-> (* -> map_constant:(Kimchi_pasta_basic.Fp.t -> 't) *)
't

val ft_eval0 :
't field
-> domain:< shifts : 't array ; .. >
't field (* -> 'u field *)
-> domain:< shifts : 't array ; .. > (* -> const_map:('a -> 'b) *)
-> env:'t Scalars.Env.t
-> ( 't
(* -> ?custom_gate_type:Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array *)
-> ?custom_gate_type:'u Kimchi_types.polish_token array
-> ?map_constant:('u -> 't)
-> (* -> ?map_constant:(Kimchi_pasta_basic.Fp.t -> 't) *)
(* -> ?custom_gate_type:Kimchi_pasta_basic.Fp.t Kimchi_types.polish_token array *)
( 't
, 't
, 'b )
Composition_types.Wrap.Proof_state.Deferred_values.Plonk.Minimal.t
Expand Down
Loading
Loading