From 1cd1e03705f525fd6db945e1043ded6e74fbc5e3 Mon Sep 17 00:00:00 2001 From: Oleksandr Brezhniev Date: Mon, 6 Nov 2023 12:55:34 +0000 Subject: [PATCH] Regenerated ff with goff. Fixes because of changed interfaces in ff --- babyjub/babyjub.go | 58 +- ff/arith.go | 13 + ff/asm.go | 5 +- ff/asm_noadx.go | 5 +- ff/doc.go | 44 +- ff/element.go | 1582 ++++++++++++++++++++++----------- ff/element_fuzz.go | 136 --- ff/element_mul_adx_amd64.s | 466 ---------- ff/element_mul_amd64.s | 103 ++- ff/element_ops_amd64.go | 75 ++ ff/element_ops_amd64.s | 91 +- ff/element_ops_noasm.go | 78 -- ff/element_ops_purego.go | 463 ++++++++++ ff/element_test.go | 1699 ++++++++++++++++++++++++++++-------- ff/vector.go | 253 ++++++ ff/vector_test.go | 91 ++ go.mod | 6 +- go.sum | 8 + mimc7/mimc7.go | 36 +- poseidon/constants.go | 8 +- poseidon/poseidon.go | 4 +- utils/utils.go | 4 +- 22 files changed, 3460 insertions(+), 1768 deletions(-) delete mode 100644 ff/element_fuzz.go delete mode 100644 ff/element_mul_adx_amd64.s delete mode 100644 ff/element_ops_noasm.go create mode 100644 ff/element_ops_purego.go create mode 100644 ff/vector.go create mode 100644 ff/vector_test.go diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 95db039..fcec711 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -36,8 +36,8 @@ var B8 *Point func init() { A, _ = utils.NewIntFromString("168700") D, _ = utils.NewIntFromString("168696") - Aff = ff.NewElement().SetBigInt(A) - Dff = ff.NewElement().SetBigInt(D) + Aff = new(ff.Element).SetBigInt(A) + Dff = new(ff.Element).SetBigInt(D) Order, _ = utils.NewIntFromString( "21888242871839275222246405745257275088614511777268538073601725287587578984328") @@ -59,26 +59,26 @@ type PointProjective struct { // NewPointProjective creates a new Point in projective coordinates. func NewPointProjective() *PointProjective { - return &PointProjective{X: ff.NewElement().SetZero(), - Y: ff.NewElement().SetOne(), Z: ff.NewElement().SetOne()} + return &PointProjective{X: new(ff.Element).SetZero(), + Y: new(ff.Element).SetOne(), Z: new(ff.Element).SetOne()} } // Affine returns the Point from the projective representation func (p *PointProjective) Affine() *Point { - if p.Z.Equal(ff.NewElement().SetZero()) { + if p.Z.Equal(new(ff.Element).SetZero()) { return &Point{ X: big.NewInt(0), Y: big.NewInt(0), } } - zinv := ff.NewElement().Inverse(p.Z) - x := ff.NewElement().Mul(p.X, zinv) + zinv := new(ff.Element).Inverse(p.Z) + x := new(ff.Element).Mul(p.X, zinv) - y := ff.NewElement().Mul(p.Y, zinv) + y := new(ff.Element).Mul(p.Y, zinv) xBig := big.NewInt(0) - x.ToBigIntRegular(xBig) + x.BigInt(xBig) yBig := big.NewInt(0) - y.ToBigIntRegular(yBig) + y.BigInt(yBig) return &Point{ X: xBig, Y: yBig, @@ -90,26 +90,26 @@ func (p *PointProjective) Affine() *Point { func (p *PointProjective) Add(q, o *PointProjective) *PointProjective { // add-2008-bbjlp // https://hyperelliptic.org/EFD/g1p/auto-twisted-projective.html#doubling-dbl-2008-bbjlp - a := ff.NewElement().Mul(q.Z, o.Z) - b := ff.NewElement().Square(a) - c := ff.NewElement().Mul(q.X, o.X) - d := ff.NewElement().Mul(q.Y, o.Y) - e := ff.NewElement().Mul(Dff, c) + a := new(ff.Element).Mul(q.Z, o.Z) + b := new(ff.Element).Square(a) + c := new(ff.Element).Mul(q.X, o.X) + d := new(ff.Element).Mul(q.Y, o.Y) + e := new(ff.Element).Mul(Dff, c) e.Mul(e, d) - f := ff.NewElement().Sub(b, e) - g := ff.NewElement().Add(b, e) - x1y1 := ff.NewElement().Add(q.X, q.Y) - x2y2 := ff.NewElement().Add(o.X, o.Y) - x3 := ff.NewElement().Mul(x1y1, x2y2) + f := new(ff.Element).Sub(b, e) + g := new(ff.Element).Add(b, e) + x1y1 := new(ff.Element).Add(q.X, q.Y) + x2y2 := new(ff.Element).Add(o.X, o.Y) + x3 := new(ff.Element).Mul(x1y1, x2y2) x3.Sub(x3, c) x3.Sub(x3, d) x3.Mul(x3, a) x3.Mul(x3, f) - ac := ff.NewElement().Mul(Aff, c) - y3 := ff.NewElement().Sub(d, ac) + ac := new(ff.Element).Mul(Aff, c) + y3 := new(ff.Element).Sub(d, ac) y3.Mul(y3, a) y3.Mul(y3, g) - z3 := ff.NewElement().Mul(f, g) + z3 := new(ff.Element).Mul(f, g) p.X = x3 p.Y = y3 @@ -138,9 +138,9 @@ func (p *Point) Set(c *Point) *Point { // Projective returns a PointProjective from the Point func (p *Point) Projective() *PointProjective { return &PointProjective{ - X: ff.NewElement().SetBigInt(p.X), - Y: ff.NewElement().SetBigInt(p.Y), - Z: ff.NewElement().SetOne(), + X: new(ff.Element).SetBigInt(p.X), + Y: new(ff.Element).SetBigInt(p.Y), + Z: new(ff.Element).SetOne(), } } @@ -148,9 +148,9 @@ func (p *Point) Projective() *PointProjective { // which is also returned. func (p *Point) Mul(s *big.Int, q *Point) *Point { resProj := &PointProjective{ - X: ff.NewElement().SetZero(), - Y: ff.NewElement().SetOne(), - Z: ff.NewElement().SetOne(), + X: new(ff.Element).SetZero(), + Y: new(ff.Element).SetOne(), + Z: new(ff.Element).SetOne(), } exp := q.Projective() diff --git a/ff/arith.go b/ff/arith.go index 790067f..ddb7907 100644 --- a/ff/arith.go +++ b/ff/arith.go @@ -58,3 +58,16 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { hi, _ = bits.Add64(hi, e, carry) return } +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} diff --git a/ff/asm.go b/ff/asm.go index 2718ff3..936898e 100644 --- a/ff/asm.go +++ b/ff/asm.go @@ -21,4 +21,7 @@ package ff import "golang.org/x/sys/cpu" -var supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 +var ( + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx +) diff --git a/ff/asm_noadx.go b/ff/asm_noadx.go index 23c3a0b..ca17ea6 100644 --- a/ff/asm_noadx.go +++ b/ff/asm_noadx.go @@ -22,4 +22,7 @@ package ff // note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. -var supportAdx = false +var ( + supportAdx = false + _ = supportAdx +) diff --git a/ff/doc.go b/ff/doc.go index 114a4eb..1767790 100644 --- a/ff/doc.go +++ b/ff/doc.go @@ -16,28 +16,38 @@ // Package ff contains field arithmetic operations for modulus = 0x30644e...000001. // -// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@zkteam/modular_multiplication) +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) // // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Example API signature -// // Mul z = x * y mod q -// func (z *Element) Mul(x, y *Element) *Element +// type Element [4]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) -// -// Modulus -// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 // base 16 -// 21888242871839275222246405745257275088548364400416034343698204186575808495617 // base 10 +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package ff diff --git a/ff/element.go b/ff/element.go index c2ff2bc..028945a 100644 --- a/ff/element.go +++ b/ff/element.go @@ -25,85 +25,101 @@ import ( "math/bits" "reflect" "strconv" - "sync" + "strings" + + "github.com/bits-and-blooms/bitset" + "github.com/consensys/gnark-crypto/field/hash" + "github.com/consensys/gnark-crypto/field/pool" ) // Element represents a field element stored on 4 words (uint64) -// Element are assumed to be in Montgomery form in all methods -// field modulus q = // -// 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 -// Limbs number of 64 bits words needed to represent Element -const Limbs = 4 +const ( + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element +) -// Bits number bits needed to represent Element -const Bits = 254 +// Field modulus q +const ( + q0 uint64 = 4891460686036598785 + q1 uint64 = 2896914383306846353 + q2 uint64 = 13281191951274694749 + q3 uint64 = 3486998266802970665 +) -// Bytes number bytes needed to represent Element -const Bytes = Limbs * 8 +var qElement = Element{ + q0, + q1, + q2, + q3, +} -// field modulus stored as big.Int -var _modulus big.Int +var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int -// q = // -// 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } -// q (modulus) -var qElement = Element{ - 4891460686036598785, - 2896914383306846353, - 13281191951274694749, - 3486998266802970665, -} - -// rSquare -var rSquare = Element{ - 1997599621687373223, - 6052339484930628067, - 10108755138030829701, - 150537098327114917, -} - -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 14042775128853446655 func init() { - _modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + _modulus.SetString("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", 16) } -// NewElement returns a new Element -func NewElement() *Element { - return &Element{} -} - -// NewElementFromUint64 returns a new Element from a uint64 value +// NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v NewElement -// v.SetUint64(...) -func NewElementFromUint64(v uint64) Element { +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) return z } -// SetUint64 z = v, sets z LSB to v (non-Montgomery form) and convert z to Montgomery form +// SetUint64 sets z to v and returns z func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } -// Set z = x +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z func (z *Element) Set(x *Element) *Element { z[0] = x[0] z[1] = x[1] @@ -114,21 +130,55 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported -// supported types: Element, *Element, uint64, int, string (interpreted as base10 integer), -// *big.Int, big.Int, []byte +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set ff.Element with ") + } + switch c1 := i1.(type) { case Element: return z.Set(&c1), nil case *Element: + if c1 == nil { + return nil, errors.New("can't set ff.Element with ") + } return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil case uint64: return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil case int: - return z.SetString(strconv.Itoa(c1)), nil + return z.SetInt64(int64(c1)), nil case string: - return z.SetString(c1), nil + return z.SetString(c1) case *big.Int: + if c1 == nil { + return nil, errors.New("can't set ff.Element with ") + } return z.SetBigInt(c1), nil case big.Int: return z.SetBigInt(&c1), nil @@ -157,7 +207,58 @@ func (z *Element) SetOne() *Element { return z } -// Div z = x*y^-1 mod q +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + var borrow uint64 + z[0], borrow = bits.Sub64(q0, x[0], 0) + z[1], borrow = bits.Sub64(q1, x[1], borrow) + z[2], borrow = bits.Sub64(q2, x[2], borrow) + z[3], _ = bits.Sub64(q3, x[3], borrow) + return z +} + +// Div z = x*y⁻¹ (mod q) func (z *Element) Div(x, y *Element) *Element { var yInv Element yInv.Inverse(y) @@ -165,19 +266,14 @@ func (z *Element) Div(x, y *Element) *Element { return z } -// Bit returns the i'th bit, with lsb == bit 0. -// It is the responsability of the caller to convert from Montgomery to Regular form if needed -func (z *Element) Bit(i uint64) uint64 { - j := i / 64 - if j >= 4 { - return 0 - } - return uint64(z[j] >> (i % 64) & 1) +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 } -// Equal returns z == x -func (z *Element) Equal(x *Element) bool { - return (z[3] == x[3]) && (z[2] == x[2]) && (z[1] == x[1]) && (z[0] == x[0]) +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[3] ^ x[3]) | (z[2] ^ x[2]) | (z[1] ^ x[1]) | (z[0] ^ x[0]) } // IsZero returns z == 0 @@ -185,22 +281,38 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return ((z[3] ^ 1011752739694698287) | (z[2] ^ 7381016538464732718) | (z[1] ^ 3962172157175319849) | (z[0] ^ 12436184717236109307)) == 0 +} + +// IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { + zz := *z + zz.fromMont() + return zz.FitsOnOneWord() +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { return (z[3] | z[2] | z[1]) == 0 } // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -231,8 +343,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 11669102379873075201, 0) @@ -243,53 +354,79 @@ func (z *Element) LexicographicallyLargest() bool { return b == 0 } -// SetRandom sets z to a random element < q +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. func (z *Element) SetRandom() (*Element, error) { - var bytes [32]byte - if _, err := io.ReadFull(rand.Reader, bytes[:]); err != nil { - return nil, err + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 4 uint64 + const l = 32 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 254 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 } - z[0] = binary.BigEndian.Uint64(bytes[0:8]) - z[1] = binary.BigEndian.Uint64(bytes[8:16]) - z[2] = binary.BigEndian.Uint64(bytes[16:24]) - z[3] = binary.BigEndian.Uint64(bytes[24:32]) - z[3] %= 3486998266802970665 - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most significant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 - z[0] = z[0]>>1 | z[1]<<63 z[1] = z[1]>>1 | z[2]<<63 z[2] = z[2]>>1 | z[3]<<63 @@ -297,272 +434,234 @@ func (z *Element) Halve() { } -// API with assembly impl - -// Mul z = x * y mod q -// see https://hackmd.io/@zkteam/modular_multiplication -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} - -// Square z = x * x mod q -// see https://hackmd.io/@zkteam/modular_multiplication -func (z *Element) Square(x *Element) *Element { - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } -// Add z = x + y mod q -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -// Double z = x + x mod q, aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -// Sub z = x - y mod q -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - -// Neg z = q - x -func (z *Element) Neg(x *Element) *Element { - neg(z, x) +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + z[1] = x0[1] ^ cC&(x0[1]^x1[1]) + z[2] = x0[2] ^ cC&(x0[2]^x1[2]) + z[3] = x0[3] ^ cC&(x0[3]^x1[3]) return z } -// Generic (no ADX instructions, no AMD64) versions of multiplication and squaring algorithms - +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) } } func _fromMontGeneric(z *Element) { // the following lines implement z = z * 1 // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation { // m = z[0]n'[0] mod W - m := z[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, z[0]) - C, z[0] = madd2(m, 2896914383306846353, z[1], C) - C, z[1] = madd2(m, 13281191951274694749, z[2], C) - C, z[2] = madd2(m, 3486998266802970665, z[3], C) + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) z[3] = C } { // m = z[0]n'[0] mod W - m := z[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, z[0]) - C, z[0] = madd2(m, 2896914383306846353, z[1], C) - C, z[1] = madd2(m, 13281191951274694749, z[2], C) - C, z[2] = madd2(m, 3486998266802970665, z[3], C) + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) z[3] = C } { // m = z[0]n'[0] mod W - m := z[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, z[0]) - C, z[0] = madd2(m, 2896914383306846353, z[1], C) - C, z[1] = madd2(m, 13281191951274694749, z[2], C) - C, z[2] = madd2(m, 3486998266802970665, z[3], C) + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) z[3] = C } { // m = z[0]n'[0] mod W - m := z[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, z[0]) - C, z[0] = madd2(m, 2896914383306846353, z[1], C) - C, z[1] = madd2(m, 13281191951274694749, z[2], C) - C, z[2] = madd2(m, 3486998266802970665, z[3], C) + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) z[3] = C } - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) } } -func _addGeneric(z, x, y *Element) { - var carry uint64 - - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } -} - -func _doubleGeneric(z, x *Element) { - var carry uint64 - - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } -} - -func _subGeneric(z, x, y *Element) { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], 4891460686036598785, 0) - z[1], c = bits.Add64(z[1], 2896914383306846353, c) - z[2], c = bits.Add64(z[2], 13281191951274694749, c) - z[3], _ = bits.Add64(z[3], 3486998266802970665, c) - } -} - -func _negGeneric(z, x *Element) { - if x.IsZero() { - z.SetZero() - return - } - var borrow uint64 - z[0], borrow = bits.Sub64(4891460686036598785, x[0], 0) - z[1], borrow = bits.Sub64(2896914383306846353, x[1], borrow) - z[2], borrow = bits.Sub64(13281191951274694749, x[2], borrow) - z[3], _ = bits.Sub64(3486998266802970665, x[3], borrow) -} - func _reduceGeneric(z *Element) { - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } -} - -func mulByConstant(z *Element, c uint8) { - switch c { - case 0: - z.SetZero() - return - case 1: - return - case 2: - z.Double(z) - return - case 3: - _z := *z - z.Double(z).Add(z, &_z) - case 5: - _z := *z - z.Double(z).Double(z).Add(z, &_z) - default: - var y Element - y.SetUint64(uint64(c)) - z.Mul(z, &y) + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) } } @@ -574,12 +673,12 @@ func BatchInvert(a []Element) []Element { return res } - zeroes := make([]bool, len(a)) + zeroes := bitset.New(uint(len(a))) accumulator := One() for i := 0; i < len(a); i++ { if a[i].IsZero() { - zeroes[i] = true + zeroes.Set(uint(i)) continue } res[i] = accumulator @@ -589,7 +688,7 @@ func BatchInvert(a []Element) []Element { accumulator.Inverse(&accumulator) for i := len(a) - 1; i >= 0; i-- { - if zeroes[i] { + if zeroes.Test(uint(i)) { continue } res[i].Mul(&res[i], &accumulator) @@ -620,18 +719,59 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } -// Exp z = x^exponent mod q -func (z *Element) Exp(x Element, exponent *big.Int) *Element { - var bZero big.Int - if exponent.Cmp(&bZero) == 0 { +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := hash.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + pool.BigInt.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { return z.SetOne() } + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + z.Set(&x) - for i := exponent.BitLen() - 2; i >= 0; i-- { + for i := e.BitLen() - 2; i >= 0; i-- { z.Square(z) - if exponent.Bit(i) == 1 { + if e.Bit(i) == 1 { z.Mul(z, &x) } } @@ -639,39 +779,31 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { return z } -// ToMont converts z to Montgomery form -// sets and returns z = z * r^2 -func (z *Element) ToMont() *Element { - return z.Mul(z, &rSquare) +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) } -// String returns the string form of an Element in Montgomery form +// String returns the decimal representation of z as generated by +// z.Text(10). func (z *Element) String() string { - zz := *z - zz.FromMont() - if zz.IsUint64() { - return strconv.FormatUint(zz[0], 10) - } else { - var zzNeg Element - zzNeg.Neg(z) - zzNeg.FromMont() - if zzNeg.IsUint64() { - return "-" + strconv.FormatUint(zzNeg[0], 10) - } - } - vv := bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(vv) - return zz.ToBigInt(vv).String() + return z.Text(10) } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte binary.BigEndian.PutUint64(b[24:32], z[0]) binary.BigEndian.PutUint64(b[16:24], z[1]) binary.BigEndian.PutUint64(b[8:16], z[2]) @@ -680,48 +812,123 @@ func (z *Element) ToBigInt(res *big.Int) *big.Int { return res.SetBytes(b[:]) } +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := *z + zz.fromMont() + if zz.FitsOnOneWord() { + return strconv.FormatUint(zz[0], base) + } + vv := pool.BigInt.Get() + r := zz.toBigInt(vv).Text(base) + pool.BigInt.Put(vv) + return r +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the regular (non montgomery) value -// of z as a big-endian byte array. -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } -// Marshal returns the regular (non montgomery) value -// of z as a big-endian byte slice. +// Marshal returns the value of z as a big-endian byte slice func (z *Element) Marshal() []byte { b := z.Bytes() return b[:] } +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + // SetBytes interprets e as the bytes of a big-endian unsigned integer, -// sets z to that value (in Montgomery form), and returns z. +// sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := pool.BigInt.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + pool.BigInt.Put(vv) return z } -// SetBigInt sets z to v (regular form) and returns z in Montgomery form +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid ff.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -738,21 +945,20 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := pool.BigInt.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + pool.BigInt.Put(vv) return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -770,25 +976,159 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } -// SetString creates a big.Int with s (in base 10) and calls SetBigInt on z -func (z *Element) SetString(s string) *Element { +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := pool.BigInt.Get() - if _, ok := vv.SetString(s, 10); !ok { - panic("Element.SetString failed -> can't parse number in base10 into a big.Int") + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) } + z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + pool.BigInt.Put(vv) - return z + return z, nil } +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := pool.BigInt.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + pool.BigInt.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid ff.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid ff.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + var ( _bLegendreExponentElement *big.Int _bSqrtExponentElement *big.Int @@ -811,13 +1151,13 @@ func (z *Element) Legendre() int { } // if l == 1 - if (l[3] == 1011752739694698287) && (l[2] == 7381016538464732718) && (l[1] == 3962172157175319849) && (l[0] == 12436184717236109307) { + if l.IsOne() { return 1 } return -1 } -// Sqrt z = √x mod q +// Sqrt z = √x (mod q) // if the square root doesn't exist (x is not a square mod q) // Sqrt leaves z unchanged and returns nil func (z *Element) Sqrt(x *Element) *Element { @@ -832,7 +1172,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -845,7 +1185,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(28) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -853,7 +1193,7 @@ func (z *Element) Sqrt(x *Element) *Element { if t.IsZero() { return z.SetZero() } - if !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + if !t.IsOne() { // t != 1, we don't have a square root return nil } @@ -862,7 +1202,7 @@ func (z *Element) Sqrt(x *Element) *Element { t = b // for t != 1 - for !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + for !t.IsOne() { t.Square(&t) m++ } @@ -870,7 +1210,7 @@ func (z *Element) Sqrt(x *Element) *Element { if m == 0 { return z.Set(&y) } - // t = g^(2^(r-m-1)) mod q + // t = g^(2^(r-m-1)) (mod q) ge := int(r - m - 1) t = g for ge > 0 { @@ -885,153 +1225,387 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" +const ( + k = 32 // word size / 2 + signBitSelector = uint64(1) << 63 + approxLowBitsN = k - 1 + approxHighBitsN = k + 1 +) + +const ( + inversionCorrectionFactorWord0 = 13488105295233737379 + inversionCorrectionFactorWord1 = 17373395488625725466 + inversionCorrectionFactorWord2 = 6831692495576925776 + inversionCorrectionFactorWord3 = 3282329835997625403 + invIterationsN = 18 +) + +// Inverse z = x⁻¹ (mod q) +// // if x == 0, sets and returns z = x func (z *Element) Inverse(x *Element) *Element { - if x.IsZero() { - z.SetZero() - return z - } + // Implements "Optimized Binary GCD for Modular Inversion" + // https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf - // initialize u = q - var u = Element{ - 4891460686036598785, - 2896914383306846353, - 13281191951274694749, - 3486998266802970665, - } + a := *x + b := Element{ + q0, + q1, + q2, + q3, + } // b := q - // initialize s = r^2 - var s = Element{ - 1997599621687373223, - 6052339484930628067, - 10108755138030829701, - 150537098327114917, - } + u := Element{1} - // r = 0 - r := Element{} + // Update factors: we get [u; v] ← [f₀ g₀; f₁ g₁] [u; v] + // cᵢ = fᵢ + 2³¹ - 1 + 2³² * (gᵢ + 2³¹ - 1) + var c0, c1 int64 - v := *x + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 - var carry, borrow uint64 - var bigger bool + var i uint - for { - for v[0]&1 == 0 { + var v, s Element - // v = v >> 1 + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + // f₀, g₀, f₁, g₁ = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 - if s[0]&1 == 1 { + for j := 0; j < approxLowBitsN; j++ { - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + // -2ʲ < f₀, f₁ ≤ 2ʲ + // |f₀| + |f₁| < 2ʲ⁺¹ + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + // invariants unchanged + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ⁺¹ ≤ 2ʲ⁺¹ (only the weaker inequality is needed, strictly speaking) + // Started with f₀ > -2ʲ and f₁ ≤ 2ʲ, so f₀ - f₁ > -2ʲ⁺¹ + // Invariants unchanged for f₁ } - // s = s >> 1 + c1 *= 2 + // -2ʲ⁺¹ < f₁ ≤ 2ʲ⁺¹ + // So now |f₀| + |f₁| < 2ʲ⁺² + } - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 + s = a + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = negL(&a, aHi) } - for u[0]&1 == 0 { + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = negL(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [F₀, G₀; F₁, G₁] ← [f₀, g₀; f₁, g₁] [pf₀, pg₀; pf₁, pg₁], with capital letters denoting new combined values + // We get |F₀| = | f₀pf₀ + g₀pf₁ | ≤ |f₀pf₀| + |g₀pf₁| = |f₀| |pf₀| + |g₀| |pf₁| ≤ 2ᵏ⁻¹|pf₀| + 2ᵏ⁻¹|pf₁| + // = 2ᵏ⁻¹ (|pf₀| + |pf₁|) < 2ᵏ⁻¹ 2ᵏ = 2²ᵏ⁻¹ + // So |F₀| < 2²ᵏ⁻¹ meaning it fits in a 2k-bit signed register + + // c₀ aliases f₀, c₁ aliases g₁ + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + + // 0 ≤ u, v < 2²⁵⁵ + // |F₀|, |G₀| < 2⁶³ + u.linearComb(&u, c0, &v, g0) + // |F₁|, |G₁| < 2⁶³ + v.linearComb(&s, f1, &v, c1) - // u = u >> 1 + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + // For every iteration that we miss, v is not being multiplied by 2ᵏ⁻² + const pSq uint64 = 1 << (2 * (k - 1)) + a = Element{pSq} + // If the function is constant-time ish, this loop will not run (no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + // could optimize further with mul by word routine or by pre-computing a table since with k=26, + // we would multiply by pSq up to 13times; + // on x86, the assembly routine outperforms generic code for mul by word + // on arm64, we may loose up to ~5% for 6 limbs + v.Mul(&v, &a) + } - if r[0]&1 == 1 { + u.Set(x) // for correctness check - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) - } + // correctness check + v.Mul(&u, z) + if !v.IsOne() && !u.IsZero() { + return z.inverseExp(u) + } - // r = r >> 1 + return z +} - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 +// inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) } + } - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + return z +} - if bigger { +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) + if nBits <= 64 { + return x[0] + } - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - if borrow == 1 { + hiWordIndex := (nBits - 1) / 64 - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - } - } else { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +// linearComb z = xC * x + yC * y; +// 0 ≤ x, y < 2²⁵⁴ +// |xC|, |yC| < 2⁶³ +func (z *Element) linearComb(x *Element, xC int64, y *Element, yC int64) { + // | (hi, z) | < 2 * 2⁶³ * 2²⁵⁴ = 2³¹⁸ + // therefore | hi | < 2⁶² ≤ 2⁶³ + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) +// montReduceSigned z = (xHi * r + x) * r⁻¹ using the SOS algorithm +// Requires |xHi| < 2⁶³. Most significant bit of xHi is the sign bit. +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + const signBitRemover = ^signBitSelector + mustNeg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // If X is negative we would have initially stored it as 2⁶⁴ r + X (à la 2's complement) + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNeg + + C = madd0(m, q0, x[0]) + C, t[1] = madd2(m, q1, x[1], C) + C, t[2] = madd2(m, q2, x[2], C) + C, t[3] = madd2(m, q3, x[3], C) + + // m * qElement[3] ≤ (2⁶⁴ - 1) * (2⁶³ - 1) = 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + // x[3] + C ≤ 2*(2⁶⁴ - 1) = 2⁶⁵ - 2 + // On LHS, (C, t[3]) ≤ 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + 2⁶⁵ - 2 = 2¹²⁷ + 2⁶³ - 1 + // So on LHS, C ≤ 2⁶³ + t[4] = xHi + C + // xHi + C < 2⁶³ + 2⁶³ = 2⁶⁴ + + // + { + const i = 1 + m = t[i] * qInvNeg - if borrow == 1 { + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNeg - } - } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z - } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + C = madd0(m, q0, t[i+0]) + C, t[i+1] = madd2(m, q1, t[i+1], C) + C, t[i+2] = madd2(m, q2, t[i+2], C) + C, t[i+3] = madd2(m, q3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNeg + + C = madd0(m, q0, t[i+0]) + C, z[0] = madd2(m, q1, t[i+1], C) + C, z[1] = madd2(m, q2, t[i+2], C) + z[3], z[2] = madd2(m, q3, t[i+3], C) + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + // + + if mustNeg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + var carry uint64 + + z[0], carry = bits.Add64(z[0], q0, 0) + z[1], carry = bits.Add64(z[1], q1, carry) + z[2], carry = bits.Add64(z[2], q2, carry) + z[3], _ = bits.Add64(neg1, q3, carry) } } +} + +const ( + updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) + updateFactorIdentityMatrixRow0 = 1 + updateFactorIdentityMatrixRow1 = 1 << 32 +) + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +// negL negates in place [x | xHi] and return the new most significant word xHi +func negL(x *Element, xHi uint64) uint64 { + var b uint64 + + x[0], b = bits.Sub64(0, x[0], 0) + x[1], b = bits.Sub64(0, x[1], b) + x[2], b = bits.Sub64(0, x[2], b) + x[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} + +// mulWNonModular multiplies by one word in non-montgomery, without reducing +func (z *Element) mulWNonModular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + + if y < 0 { + c = negL(z, c) + } + + return c +} + +// linearCombNonModular computes a linear combination without modular reduction +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWNonModular(y, yC) + xHi := z.mulWNonModular(x, xC) + + var carry uint64 + z[0], carry = bits.Add64(z[0], yTimes[0], 0) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ff/element_fuzz.go b/ff/element_fuzz.go deleted file mode 100644 index cfb088a..0000000 --- a/ff/element_fuzz.go +++ /dev/null @@ -1,136 +0,0 @@ -//go:build gofuzz -// +build gofuzz - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package ff - -import ( - "bytes" - "encoding/binary" - "io" - "math/big" - "math/bits" -) - -const ( - fuzzInteresting = 1 - fuzzNormal = 0 - fuzzDiscard = -1 -) - -// Fuzz arithmetic operations fuzzer -func Fuzz(data []byte) int { - r := bytes.NewReader(data) - - var e1, e2 Element - e1.SetRawBytes(r) - e2.SetRawBytes(r) - - { - // mul assembly - - var c, _c Element - a, _a, b, _b := e1, e1, e2, e2 - c.Mul(&a, &b) - _mulGeneric(&_c, &_a, &_b) - - if !c.Equal(&_c) { - panic("mul asm != mul generic on Element") - } - } - - { - // inverse - inv := e1 - inv.Inverse(&inv) - - var bInv, b1, b2 big.Int - e1.ToBigIntRegular(&b1) - bInv.ModInverse(&b1, Modulus()) - inv.ToBigIntRegular(&b2) - - if b2.Cmp(&bInv) != 0 { - panic("inverse operation doesn't match big int result") - } - } - - { - // a + -a == 0 - a, b := e1, e1 - b.Neg(&b) - a.Add(&a, &b) - if !a.IsZero() { - panic("a + -a != 0") - } - } - - return fuzzNormal - -} - -// SetRawBytes reads up to Bytes (bytes needed to represent Element) from reader -// and interpret it as big endian uint64 -// used for fuzzing purposes only -func (z *Element) SetRawBytes(r io.Reader) { - - buf := make([]byte, 8) - - for i := 0; i < len(z); i++ { - if _, err := io.ReadFull(r, buf); err != nil { - goto eof - } - z[i] = binary.BigEndian.Uint64(buf[:]) - } -eof: - z[3] %= qElement[3] - - if z.BiggerModulus() { - var b uint64 - z[0], b = bits.Sub64(z[0], qElement[0], 0) - z[1], b = bits.Sub64(z[1], qElement[1], b) - z[2], b = bits.Sub64(z[2], qElement[2], b) - z[3], b = bits.Sub64(z[3], qElement[3], b) - } - - return -} - -func (z *Element) BiggerModulus() bool { - if z[3] > qElement[3] { - return true - } - if z[3] < qElement[3] { - return false - } - - if z[2] > qElement[2] { - return true - } - if z[2] < qElement[2] { - return false - } - - if z[1] > qElement[1] { - return true - } - if z[1] < qElement[1] { - return false - } - - return z[0] >= qElement[0] -} diff --git a/ff/element_mul_adx_amd64.s b/ff/element_mul_adx_amd64.s deleted file mode 100644 index 494e7bf..0000000 --- a/ff/element_mul_adx_amd64.s +++ /dev/null @@ -1,466 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x43e1f593f0000001 -DATA q<>+8(SB)/8, $0x2833e84879b97091 -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xc2e1f593efffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) - REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ff/element_mul_amd64.s b/ff/element_mul_amd64.s index 38b3b6c..b51bc69 100644 --- a/ff/element_mul_amd64.s +++ b/ff/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // @@ -45,8 +45,7 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 // mul(res, x, y *Element) TEXT ·mul(SB), $24-24 - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication + // the algorithm is described in the Element.Mul declaration (.go) // however, to benefit from the ADCX and ADOX carry chains // we split the inner loops in 2: // for i=0 to N-1 @@ -75,7 +74,7 @@ TEXT ·mul(SB), $24-24 // A -> BP // t[0] -> R14 - // t[1] -> R15 + // t[1] -> R13 // t[2] -> CX // t[3] -> BX // clear the flags @@ -83,11 +82,11 @@ TEXT ·mul(SB), $24-24 MOVQ 0(R11), DX // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R15 + MULXQ DI, R14, R13 // (A,t[1]) := x[1]*y[0] + A MULXQ R8, AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (A,t[2]) := x[2]*y[0] + A MULXQ R9, AX, BX @@ -114,14 +113,14 @@ TEXT ·mul(SB), $24-24 MOVQ R12, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -142,9 +141,9 @@ TEXT ·mul(SB), $24-24 ADOXQ AX, R14 // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 + ADCXQ BP, R13 MULXQ R8, AX, BP - ADOXQ AX, R15 + ADOXQ AX, R13 // (A,t[2]) := t[2] + x[2]*y[1] + A ADCXQ BP, CX @@ -174,14 +173,14 @@ TEXT ·mul(SB), $24-24 MOVQ R12, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -202,9 +201,9 @@ TEXT ·mul(SB), $24-24 ADOXQ AX, R14 // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 + ADCXQ BP, R13 MULXQ R8, AX, BP - ADOXQ AX, R15 + ADOXQ AX, R13 // (A,t[2]) := t[2] + x[2]*y[2] + A ADCXQ BP, CX @@ -234,14 +233,14 @@ TEXT ·mul(SB), $24-24 MOVQ R12, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -262,9 +261,9 @@ TEXT ·mul(SB), $24-24 ADOXQ AX, R14 // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 + ADCXQ BP, R13 MULXQ R8, AX, BP - ADOXQ AX, R15 + ADOXQ AX, R13 // (A,t[2]) := t[2] + x[2]*y[3] + A ADCXQ BP, CX @@ -294,14 +293,14 @@ TEXT ·mul(SB), $24-24 MOVQ R12, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -313,12 +312,12 @@ TEXT ·mul(SB), $24-24 ADCXQ AX, BX ADOXQ BP, BX - // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) - REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) + // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) + REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) MOVQ res+0(FP), AX MOVQ R14, 0(AX) - MOVQ R15, 8(AX) + MOVQ R13, 8(AX) MOVQ CX, 16(AX) MOVQ BX, 24(AX) RET @@ -337,7 +336,7 @@ TEXT ·fromMont(SB), $8-8 NO_LOCAL_POINTERS // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication + // https://hackmd.io/@gnark/modular_multiplication // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] @@ -351,7 +350,7 @@ TEXT ·fromMont(SB), $8-8 JNE l2 MOVQ res+0(FP), DX MOVQ 0(DX), R14 - MOVQ 8(DX), R15 + MOVQ 8(DX), R13 MOVQ 16(DX), CX MOVQ 24(DX), BX XORQ DX, DX @@ -367,14 +366,14 @@ TEXT ·fromMont(SB), $8-8 MOVQ BP, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -396,14 +395,14 @@ TEXT ·fromMont(SB), $8-8 MOVQ BP, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -425,14 +424,14 @@ TEXT ·fromMont(SB), $8-8 MOVQ BP, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -454,14 +453,14 @@ TEXT ·fromMont(SB), $8-8 MOVQ BP, R14 // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 ADOXQ AX, R14 // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 + ADCXQ CX, R13 MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 + ADOXQ AX, R13 // (C,t[2]) := t[3] + m*q[3] + C ADCXQ BX, CX @@ -471,12 +470,12 @@ TEXT ·fromMont(SB), $8-8 ADCXQ AX, BX ADOXQ AX, BX - // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) + // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) + REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) MOVQ res+0(FP), AX MOVQ R14, 0(AX) - MOVQ R15, 8(AX) + MOVQ R13, 8(AX) MOVQ CX, 16(AX) MOVQ BX, 24(AX) RET diff --git a/ff/element_ops_amd64.go b/ff/element_ops_amd64.go index 777ba01..153381c 100644 --- a/ff/element_ops_amd64.go +++ b/ff/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -46,5 +49,77 @@ func fromMont(res *Element) //go:noescape func reduce(res *Element) +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} + +// Double z = x + x mod q, aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + double(z, x) + return z +} diff --git a/ff/element_ops_amd64.s b/ff/element_ops_amd64.s index d5dca83..5627b15 100644 --- a/ff/element_ops_amd64.s +++ b/ff/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,61 +42,6 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb2, ra2; \ CMOVQCS rb3, ra3; \ -// add(res, x, y *Element) -TEXT ·add(SB), NOSPLIT, $0-24 - MOVQ x+8(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ y+16(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), R12 - MOVQ CX, 0(R12) - MOVQ BX, 8(R12) - MOVQ SI, 16(R12) - MOVQ DI, 24(R12) - RET - -// sub(res, x, y *Element) -TEXT ·sub(SB), NOSPLIT, $0-24 - XORQ DI, DI - MOVQ x+8(FP), SI - MOVQ 0(SI), AX - MOVQ 8(SI), DX - MOVQ 16(SI), CX - MOVQ 24(SI), BX - MOVQ y+16(FP), SI - SUBQ 0(SI), AX - SBBQ 8(SI), DX - SBBQ 16(SI), CX - SBBQ 24(SI), BX - MOVQ $0x43e1f593f0000001, R8 - MOVQ $0x2833e84879b97091, R9 - MOVQ $0xb85045b68181585d, R10 - MOVQ $0x30644e72e131a029, R11 - CMOVQCC DI, R8 - CMOVQCC DI, R9 - CMOVQCC DI, R10 - CMOVQCC DI, R11 - ADDQ R8, AX - ADCQ R9, DX - ADCQ R10, CX - ADCQ R11, BX - MOVQ res+0(FP), R12 - MOVQ AX, 0(R12) - MOVQ DX, 8(R12) - MOVQ CX, 16(R12) - MOVQ BX, 24(R12) - RET - // double(res, x *Element) TEXT ·double(SB), NOSPLIT, $0-16 MOVQ x+8(FP), AX @@ -117,40 +64,6 @@ TEXT ·double(SB), NOSPLIT, $0-16 MOVQ SI, 24(R11) RET -// neg(res, x *Element) -TEXT ·neg(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DI - MOVQ x+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ DX, AX - ORQ CX, AX - ORQ BX, AX - ORQ SI, AX - TESTQ AX, AX - JEQ l1 - MOVQ $0x43e1f593f0000001, R8 - SUBQ DX, R8 - MOVQ R8, 0(DI) - MOVQ $0x2833e84879b97091, R8 - SBBQ CX, R8 - MOVQ R8, 8(DI) - MOVQ $0xb85045b68181585d, R8 - SBBQ BX, R8 - MOVQ R8, 16(DI) - MOVQ $0x30644e72e131a029, R8 - SBBQ SI, R8 - MOVQ R8, 24(DI) - RET - -l1: - MOVQ AX, 0(DI) - MOVQ AX, 8(DI) - MOVQ AX, 16(DI) - MOVQ AX, 24(DI) - RET TEXT ·reduce(SB), NOSPLIT, $0-8 MOVQ res+0(FP), AX diff --git a/ff/element_ops_noasm.go b/ff/element_ops_noasm.go deleted file mode 100644 index ca357bc..0000000 --- a/ff/element_ops_noasm.go +++ /dev/null @@ -1,78 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package ff - -// /!\ WARNING /!\ -// this code has not been audited and is provided as-is. In particular, -// there is no security guarantees such as constant time implementation -// or side-channel attack resistance -// /!\ WARNING /!\ - -// MulBy3 x *= 3 -func MulBy3(x *Element) { - mulByConstant(x, 3) -} - -// MulBy5 x *= 5 -func MulBy5(x *Element) { - mulByConstant(x, 5) -} - -// MulBy13 x *= 13 -func MulBy13(x *Element) { - mulByConstant(x, 13) -} - -// Butterfly sets -// a = a + b -// b = a - b -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation -// sets and returns z = z * 1 -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func add(z, x, y *Element) { - _addGeneric(z, x, y) -} - -func double(z, x *Element) { - _doubleGeneric(z, x) -} - -func sub(z, x, y *Element) { - _subGeneric(z, x, y) -} - -func neg(z, x *Element) { - _negGeneric(z, x) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ff/element_ops_purego.go b/ff/element_ops_purego.go new file mode 100644 index 0000000..c699e16 --- /dev/null +++ b/ff/element_ops_purego.go @@ -0,0 +1,463 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ff/element_test.go b/ff/element_test.go index 6c43f79..ad8d95a 100644 --- a/ff/element_test.go +++ b/ff/element_test.go @@ -18,12 +18,18 @@ package ff import ( "crypto/rand" + "encoding/json" + "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" ) // ------------------------------------------------------------------------------------------------- @@ -33,6 +39,27 @@ import ( var benchResElement Element +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + func BenchmarkElementSetBytes(b *testing.B) { var x Element x.SetRandom() @@ -152,17 +179,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -240,7 +260,6 @@ func TestElementCmp(t *testing.T) { t.Fatal("x < y") } } - func TestElementIsRandom(t *testing.T) { for i := 0; i < 50; i++ { var x, y Element @@ -252,6 +271,46 @@ func TestElementIsRandom(t *testing.T) { } } +func TestElementIsUint64(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(v uint64) bool { + var e Element + e.SetUint64(v) + + if !e.IsUint64() { + return false + } + + return e.Uint64() == v + }, + ggen.UInt64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + // ------------------------------------------------------------------------------------------------- // Gopter tests // most of them are generated with a template @@ -267,7 +326,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -277,20 +336,21 @@ func init() { { a := qElement - a[3]-- + a[0]-- staticTestValues = append(staticTestValues, a) } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{0, 0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{0, 1}) + staticTestValues = append(staticTestValues, Element{2}) + staticTestValues = append(staticTestValues, Element{0, 2}) + { a := qElement - a[0]-- + a[3]-- staticTestValues = append(staticTestValues, a) } - - for i := 0; i <= 3; i++ { - staticTestValues = append(staticTestValues, Element{uint64(i)}) - staticTestValues = append(staticTestValues, Element{0, uint64(i)}) - } - { a := qElement a[3]-- @@ -298,25 +358,20 @@ func init() { staticTestValues = append(staticTestValues, a) } -} - -func TestElementNegZero(t *testing.T) { - var a, b Element - b.SetZero() - for a.IsZero() { - a.SetRandom() - } - a.Neg(&b) - if !a.IsZero() { - t.Fatal("neg(0) != 0") + { + a := qElement + a[3] = 0 + staticTestValues = append(staticTestValues, a) } + } func TestElementReduce(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, s := range testValues { + for i := range testValues { + s := testValues[i] expected := s reduce(&s) _reduceGeneric(&expected) @@ -325,6 +380,7 @@ func TestElementReduce(t *testing.T) { } } + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -341,23 +397,50 @@ func TestElementReduce(t *testing.T) { b := a reduce(&a) _reduceGeneric(&b) - return !a.biggerOrEqualModulus() && a.Equal(&b) + return a.smallerThanModulus() && a.Equal(&b) }, genA, )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz } + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) } func TestElementBytes(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -369,7 +452,7 @@ func TestElementBytes(t *testing.T) { genA := gen() - properties.Property("SetBytes(Bytes()) should stayt constant", prop.ForAll( + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( func(a testPairElement) bool { var b Element bytes := a.element.Bytes() @@ -387,41 +470,43 @@ func TestElementInverseExp(t *testing.T) { exp := Modulus() exp.Sub(exp, new(big.Int).SetUint64(2)) + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort } else { parameters.MinSuccessfulTests = nbFuzz } - properties := gopter.NewProperties(parameters) - genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) - properties.Property("inv == exp^-2", prop.ForAll( - func(a testPairElement) bool { - var b Element - b.Set(&a.element) - a.element.Inverse(&a.element) - b.Exp(b, exp) + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) - return a.element.Equal(&b) - }, - genA, - )) +} - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) } func TestElementMulByConstants(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -502,17 +587,11 @@ func TestElementMulByConstants(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } } func TestElementLegendre(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -532,18 +611,36 @@ func TestElementLegendre(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz } + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } func TestElementButterflies(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -569,17 +666,11 @@ func TestElementButterflies(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } } func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -611,17 +702,11 @@ func TestElementLexicographicallyLargest(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } } func TestElementAdd(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -659,7 +744,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -670,23 +755,16 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _addGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -702,18 +780,7 @@ func TestElementAdd(t *testing.T) { c.Add(&a.element, &b.element) - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - properties.Property("Add: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Add(&a.element, &b.element) - _addGeneric(&d, &a.element, &b.element) - return c.Equal(&d) + return c.smallerThanModulus() }, genA, genB, @@ -724,26 +791,20 @@ func TestElementAdd(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _addGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Add failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -752,17 +813,11 @@ func TestElementAdd(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementSub(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -800,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -811,23 +866,16 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - // checking generic impl against asm path - var cGeneric Element - _subGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -843,18 +891,7 @@ func TestElementSub(t *testing.T) { c.Sub(&a.element, &b.element) - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - properties.Property("Sub: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Sub(&a.element, &b.element) - _subGeneric(&d, &a.element, &b.element) - return c.Equal(&d) + return c.smallerThanModulus() }, genA, genB, @@ -865,26 +902,20 @@ func TestElementSub(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _subGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Sub failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -893,17 +924,11 @@ func TestElementSub(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementMul(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -941,7 +966,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -952,9 +977,10 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -968,7 +994,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +1010,7 @@ func TestElementMul(t *testing.T) { c.Mul(&a.element, &b.element) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, genB, @@ -1006,13 +1032,14 @@ func TestElementMul(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1025,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1034,17 +1061,11 @@ func TestElementMul(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementDiv(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1083,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1094,16 +1115,17 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1119,7 +1141,7 @@ func TestElementDiv(t *testing.T) { c.Div(&a.element, &b.element) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, genB, @@ -1130,20 +1152,21 @@ func TestElementDiv(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1152,17 +1175,11 @@ func TestElementDiv(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementExp(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1200,7 +1217,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1211,15 +1228,16 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, r := range testValues { + for i := range testValues { + r := testValues[i] var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,7 +1253,7 @@ func TestElementExp(t *testing.T) { c.Exp(a.element, &b.bigint) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, genB, @@ -1246,19 +1264,20 @@ func TestElementExp(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - + a.BigInt(&aBig) + for j := range testValues { + b := testValues[j] var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1267,17 +1286,11 @@ func TestElementExp(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementSquare(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1309,7 +1322,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1318,7 +1331,7 @@ func TestElementSquare(t *testing.T) { func(a testPairElement) bool { var c Element c.Square(&a.element) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, )) @@ -1328,16 +1341,17 @@ func TestElementSquare(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1345,17 +1359,11 @@ func TestElementSquare(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementInverse(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1387,7 +1395,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1396,7 +1404,7 @@ func TestElementInverse(t *testing.T) { func(a testPairElement) bool { var c Element c.Inverse(&a.element) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, )) @@ -1406,16 +1414,17 @@ func TestElementInverse(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1423,17 +1432,11 @@ func TestElementInverse(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementSqrt(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1465,7 +1468,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1474,7 +1477,7 @@ func TestElementSqrt(t *testing.T) { func(a testPairElement) bool { var c Element c.Sqrt(&a.element) - return !c.biggerOrEqualModulus() + return c.smallerThanModulus() }, genA, )) @@ -1484,16 +1487,17 @@ func TestElementSqrt(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1501,17 +1505,11 @@ func TestElementSqrt(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementDouble(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1543,7 +1541,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1552,17 +1550,7 @@ func TestElementDouble(t *testing.T) { func(a testPairElement) bool { var c Element c.Double(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - properties.Property("Double: assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - var c, d Element - c.Double(&a.element) - _doubleGeneric(&d, &a.element) - return c.Equal(&d) + return c.smallerThanModulus() }, genA, )) @@ -1572,23 +1560,17 @@ func TestElementDouble(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _doubleGeneric(&cGeneric, &a) - if !cGeneric.Equal(&c) { - t.Fatal("Double failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1596,17 +1578,11 @@ func TestElementDouble(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementNeg(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1638,7 +1614,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1647,17 +1623,7 @@ func TestElementNeg(t *testing.T) { func(a testPairElement) bool { var c Element c.Neg(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - properties.Property("Neg: assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - var c, d Element - c.Neg(&a.element) - _negGeneric(&d, &a.element) - return c.Equal(&d) + return c.smallerThanModulus() }, genA, )) @@ -1667,23 +1633,17 @@ func TestElementNeg(t *testing.T) { testValues := make([]Element, len(staticTestValues)) copy(testValues, staticTestValues) - for _, a := range testValues { + for i := range testValues { + a := testValues[i] var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - // checking asm against generic impl - var cGeneric Element - _negGeneric(&cGeneric, &a) - if !cGeneric.Equal(&c) { - t.Fatal("Neg failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -1691,18 +1651,12 @@ func TestElementNeg(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } + } func TestElementHalve(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1731,8 +1685,15 @@ func TestElementHalve(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } -func TestElementFromMont(t *testing.T) { +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} +func TestElementSelect(t *testing.T) { + t.Parallel() parameters := gopter.DefaultTestParameters() if testing.Short() { parameters.MinSuccessfulTests = nbFuzzShort @@ -1742,59 +1703,493 @@ func TestElementFromMont(t *testing.T) { properties := gopter.NewProperties(parameters) - genA := gen() + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely - properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - c := a.element - d := a.element - c.FromMont() - _fromMontGeneric(&d) - return c.Equal(&d) + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) }, genA, + genB, + genC, + genZ, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( - func(a testPairElement) bool { - c := a.element - c.FromMont().ToMont() - return c.Equal(&a.element) + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) }, genA, + genB, + genC, + genZ, )) properties.TestingRun(t, gopter.ConsoleReporter(false)) } -type testPairElement struct { - element Element - bigint big.Int -} +func TestElementSetInt64(t *testing.T) { -func (z *Element) biggerOrEqualModulus() bool { - if z[3] > qElement[3] { - return true - } - if z[3] < qElement[3] { - return false + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz } - if z[2] > qElement[2] { - return true + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + } - if z[2] < qElement[2] { - return false + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, } - if z[1] > qElement[1] { - return true + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } } - if z[1] < qElement[1] { - return false + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz } - return z[0] >= qElement[0] + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. + formatValue := func(v int64) string { + var a big.Int + a.SetInt64(v) + a.Mod(&a, Modulus()) + const maxUint16 = 65535 + var aNeg big.Int + aNeg.Neg(&a).Mod(&aNeg, Modulus()) + if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { + return "-" + aNeg.Text(10) + } + return a.Text(10) + } + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int } func gen() gopter.Gen { @@ -1811,7 +2206,7 @@ func gen() gopter.Gen { g.element[3] %= (qElement[3] + 1) } - for g.element.biggerOrEqualModulus() { + for !g.element.smallerThanModulus() { g.element = Element{ genParams.NextUint64(), genParams.NextUint64(), @@ -1823,7 +2218,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -1846,7 +2241,7 @@ func genFull() gopter.Gen { g[3] %= (qElement[3] + 1) } - for g.biggerOrEqualModulus() { + for !g.smallerThanModulus() { g = Element{ genParams.NextUint64(), genParams.NextUint64(), @@ -1872,3 +2267,573 @@ func genFull() gopter.Gen { return genResult } } + +func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + slice := append(z[:], aHi) + + return bigIntMatchUint64Slice(&aIntMod, slice) +} + +// TODO: Phase out in favor of property based testing +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + if err := z.matchVeryBigInt(aHi, aInt); err != nil { + t.Error(err) + } +} + +// bigIntMatchUint64Slice is a test helper to match big.Int words against a uint64 slice +func bigIntMatchUint64Slice(aInt *big.Int, a []uint64) error { + + words := aInt.Bits() + + const steps = 64 / bits.UintSize + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + for i := 0; i < len(a)*steps; i++ { + + var wI big.Word + + if i < len(words) { + wI = words[i] + } + + aI := a[i/steps] >> ((i * bits.UintSize) % 64) + aI &= filter + + if uint64(wI) != aI { + return fmt.Errorf("bignum mismatch: disagreement on word %d: %x ≠ %x; %d ≠ %d", i, uint64(wI), aI, uint64(wI), aI) + } + } + + return nil +} + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs //#nosec G404 weak rng is fine here + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.toBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) //#nosec G404 weak rng is fine here + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.BigInt(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := negL(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() //#nosec G404 weak rng is fine here + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() //#nosec G404 weak rng is fine here + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +type veryBigInt struct { + asInt big.Int + low Element + hi uint64 +} + +// genVeryBigIntSigned if sign == 0, no sign is forced +func genVeryBigIntSigned(sign int) gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g veryBigInt + + g.low = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + g.hi = genParams.NextUint64() + + if sign < 0 { + g.hi |= signBitSelector + } else if sign > 0 { + g.hi &= ^signBitSelector + } + + g.low.toVeryBigIntSigned(&g.asInt, g.hi) + + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func TestElementMontReduce(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := genVeryBigIntSigned(0) + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(g veryBigInt) bool { + var res Element + var resInt big.Int + + montReduce(&resInt, &g.asInt) + res.montReduceSigned(&g.low, g.hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementMontReduceMultipleOfR(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + gen := ggen.UInt64() + + properties.Property("Montgomery reduction is correct", prop.ForAll( + func(hi uint64) bool { + var zero, res Element + var asInt, resInt big.Int + + zero.toVeryBigIntSigned(&asInt, hi) + + montReduce(&resInt, &asInt) + res.montReduceSigned(&zero, hi) + + return res.matchVeryBigInt(0, &resInt) == nil + }, + gen, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElement0Inverse(t *testing.T) { + var x Element + x.Inverse(&x) + if !x.IsZero() { + t.Fail() + } +} + +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) //#nosec G404 weak rng is fine here + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 //#nosec G404 weak rng is fine here + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 //#nosec G404 weak rng is fine here + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 //#nosec G404 weak rng is fine here + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.toBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.toBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearComb(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWNonModular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.toBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.toBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.toBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ff/vector.go b/ff/vector.go new file mode 100644 index 0000000..ac745cc --- /dev/null +++ b/ff/vector.go @@ -0,0 +1,253 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[24:32]) + z[1] = binary.BigEndian.Uint64(b[16:24]) + z[2] = binary.BigEndian.Uint64(b[8:16]) + z[3] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/ff/vector_test.go b/ff/vector_test.go new file mode 100644 index 0000000..ec26477 --- /dev/null +++ b/ff/vector_test.go @@ -0,0 +1,91 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +import ( + "bytes" + "reflect" + "sort" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2, v3 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr +} diff --git a/go.mod b/go.mod index 311ddb2..e7655d5 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,13 @@ require ( github.com/dchest/blake512 v1.0.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.8.2 - golang.org/x/crypto v0.7.0 - golang.org/x/sys v0.6.0 + golang.org/x/crypto v0.10.0 + golang.org/x/sys v0.9.0 ) require ( + github.com/bits-and-blooms/bitset v1.8.0 // indirect + github.com/consensys/gnark-crypto v0.11.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 810c80c..82c77ef 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/bits-and-blooms/bitset v1.8.0 h1:FD+XqgOZDUxxZ8hzoBFuV9+cGWY9CslN6d5MS5JVb4c= +github.com/bits-and-blooms/bitset v1.8.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/consensys/gnark-crypto v0.11.2 h1:GJjjtWJ+db1xGao7vTsOgAOGgjfPe7eRGPL+xxMX0qE= +github.com/consensys/gnark-crypto v0.11.2/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,8 +20,12 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mimc7/mimc7.go b/mimc7/mimc7.go index ddf3ec0..f671858 100644 --- a/mimc7/mimc7.go +++ b/mimc7/mimc7.go @@ -37,13 +37,13 @@ func generateConstantsData() constantsData { func getConstants(seed string, nRounds int) []*ff.Element { cts := make([]*ff.Element, nRounds) - cts[0] = ff.NewElement() + cts[0] = new(ff.Element) c := new(big.Int).SetBytes(keccak256.Hash([]byte(seed))) for i := 1; i < nRounds; i++ { c = new(big.Int).SetBytes(keccak256.Hash(c.Bytes())) n := new(big.Int).Mod(c, _constants.Q) - cts[i] = ff.NewElement().SetBigInt(n) + cts[i] = new(ff.Element).SetBigInt(n) } return cts } @@ -51,23 +51,23 @@ func getConstants(seed string, nRounds int) []*ff.Element { // MIMC7HashGeneric performs the MIMC7 hash over a *big.Int, in a generic way, // where it can be specified the Finite Field over R, and the number of rounds func MIMC7HashGeneric(xInBI, kBI *big.Int, nRounds int) *big.Int { //nolint:golint - xIn := ff.NewElement().SetBigInt(xInBI) - k := ff.NewElement().SetBigInt(kBI) + xIn := new(ff.Element).SetBigInt(xInBI) + k := new(ff.Element).SetBigInt(kBI) cts := getConstants(SEED, nRounds) var r *ff.Element for i := 0; i < nRounds; i++ { var t *ff.Element if i == 0 { - t = ff.NewElement().Add(xIn, k) + t = new(ff.Element).Add(xIn, k) } else { - t = ff.NewElement().Add(ff.NewElement().Add(r, k), cts[i]) + t = new(ff.Element).Add(new(ff.Element).Add(r, k), cts[i]) } - t2 := ff.NewElement().Square(t) - t4 := ff.NewElement().Square(t2) - r = ff.NewElement().Mul(ff.NewElement().Mul(t4, t2), t) + t2 := new(ff.Element).Square(t) + t4 := new(ff.Element).Square(t2) + r = new(ff.Element).Mul(new(ff.Element).Mul(t4, t2), t) } - rE := ff.NewElement().Add(r, k) + rE := new(ff.Element).Add(r, k) res := big.NewInt(0) rE.ToBigIntRegular(res) @@ -94,22 +94,22 @@ func HashGeneric(iv *big.Int, arr []*big.Int, nRounds int) (*big.Int, error) { // MIMC7Hash performs the MIMC7 hash over a *big.Int, using the Finite Field // over R and the number of rounds setted in the `constants` variable func MIMC7Hash(xInBI, kBI *big.Int) *big.Int { //nolint:golint - xIn := ff.NewElement().SetBigInt(xInBI) - k := ff.NewElement().SetBigInt(kBI) + xIn := new(ff.Element).SetBigInt(xInBI) + k := new(ff.Element).SetBigInt(kBI) var r *ff.Element for i := 0; i < constants.nRounds; i++ { var t *ff.Element if i == 0 { - t = ff.NewElement().Add(xIn, k) + t = new(ff.Element).Add(xIn, k) } else { - t = ff.NewElement().Add(ff.NewElement().Add(r, k), constants.cts[i]) + t = new(ff.Element).Add(new(ff.Element).Add(r, k), constants.cts[i]) } - t2 := ff.NewElement().Square(t) - t4 := ff.NewElement().Square(t2) - r = ff.NewElement().Mul(ff.NewElement().Mul(t4, t2), t) + t2 := new(ff.Element).Square(t) + t4 := new(ff.Element).Square(t2) + r = new(ff.Element).Mul(new(ff.Element).Mul(t4, t2), t) } - rE := ff.NewElement().Add(r, k) + rE := new(ff.Element).Add(r, k) res := big.NewInt(0) rE.ToBigIntRegular(res) diff --git a/poseidon/constants.go b/poseidon/constants.go index cc0141b..16821d5 100644 --- a/poseidon/constants.go +++ b/poseidon/constants.go @@ -46,7 +46,7 @@ func init() { if !ok { panic(fmt.Errorf("error parsing constants")) } - cci[j] = ff.NewElement().SetBigInt(b) + cci[j] = new(ff.Element).SetBigInt(b) } c.c[i] = cci } @@ -58,7 +58,7 @@ func init() { if !ok { panic(fmt.Errorf("error parsing constants")) } - csi[j] = ff.NewElement().SetBigInt(b) + csi[j] = new(ff.Element).SetBigInt(b) } c.s[i] = csi } @@ -72,7 +72,7 @@ func init() { if !ok { panic(fmt.Errorf("error parsing constants")) } - cmij[k] = ff.NewElement().SetBigInt(b) + cmij[k] = new(ff.Element).SetBigInt(b) } cmi[j] = cmij } @@ -88,7 +88,7 @@ func init() { if !ok { panic(fmt.Errorf("error parsing constants")) } - cpij[k] = ff.NewElement().SetBigInt(b) + cpij[k] = new(ff.Element).SetBigInt(b) } cpi[j] = cpij } diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index 9edd26c..c035119 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -19,7 +19,7 @@ const spongeChunkSize = 31 const spongeInputs = 16 func zero() *ff.Element { - return ff.NewElement() + return &ff.Element{} } var big5 = big.NewInt(5) @@ -123,7 +123,7 @@ func Hash(inpBI []*big.Int) (*big.Int, error) { rE := state[0] r := big.NewInt(0) - rE.ToBigIntRegular(r) + rE.BigInt(r) return r, nil } diff --git a/utils/utils.go b/utils/utils.go index b126da0..d506c92 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -107,7 +107,7 @@ func CheckBigIntArrayInField(arr []*big.Int) bool { func BigIntArrayToElementArray(bi []*big.Int) []*ff.Element { o := make([]*ff.Element, len(bi)) for i := range bi { - o[i] = ff.NewElement().SetBigInt(bi[i]) + o[i] = new(ff.Element).SetBigInt(bi[i]) } return o } @@ -118,7 +118,7 @@ func ElementArrayToBigIntArray(e []*ff.Element) []*big.Int { for i := range e { ei := e[i] bi := big.NewInt(0) - ei.ToBigIntRegular(bi) + ei.BigInt(bi) o[i] = bi } return o