From 575ee9d2d9ba2fe90942e8fce0c63739a90f6150 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Thu, 26 Sep 2024 10:32:49 +0900 Subject: [PATCH] add missing files --- arithmetic.go | 472 ++++++++++++++++++ arithmetic_test.go | 326 ++++++++++++ bits.go | 12 +- bits_errors.go | 7 +- bits_table.go | 2 +- bitwise.go | 264 ++++++++++ bitwise_test.go | 344 +++++++++++++ cmp.go | 125 +++++ cmp_test.go | 163 ++++++ coversion.go | 570 +++++++++++++++++++++ coversion_test.go | 58 +++ errors.go | 75 +++ i256.go | 253 ---------- mod.go | 605 ++++++++++++++++++++++ u256.go | 1192 +++++--------------------------------------- utils.go | 180 +++++++ 16 files changed, 3311 insertions(+), 1337 deletions(-) create mode 100644 arithmetic.go create mode 100644 arithmetic_test.go create mode 100644 bitwise.go create mode 100644 bitwise_test.go create mode 100644 cmp.go create mode 100644 cmp_test.go create mode 100644 coversion.go create mode 100644 coversion_test.go create mode 100644 errors.go delete mode 100644 i256.go create mode 100644 mod.go create mode 100644 utils.go diff --git a/arithmetic.go b/arithmetic.go new file mode 100644 index 0000000..7c9cf2c --- /dev/null +++ b/arithmetic.go @@ -0,0 +1,472 @@ +// arithmetic provides arithmetic operations for Uint objects. +// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations +// as well as overflow checks, and negation. These functions are essential for numeric +// calculations using 256-bit unsigned integers. +package u256 + +import ( + "math/bits" +) + +// Add sets z to the sum x+y +func (z *Uint) Add(x, y *Uint) *Uint { + var carry uint64 + z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry) + z.arr[3], _ = bits.Add64(x.arr[3], y.arr[3], carry) + return z +} + +// AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred +func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) { + var carry uint64 + z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry) + z.arr[3], carry = bits.Add64(x.arr[3], y.arr[3], carry) + return z, carry != 0 +} + +// Sub sets z to the difference x-y +func (z *Uint) Sub(x, y *Uint) *Uint { + var carry uint64 + z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry) + z.arr[3], _ = bits.Sub64(x.arr[3], y.arr[3], carry) + return z +} + +// SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed +func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) { + var carry uint64 + z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry) + z.arr[3], carry = bits.Sub64(x.arr[3], y.arr[3], carry) + return z, carry != 0 +} + +// Neg returns -x mod 2^256. +func (z *Uint) Neg(x *Uint) *Uint { + return z.Sub(new(Uint), x) +} + +// commented out for possible overflow +// Mul sets z to the product x*y +func (z *Uint) Mul(x, y *Uint) *Uint { + var ( + res Uint + carry uint64 + res1, res2, res3 uint64 + ) + + carry, res.arr[0] = bits.Mul64(x.arr[0], y.arr[0]) + carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) + carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) + res3 = x.arr[3]*y.arr[0] + carry + + carry, res.arr[1] = umulHop(res1, x.arr[0], y.arr[1]) + carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) + res3 = res3 + x.arr[2]*y.arr[1] + carry + + carry, res.arr[2] = umulHop(res2, x.arr[0], y.arr[2]) + res3 = res3 + x.arr[1]*y.arr[2] + carry + + res.arr[3] = res3 + x.arr[0]*y.arr[3] + + return z.Set(&res) +} + +// MulOverflow sets z to the product x*y, and returns z and whether overflow occurred +func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) { + p := umul(x, y) + copy(z.arr[:], p[:4]) + return z, (p[4] | p[5] | p[6] | p[7]) != 0 +} + +// commented out for possible overflow +// Div sets z to the quotient x/y for returns z. +// If y == 0, z is set to 0 +func (z *Uint) Div(x, y *Uint) *Uint { + if y.IsZero() || y.Gt(x) { + return z.Clear() + } + if x.Eq(y) { + return z.SetOne() + } + // Shortcut some cases + if x.IsUint64() { + return z.SetUint64(x.Uint64() / y.Uint64()) + } + + // At this point, we know + // x/y ; x > y > 0 + + var quot Uint + udivrem(quot.arr[:], x.arr[:], y) + return z.Set(") +} + +// MulMod calculates the modulo-m multiplication of x and y and +// returns z. +// If m == 0, z is set to 0 (OBS: differs from the big.Int) +func (z *Uint) MulMod(x, y, m *Uint) *Uint { + if x.IsZero() || y.IsZero() || m.IsZero() { + return z.Clear() + } + p := umul(x, y) + + if m.arr[3] != 0 { + mu := Reciprocal(m) + r := reduce4(p, m, mu) + return z.Set(&r) + } + + var ( + pl Uint + ph Uint + ) + + pl = Uint{arr: [4]uint64{p[0], p[1], p[2], p[3]}} + ph = Uint{arr: [4]uint64{p[4], p[5], p[6], p[7]}} + + // If the multiplication is within 256 bits use Mod(). + if ph.IsZero() { + return z.Mod(&pl, m) + } + + var quot [8]uint64 + rem := udivrem(quot[:], p[:], m) + return z.Set(&rem) +} + +// Mod sets z to the modulus x%y for y != 0 and returns z. +// If y == 0, z is set to 0 (OBS: differs from the big.Uint) +func (z *Uint) Mod(x, y *Uint) *Uint { + if x.IsZero() || y.IsZero() { + return z.Clear() + } + switch x.Cmp(y) { + case -1: + // x < y + copy(z.arr[:], x.arr[:]) + return z + case 0: + // x == y + return z.Clear() // They are equal + } + + // At this point: + // x != 0 + // y != 0 + // x > y + + // Shortcut trivial case + if x.IsUint64() { + return z.SetUint64(x.Uint64() % y.Uint64()) + } + + var quot Uint + *z = udivrem(quot.arr[:], x.arr[:], y) + return z +} + +// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0. +// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int) +func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) { + if y.IsZero() { + return z.Clear(), m.Clear() + } + var quot Uint + *m = udivrem(quot.arr[:], x.arr[:], y) + *z = quot + return z, m +} + +// Exp sets z = base**exponent mod 2**256, and returns z. +func (z *Uint) Exp(base, exponent *Uint) *Uint { + res := Uint{arr: [4]uint64{1, 0, 0, 0}} + multiplier := *base + expBitLen := exponent.BitLen() + + curBit := 0 + word := exponent.arr[0] + for ; curBit < expBitLen && curBit < 64; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[1] + for ; curBit < expBitLen && curBit < 128; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[2] + for ; curBit < expBitLen && curBit < 192; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[3] + for ; curBit < expBitLen && curBit < 256; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + return z.Set(&res) +} + +func (z *Uint) squared() { + var ( + res Uint + carry0, carry1, carry2 uint64 + res1, res2 uint64 + ) + + carry0, res.arr[0] = bits.Mul64(z.arr[0], z.arr[0]) + carry0, res1 = umulHop(carry0, z.arr[0], z.arr[1]) + carry0, res2 = umulHop(carry0, z.arr[0], z.arr[2]) + + carry1, res.arr[1] = umulHop(res1, z.arr[0], z.arr[1]) + carry1, res2 = umulStep(res2, z.arr[1], z.arr[1], carry1) + + carry2, res.arr[2] = umulHop(res2, z.arr[0], z.arr[2]) + + res.arr[3] = 2*(z.arr[0]*z.arr[3]+z.arr[1]*z.arr[2]) + carry0 + carry1 + carry2 + + z.Set(&res) +} + +// udivrem divides u by d and produces both quotient and remainder. +// The quotient is stored in provided quot - len(u)-len(d)+1 words. +// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words. +// See Knuth, Volume 2, section 4.3.1, Algorithm D. +func udivrem(quot, u []uint64, d *Uint) (rem Uint) { + var dLen int + for i := len(d.arr) - 1; i >= 0; i-- { + if d.arr[i] != 0 { + dLen = i + 1 + break + } + } + + shift := uint(bits.LeadingZeros64(d.arr[dLen-1])) + + var dnStorage Uint + dn := dnStorage.arr[:dLen] + for i := dLen - 1; i > 0; i-- { + dn[i] = (d.arr[i] << shift) | (d.arr[i-1] >> (64 - shift)) + } + dn[0] = d.arr[0] << shift + + var uLen int + for i := len(u) - 1; i >= 0; i-- { + if u[i] != 0 { + uLen = i + 1 + break + } + } + + if uLen < dLen { + copy(rem.arr[:], u) + return rem + } + + var unStorage [9]uint64 + un := unStorage[:uLen+1] + un[uLen] = u[uLen-1] >> (64 - shift) + for i := uLen - 1; i > 0; i-- { + un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift)) + } + un[0] = u[0] << shift + + // TODO: Skip the highest word of numerator if not significant. + + if dLen == 1 { + r := udivremBy1(quot, un, dn[0]) + rem.SetUint64(r >> shift) + return rem + } + + udivremKnuth(quot, un, dn) + + for i := 0; i < dLen-1; i++ { + rem.arr[i] = (un[i] >> shift) | (un[i+1] << (64 - shift)) + } + rem.arr[dLen-1] = un[dLen-1] >> shift + + return rem +} + +// umul computes full 256 x 256 -> 512 multiplication. +func umul(x, y *Uint) [8]uint64 { + var ( + res [8]uint64 + carry, carry4, carry5, carry6 uint64 + res1, res2, res3, res4, res5 uint64 + ) + + carry, res[0] = bits.Mul64(x.arr[0], y.arr[0]) + carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) + carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) + carry4, res3 = umulHop(carry, x.arr[3], y.arr[0]) + + carry, res[1] = umulHop(res1, x.arr[0], y.arr[1]) + carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) + carry, res3 = umulStep(res3, x.arr[2], y.arr[1], carry) + carry5, res4 = umulStep(carry4, x.arr[3], y.arr[1], carry) + + carry, res[2] = umulHop(res2, x.arr[0], y.arr[2]) + carry, res3 = umulStep(res3, x.arr[1], y.arr[2], carry) + carry, res4 = umulStep(res4, x.arr[2], y.arr[2], carry) + carry6, res5 = umulStep(carry5, x.arr[3], y.arr[2], carry) + + carry, res[3] = umulHop(res3, x.arr[0], y.arr[3]) + carry, res[4] = umulStep(res4, x.arr[1], y.arr[3], carry) + carry, res[5] = umulStep(res5, x.arr[2], y.arr[3], carry) + res[7], res[6] = umulStep(carry6, x.arr[3], y.arr[3], carry) + + return res +} + +// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry. +func umulStep(z, x, y, carry uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry = bits.Add64(lo, carry, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// umulHop computes (hi * 2^64 + lo) = z + (x * y) +func umulHop(z, x, y uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry := bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// udivremBy1 divides u by single normalized word d and produces both quotient and remainder. +// The quotient is stored in provided quot. +func udivremBy1(quot, u []uint64, d uint64) (rem uint64) { + reciprocal := reciprocal2by1(d) + rem = u[len(u)-1] // Set the top word as remainder. + for j := len(u) - 2; j >= 0; j-- { + quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal) + } + return rem +} + +// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm. +// The quotient is stored in provided quot - len(u)-len(d) words. +// Updates u to contain the remainder - len(d) words. +func udivremKnuth(quot, u, d []uint64) { + dh := d[len(d)-1] + dl := d[len(d)-2] + reciprocal := reciprocal2by1(dh) + + for j := len(u) - len(d) - 1; j >= 0; j-- { + u2 := u[j+len(d)] + u1 := u[j+len(d)-1] + u0 := u[j+len(d)-2] + + var qhat, rhat uint64 + if u2 >= dh { // Division overflows. + qhat = ^uint64(0) + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } else { + qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal) + ph, pl := bits.Mul64(qhat, dl) + if ph > rhat || (ph == rhat && pl > u0) { + qhat-- + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } + } + + // Multiply and subtract. + borrow := subMulTo(u[j:], d, qhat) + u[j+len(d)] = u2 - borrow + if u2 < borrow { // Too much subtracted, add back. + qhat-- + u[j+len(d)] += addTo(u[j:], d) + } + + quot[j] = qhat // Store quotient digit. + } +} + +// isBitSet returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (z *Uint) isBitSet(n uint) bool { + return (z.arr[n/64] & (1 << (n % 64))) != 0 +} + +// addTo computes x += y. +// Requires len(x) >= len(y). +func addTo(x, y []uint64) uint64 { + var carry uint64 + for i := 0; i < len(y); i++ { + x[i], carry = bits.Add64(x[i], y[i], carry) + } + return carry +} + +// subMulTo computes x -= y * multiplier. +// Requires len(x) >= len(y). +func subMulTo(x, y []uint64, multiplier uint64) uint64 { + var borrow uint64 + for i := 0; i < len(y); i++ { + s, carry1 := bits.Sub64(x[i], borrow, 0) + ph, pl := bits.Mul64(y[i], multiplier) + t, carry2 := bits.Sub64(s, pl, 0) + x[i] = t + borrow = ph + carry1 + carry2 + } + return borrow +} + +// reciprocal2by1 computes <^d, ^0> / d. +func reciprocal2by1(d uint64) uint64 { + reciprocal, _ := bits.Div64(^d, ^uint64(0), d) + return reciprocal +} + +// udivrem2by1 divides / d and produces both quotient and remainder. +// It uses the provided d's reciprocal. +// Implementation ported from https://github.com/chfast/intx and is based on +// "Improved division by invariant integers", Algorithm 4. +func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) { + qh, ql := bits.Mul64(reciprocal, uh) + ql, carry := bits.Add64(ql, ul, 0) + qh, _ = bits.Add64(qh, uh, carry) + qh++ + + r := ul - qh*d + + if r > ql { + qh-- + r += d + } + + if r >= d { + qh++ + r -= d + } + + return qh, r +} diff --git a/arithmetic_test.go b/arithmetic_test.go new file mode 100644 index 0000000..ca0e45e --- /dev/null +++ b/arithmetic_test.go @@ -0,0 +1,326 @@ +package u256 + +import "testing" + +type binOp2Test struct { + x, y, want string +} + +func TestAdd(t *testing.T) { + tests := []binOp2Test{ + {"0", "1", "1"}, + {"1", "0", "1"}, + {"1", "1", "2"}, + {"1", "3", "4"}, + {"10", "10", "20"}, + {"18446744073709551615", "18446744073709551615", "36893488147419103230"}, // uint64 overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Add(x, y) + + if got.Neq(want) { + t.Errorf("Add(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestSub(t *testing.T) { + tests := []binOp2Test{ + {"1", "0", "1"}, + {"1", "1", "0"}, + {"10", "10", "0"}, + {"31337", "1337", "30000"}, + {"2", "3", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, // underflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Sub(x, y) + + if got.Neq(want) { + t.Errorf("Sub(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMul(t *testing.T) { + tests := []binOp2Test{ + {"1", "0", "0"}, + {"1", "1", "1"}, + {"10", "10", "100"}, + {"18446744073709551615", "2", "36893488147419103230"}, // uint64 overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Mul(x, y) + + if got.Neq(want) { + t.Errorf("Mul(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestDiv(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "10445"}, + {"31337", "0", "0"}, + {"0", "31337", "0"}, + {"1", "1", "1"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Div(x, y) + + if got.Neq(want) { + t.Errorf("Div(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMod(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "2"}, + {"31337", "0", "0"}, + {"0", "31337", "0"}, + {"2", "31337", "2"}, + {"1", "1", "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Mod(x, y) + + if got.Neq(want) { + t.Errorf("Mod(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestDivMod(t *testing.T) { + tests := []struct { + x string + y string + wantDiv string + wantMod string + }{ + {"1", "1", "1", "0"}, + {"10", "10", "1", "0"}, + {"100", "10", "10", "0"}, + {"31337", "3", "10445", "2"}, + {"31337", "0", "0", "0"}, + {"0", "31337", "0", "0"}, + {"2", "31337", "0", "2"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + wantDiv, err := FromDecimal(tc.wantDiv) + if err != nil { + t.Error(err) + continue + } + + wantMod, err := FromDecimal(tc.wantMod) + if err != nil { + t.Error(err) + continue + } + + gotDiv := new(Uint) + gotMod := new(Uint) + gotDiv.DivMod(x, y, gotMod) + + for i := range gotDiv.arr { + if gotDiv.arr[i] != wantDiv.arr[i] { + t.Errorf("DivMod(%s, %s) got Div %v, want Div %v", tc.x, tc.y, gotDiv, wantDiv) + break + } + } + for i := range gotMod.arr { + if gotMod.arr[i] != wantMod.arr[i] { + t.Errorf("DivMod(%s, %s) got Mod %v, want Mod %v", tc.x, tc.y, gotMod, wantMod) + break + } + } + } +} + +func TestNeg(t *testing.T) { + tests := []struct { + x string + want string + }{ + {"31337", "115792089237316195423570985008687907853269984665640564039457584007913129608599"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129608599", "31337"}, + {"0", "0"}, + {"2", "115792089237316195423570985008687907853269984665640564039457584007913129639934"}, + {"1", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Neg(x) + + if got.Neq(want) { + t.Errorf("Neg(%s) = %v, want %v", tc.x, got.ToString(), want.ToString()) + } + } +} + +func TestExp(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "30773171189753"}, + {"31337", "0", "1"}, + {"0", "31337", "0"}, + {"1", "1", "1"}, + {"2", "3", "8"}, + {"2", "64", "18446744073709551616"}, + {"2", "128", "340282366920938463463374607431768211456"}, + {"2", "255", "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"2", "256", "0"}, // overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Exp(x, y) + + if got.Neq(want) { + t.Errorf("Exp(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} diff --git a/bits.go b/bits.go index 7195d38..4fc7670 100644 --- a/bits.go +++ b/bits.go @@ -1,5 +1,3 @@ - - // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -16,7 +14,7 @@ // architecture and the Go release. package u256 -const uintSize = 32 << (^uint(0) >> 63) // 32 or 64 +// const uintSize = 32 << (^uint(0) >> 63) // 32 or 64 // UintSize is the size of a uint in bits. const UintSize = uintSize @@ -506,7 +504,7 @@ func Div(hi, lo, y uint) (quo, rem uint) { // Div32 panics for y == 0 (division by zero) or y <= hi (quotient overflow). func Div32(hi, lo, y uint32) (quo, rem uint32) { if y != 0 && y <= hi { - panic(overflowError) + panic(errOverflow) } z := uint64(hi)<<32 | uint64(lo) quo, rem = uint32(z/uint64(y)), uint32(z%uint64(y)) @@ -519,10 +517,10 @@ func Div32(hi, lo, y uint32) (quo, rem uint32) { // Div64 panics for y == 0 (division by zero) or y <= hi (quotient overflow). func Div64(hi, lo, y uint64) (quo, rem uint64) { if y == 0 { - panic(divideError) + panic(errDivide) } if y <= hi { - panic(overflowError) + panic(errOverflow) } // If high part is zero, we can directly return the results. @@ -598,4 +596,4 @@ func Rem64(hi, lo, y uint64) uint64 { // hi<<64 + lo ≡ (hi%y)<<64 + lo (mod y) _, rem := Div64(hi%y, lo, y) return rem -} \ No newline at end of file +} diff --git a/bits_errors.go b/bits_errors.go index 30b8ae3..ad2a847 100644 --- a/bits_errors.go +++ b/bits_errors.go @@ -10,8 +10,5 @@ import ( "errors" ) -//go:linkname overflowError runtime.overflowError -var overflowError error = errors.New("u256: integer overflow") - -//go:linkname divideError runtime.divideError -var divideError error = errors.New("u256: integer divide by zero") \ No newline at end of file +var errOverflow error = errors.New("u256: integer overflow") +var errDivide error = errors.New("u256: integer divide by zero") diff --git a/bits_table.go b/bits_table.go index 4d7c539..98eeeb1 100644 --- a/bits_table.go +++ b/bits_table.go @@ -76,4 +76,4 @@ const len8tab = "" + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + - "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" \ No newline at end of file + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" diff --git a/bitwise.go b/bitwise.go new file mode 100644 index 0000000..746d4a8 --- /dev/null +++ b/bitwise.go @@ -0,0 +1,264 @@ +// bitwise contains bitwise operations for Uint instances. +// This file includes functions to perform bitwise AND, OR, XOR, and NOT operations, as well as bit shifting. +// These operations are crucial for manipulating individual bits within a 256-bit unsigned integer. +package u256 + +// Or sets z = x | y and returns z. +func (z *Uint) Or(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] | y.arr[0] + z.arr[1] = x.arr[1] | y.arr[1] + z.arr[2] = x.arr[2] | y.arr[2] + z.arr[3] = x.arr[3] | y.arr[3] + return z +} + +// And sets z = x & y and returns z. +func (z *Uint) And(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] & y.arr[0] + z.arr[1] = x.arr[1] & y.arr[1] + z.arr[2] = x.arr[2] & y.arr[2] + z.arr[3] = x.arr[3] & y.arr[3] + return z +} + +// Not sets z = ^x and returns z. +func (z *Uint) Not(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = ^x.arr[3], ^x.arr[2], ^x.arr[1], ^x.arr[0] + return z +} + +// AndNot sets z = x &^ y and returns z. +func (z *Uint) AndNot(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] &^ y.arr[0] + z.arr[1] = x.arr[1] &^ y.arr[1] + z.arr[2] = x.arr[2] &^ y.arr[2] + z.arr[3] = x.arr[3] &^ y.arr[3] + return z +} + +// Xor sets z = x ^ y and returns z. +func (z *Uint) Xor(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] ^ y.arr[0] + z.arr[1] = x.arr[1] ^ y.arr[1] + z.arr[2] = x.arr[2] ^ y.arr[2] + z.arr[3] = x.arr[3] ^ y.arr[3] + return z +} + +// Lsh sets z = x << n and returns z. +func (z *Uint) Lsh(x *Uint, n uint) *Uint { + // n % 64 == 0 + if n&0x3f == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.lsh64(x) + case 128: + return z.lsh128(x) + case 192: + return z.lsh192(x) + default: + return z.Clear() + } + } + var a, b uint64 + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.Clear() + } + z.lsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.lsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.lsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + a = z.arr[0] >> (64 - n) + z.arr[0] = z.arr[0] << n + +sh64: + b = z.arr[1] >> (64 - n) + z.arr[1] = (z.arr[1] << n) | a + +sh128: + a = z.arr[2] >> (64 - n) + z.arr[2] = (z.arr[2] << n) | b + +sh192: + z.arr[3] = (z.arr[3] << n) | a + + return z +} + +// Rsh sets z = x >> n and returns z. +func (z *Uint) Rsh(x *Uint, n uint) *Uint { + // n % 64 == 0 + if n&0x3f == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.rsh64(x) + case 128: + return z.rsh128(x) + case 192: + return z.rsh192(x) + default: + return z.Clear() + } + } + var a, b uint64 + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.Clear() + } + z.rsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.rsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.rsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + a = z.arr[3] << (64 - n) + z.arr[3] = z.arr[3] >> n + +sh64: + b = z.arr[2] << (64 - n) + z.arr[2] = (z.arr[2] >> n) | a + +sh128: + a = z.arr[1] << (64 - n) + z.arr[1] = (z.arr[1] >> n) | b + +sh192: + z.arr[0] = (z.arr[0] >> n) | a + + return z +} + +// SRsh (Signed/Arithmetic right shift) +// considers z to be a signed integer, during right-shift +// and sets z = x >> n and returns z. +func (z *Uint) SRsh(x *Uint, n uint) *Uint { + // If the MSB is 0, SRsh is same as Rsh. + if !x.isBitSet(255) { + return z.Rsh(x, n) + } + if n%64 == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.srsh64(x) + case 128: + return z.srsh128(x) + case 192: + return z.srsh192(x) + default: + return z.SetAllOne() + } + } + var a uint64 = MaxUint64 << (64 - n%64) + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.SetAllOne() + } + z.srsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.srsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.srsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + z.arr[3], a = (z.arr[3]>>n)|a, z.arr[3]<<(64-n) + +sh64: + z.arr[2], a = (z.arr[2]>>n)|a, z.arr[2]<<(64-n) + +sh128: + z.arr[1], a = (z.arr[1]>>n)|a, z.arr[1]<<(64-n) + +sh192: + z.arr[0] = (z.arr[0] >> n) | a + + return z +} + +func (z *Uint) lsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[2], x.arr[1], x.arr[0], 0 + return z +} + +func (z *Uint) lsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[1], x.arr[0], 0, 0 + return z +} + +func (z *Uint) lsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[0], 0, 0, 0 + return z +} + +func (z *Uint) rsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, x.arr[3], x.arr[2], x.arr[1] + return z +} + +func (z *Uint) rsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, x.arr[3], x.arr[2] + return z +} + +func (z *Uint) rsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x.arr[3] + return z +} + +func (z *Uint) srsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, x.arr[3], x.arr[2], x.arr[1] + return z +} + +func (z *Uint) srsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, x.arr[3], x.arr[2] + return z +} + +func (z *Uint) srsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, MaxUint64, x.arr[3] + return z +} diff --git a/bitwise_test.go b/bitwise_test.go new file mode 100644 index 0000000..e94d5fb --- /dev/null +++ b/bitwise_test.go @@ -0,0 +1,344 @@ +package u256 + +import "testing" + +type logicOpTest struct { + name string + x Uint + y Uint + want Uint +} + +func TestOr(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Or(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("Or(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestAnd(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).And(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("And(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestNot(t *testing.T) { + tests := []struct { + name string + x Uint + want Uint + }{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Not(&tc.x) + if *res != tc.want { + t.Errorf("Not(%s) = %s, want %s", tc.x.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestAndNot(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0x0000000000000000, ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).AndNot(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("AndNot(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestXor(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0x0000000000000000, ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Xor(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("Xor(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + x string + y uint + want string + }{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 64, "0"}, + {"1", 0, "1"}, + {"1", 1, "2"}, + {"1", 64, "18446744073709551616"}, + {"1", 128, "340282366920938463463374607431768211456"}, + {"1", 192, "6277101735386680763835789423207666416102355444464034512896"}, + {"1", 255, "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"1", 256, "0"}, + {"31337", 0, "31337"}, + {"31337", 1, "62674"}, + {"31337", 64, "578065619037836218990592"}, + {"31337", 128, "10663428532201448629551770073089320442396672"}, + {"31337", 192, "196705537081812415096322133155058642481399512563169449530621952"}, + {"31337", 193, "393411074163624830192644266310117284962799025126338899061243904"}, + {"31337", 255, "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"31337", 256, "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Lsh(x, tc.y) + + if got.Neq(want) { + t.Errorf("Lsh(%s, %d) = %s, want %s", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + x string + y uint + want string + }{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 64, "0"}, + {"1", 0, "1"}, + {"1", 1, "0"}, + {"1", 64, "0"}, + {"1", 128, "0"}, + {"1", 192, "0"}, + {"1", 255, "0"}, + {"57896044618658097711785492504343953926634992332820282019728792003956564819968", 255, "1"}, + {"6277101735386680763835789423207666416102355444464034512896", 192, "1"}, + {"340282366920938463463374607431768211456", 128, "1"}, + {"18446744073709551616", 64, "1"}, + {"393411074163624830192644266310117284962799025126338899061243904", 193, "31337"}, + {"196705537081812415096322133155058642481399512563169449530621952", 192, "31337"}, + {"10663428532201448629551770073089320442396672", 128, "31337"}, + {"578065619037836218990592", 64, "31337"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Rsh(x, tc.y) + + if got.Neq(want) { + t.Errorf("Rsh(%s, %d) = %s, want %s", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} diff --git a/cmp.go b/cmp.go new file mode 100644 index 0000000..dc309fb --- /dev/null +++ b/cmp.go @@ -0,0 +1,125 @@ +// cmp (or, comparisons) includes methods for comparing Uint instances. +// These comparison functions cover a range of operations including equality checks, less than/greater than +// evaluations, and specialized comparisons such as signed greater than. These are fundamental for logical +// decision making based on Uint values. +package u256 + +import ( + "math/bits" +) + +// Cmp compares z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Uint) Cmp(x *Uint) (r int) { + // z < x <=> z - x < 0 i.e. when subtraction overflows. + d0, carry := bits.Sub64(z.arr[0], x.arr[0], 0) + d1, carry := bits.Sub64(z.arr[1], x.arr[1], carry) + d2, carry := bits.Sub64(z.arr[2], x.arr[2], carry) + d3, carry := bits.Sub64(z.arr[3], x.arr[3], carry) + if carry == 1 { + return -1 + } + if d0|d1|d2|d3 == 0 { + return 0 + } + return 1 +} + +// IsZero returns true if z == 0 +func (z *Uint) IsZero() bool { + return (z.arr[0] | z.arr[1] | z.arr[2] | z.arr[3]) == 0 +} + +// Sign returns: +// +// -1 if z < 0 +// 0 if z == 0 +// +1 if z > 0 +// +// Where z is interpreted as a two's complement signed number +func (z *Uint) Sign() int { + if z.IsZero() { + return 0 + } + if z.arr[3] < 0x8000000000000000 { + return 1 + } + return -1 +} + +// LtUint64 returns true if z is smaller than n +func (z *Uint) LtUint64(n uint64) bool { + return z.arr[0] < n && (z.arr[1]|z.arr[2]|z.arr[3]) == 0 +} + +// GtUint64 returns true if z is larger than n +func (z *Uint) GtUint64(n uint64) bool { + return z.arr[0] > n || (z.arr[1]|z.arr[2]|z.arr[3]) != 0 +} + +// Lt returns true if z < x +func (z *Uint) Lt(x *Uint) bool { + // z < x <=> z - x < 0 i.e. when subtraction overflows. + _, carry := bits.Sub64(z.arr[0], x.arr[0], 0) + _, carry = bits.Sub64(z.arr[1], x.arr[1], carry) + _, carry = bits.Sub64(z.arr[2], x.arr[2], carry) + _, carry = bits.Sub64(z.arr[3], x.arr[3], carry) + + return carry != 0 +} + +// Gt returns true if z > x +func (z *Uint) Gt(x *Uint) bool { + return x.Lt(z) +} + +// Lte returns true if z <= x +func (z *Uint) Lte(x *Uint) bool { + cond1 := z.Lt(x) + cond2 := z.Eq(x) + + if cond1 || cond2 { + return true + } + return false +} + +// Gte returns true if z >= x +func (z *Uint) Gte(x *Uint) bool { + cond1 := z.Gt(x) + cond2 := z.Eq(x) + + if cond1 || cond2 { + return true + } + return false +} + +// Eq returns true if z == x +func (z *Uint) Eq(x *Uint) bool { + return (z.arr[0] == x.arr[0]) && (z.arr[1] == x.arr[1]) && (z.arr[2] == x.arr[2]) && (z.arr[3] == x.arr[3]) +} + +// Neq returns true if z != x +func (z *Uint) Neq(x *Uint) bool { + return !z.Eq(x) +} + +// Sgt interprets z and x as signed integers, and returns +// true if z > x +func (z *Uint) Sgt(x *Uint) bool { + zSign := z.Sign() + xSign := x.Sign() + + switch { + case zSign >= 0 && xSign < 0: + return true + case zSign < 0 && xSign >= 0: + return false + default: + return z.Gt(x) + } +} diff --git a/cmp_test.go b/cmp_test.go new file mode 100644 index 0000000..e54b223 --- /dev/null +++ b/cmp_test.go @@ -0,0 +1,163 @@ +package u256 + +import ( + "strings" + "testing" +) + +func TestCmp(t *testing.T) { + tests := []struct { + x, y string + want int + }{ + {"0", "0", 0}, + {"0", "1", -1}, + {"1", "0", 1}, + {"1", "1", 0}, + {"10", "10", 0}, + {"10", "11", -1}, + {"11", "10", 1}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Cmp(y) + if got != tc.want { + t.Errorf("Cmp(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestIsZero(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0", true}, + {"1", false}, + {"10", false}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := x.IsZero() + if got != tc.want { + t.Errorf("IsZero(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestLtUint64(t *testing.T) { + tests := []struct { + x string + y uint64 + want bool + }{ + {"0", 1, true}, + {"1", 0, false}, + {"10", 10, false}, + {"0xffffffffffffffff", 0, false}, + {"0x10000000000000000", 10000000000000000, false}, + } + + for _, tc := range tests { + var x *Uint + var err error + + if strings.HasPrefix(tc.x, "0x") { + x, err = FromHex(tc.x) + if err != nil { + t.Error(err) + continue + } + } else { + x, err = FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + } + + got := x.LtUint64(tc.y) + + if got != tc.want { + t.Errorf("LtUint64(%s, %d) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestSGT(t *testing.T) { + x := MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") + y := MustFromHex("0x0") + actual := x.Sgt(y) + if actual { + t.Fatalf("Expected %v false", actual) + } + + x = MustFromHex("0x0") + y = MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") + actual = x.Sgt(y) + if !actual { + t.Fatalf("Expected %v true", actual) + } +} + +func TestEq(t *testing.T) { + tests := []struct { + x string + y string + want bool + }{ + {"0xffffffffffffffff", "18446744073709551615", true}, + {"0x10000000000000000", "18446744073709551616", true}, + {"0", "0", true}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, + } + + for _, tc := range tests { + var x *Uint + var err error + + if strings.HasPrefix(tc.x, "0x") { + x, err = FromHex(tc.x) + if err != nil { + t.Error(err) + continue + } + } else { + x, err = FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Eq(y) + + if got != tc.want { + t.Errorf("Eq(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} diff --git a/coversion.go b/coversion.go new file mode 100644 index 0000000..26b4b6c --- /dev/null +++ b/coversion.go @@ -0,0 +1,570 @@ +// conversions contains methods for converting Uint instances to other types and vice versa. +// This includes conversions to and from basic types such as uint64 and int32, as well as string representations +// and byte slices. Additionally, it covers marshaling and unmarshaling for JSON and other text formats. +package u256 + +import ( + "encoding/binary" + "errors" + "strconv" + "strings" +) + +// Uint64 returns the lower 64-bits of z +func (z *Uint) Uint64() uint64 { + return z.arr[0] +} + +// Uint64WithOverflow returns the lower 64-bits of z and bool whether overflow occurred +func (z *Uint) Uint64WithOverflow() (uint64, bool) { + return z.arr[0], (z.arr[1] | z.arr[2] | z.arr[3]) != 0 +} + +// SetUint64 sets z to the value x +func (z *Uint) SetUint64(x uint64) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x + return z +} + +// IsUint64 reports whether z can be represented as a uint64. +func (z *Uint) IsUint64() bool { + return (z.arr[1] | z.arr[2] | z.arr[3]) == 0 +} + +// Dec returns the decimal representation of z. +func (z *Uint) Dec() string { + if z.IsZero() { + return "0" + } + if z.IsUint64() { + return strconv.FormatUint(z.Uint64(), 10) + } + + // The max uint64 value being 18446744073709551615, the largest + // power-of-ten below that is 10000000000000000000. + // When we do a DivMod using that number, the remainder that we + // get back is the lower part of the output. + // + // The ascii-output of remainder will never exceed 19 bytes (since it will be + // below 10000000000000000000). + // + // Algorithm example using 100 as divisor + // + // 12345 % 100 = 45 (rem) + // 12345 / 100 = 123 (quo) + // -> output '45', continue iterate on 123 + var ( + // out is 98 bytes long: 78 (max size of a string without leading zeroes, + // plus slack so we can copy 19 bytes every iteration). + // We init it with zeroes, because when strconv appends the ascii representations, + // it will omit leading zeroes. + out = []byte("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + divisor = NewUint(10000000000000000000) // 20 digits + y = new(Uint).Set(z) // copy to avoid modifying z + pos = len(out) // position to write to + buf = make([]byte, 0, 19) // buffer to write uint64:s to + ) + for { + // Obtain Q and R for divisor + var quot Uint + rem := udivrem(quot.arr[:], y.arr[:], divisor) + y.Set(") // Set Q for next loop + // Convert the R to ascii representation + buf = strconv.AppendUint(buf[:0], rem.Uint64(), 10) + // Copy in the ascii digits + copy(out[pos-len(buf):], buf) + if y.IsZero() { + break + } + // Move 19 digits left + pos -= 19 + } + // skip leading zeroes by only using the 'used size' of buf + return string(out[pos-len(buf):]) +} + +func (z *Uint) Scan(src interface{}) error { + if src == nil { + z.Clear() + return nil + } + + switch src := src.(type) { + case string: + return z.scanScientificFromString(src) + case []byte: + return z.scanScientificFromString(string(src)) + } + return errors.New("default // unsupported type: can't convert to uint256.Uint") +} + +func (z *Uint) scanScientificFromString(src string) error { + if len(src) == 0 { + z.Clear() + return nil + } + + idx := strings.IndexByte(src, 'e') + if idx == -1 { + return z.SetFromDecimal(src) + } + if err := z.SetFromDecimal(src[:idx]); err != nil { + return err + } + if src[(idx+1):] == "0" { + return nil + } + exp := new(Uint) + if err := exp.SetFromDecimal(src[(idx + 1):]); err != nil { + return err + } + if exp.GtUint64(77) { // 10**78 is larger than 2**256 + return ErrBig256Range + } + exp.Exp(NewUint(10), exp) + if _, overflow := z.MulOverflow(z, exp); overflow { + return ErrBig256Range + } + return nil +} + +// ToString returns the decimal string representation of z. It returns an empty string if z is nil. +// OBS: doesn't exist from holiman's uint256 +func (z *Uint) ToString() string { + if z == nil { + return "" + } + + return z.Dec() +} + +// MarshalJSON implements json.Marshaler. +// MarshalJSON marshals using the 'decimal string' representation. This is _not_ compatible +// with big.Uint: big.Uint marshals into JSON 'native' numeric format. +// +// The JSON native format is, on some platforms, (e.g. javascript), limited to 53-bit large +// integer space. Thus, U256 uses string-format, which is not compatible with +// big.int (big.Uint refuses to unmarshal a string representation). +func (z *Uint) MarshalJSON() ([]byte, error) { + return []byte(`"` + z.Dec() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. UnmarshalJSON accepts either +// - Quoted string: either hexadecimal OR decimal +// - Not quoted string: only decimal +func (z *Uint) UnmarshalJSON(input []byte) error { + if len(input) < 2 || input[0] != '"' || input[len(input)-1] != '"' { + // if not quoted, it must be decimal + return z.fromDecimal(string(input)) + } + return z.UnmarshalText(input[1 : len(input)-1]) +} + +// MarshalText implements encoding.TextMarshaler +// MarshalText marshals using the decimal representation (compatible with big.Uint) +func (z *Uint) MarshalText() ([]byte, error) { + return []byte(z.Dec()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. This method +// can unmarshal either hexadecimal or decimal. +// - For hexadecimal, the input _must_ be prefixed with 0x or 0X +func (z *Uint) UnmarshalText(input []byte) error { + if len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X') { + return z.fromHex(string(input)) + } + return z.fromDecimal(string(input)) +} + +// SetBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +// If buf is larger than 32 bytes, the last 32 bytes is used. +func (z *Uint) SetBytes(buf []byte) *Uint { + switch l := len(buf); l { + case 0: + z.Clear() + case 1: + z.SetBytes1(buf) + case 2: + z.SetBytes2(buf) + case 3: + z.SetBytes3(buf) + case 4: + z.SetBytes4(buf) + case 5: + z.SetBytes5(buf) + case 6: + z.SetBytes6(buf) + case 7: + z.SetBytes7(buf) + case 8: + z.SetBytes8(buf) + case 9: + z.SetBytes9(buf) + case 10: + z.SetBytes10(buf) + case 11: + z.SetBytes11(buf) + case 12: + z.SetBytes12(buf) + case 13: + z.SetBytes13(buf) + case 14: + z.SetBytes14(buf) + case 15: + z.SetBytes15(buf) + case 16: + z.SetBytes16(buf) + case 17: + z.SetBytes17(buf) + case 18: + z.SetBytes18(buf) + case 19: + z.SetBytes19(buf) + case 20: + z.SetBytes20(buf) + case 21: + z.SetBytes21(buf) + case 22: + z.SetBytes22(buf) + case 23: + z.SetBytes23(buf) + case 24: + z.SetBytes24(buf) + case 25: + z.SetBytes25(buf) + case 26: + z.SetBytes26(buf) + case 27: + z.SetBytes27(buf) + case 28: + z.SetBytes28(buf) + case 29: + z.SetBytes29(buf) + case 30: + z.SetBytes30(buf) + case 31: + z.SetBytes31(buf) + default: + z.SetBytes32(buf[l-32:]) + } + return z +} + +// SetBytes1 is identical to SetBytes(in[:1]), but panics is input is too short +func (z *Uint) SetBytes1(in []byte) *Uint { + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(in[0]) + return z +} + +// SetBytes2 is identical to SetBytes(in[:2]), but panics is input is too short +func (z *Uint) SetBytes2(in []byte) *Uint { + _ = in[1] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint16(in[0:2])) + return z +} + +// SetBytes3 is identical to SetBytes(in[:3]), but panics is input is too short +func (z *Uint) SetBytes3(in []byte) *Uint { + _ = in[2] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + return z +} + +// SetBytes4 is identical to SetBytes(in[:4]), but panics is input is too short +func (z *Uint) SetBytes4(in []byte) *Uint { + _ = in[3] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint32(in[0:4])) + return z +} + +// SetBytes5 is identical to SetBytes(in[:5]), but panics is input is too short +func (z *Uint) SetBytes5(in []byte) *Uint { + _ = in[4] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint40(in[0:5]) + return z +} + +// SetBytes6 is identical to SetBytes(in[:6]), but panics is input is too short +func (z *Uint) SetBytes6(in []byte) *Uint { + _ = in[5] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint48(in[0:6]) + return z +} + +// SetBytes7 is identical to SetBytes(in[:7]), but panics is input is too short +func (z *Uint) SetBytes7(in []byte) *Uint { + _ = in[6] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint56(in[0:7]) + return z +} + +// SetBytes8 is identical to SetBytes(in[:8]), but panics is input is too short +func (z *Uint) SetBytes8(in []byte) *Uint { + _ = in[7] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = binary.BigEndian.Uint64(in[0:8]) + return z +} + +// SetBytes9 is identical to SetBytes(in[:9]), but panics is input is too short +func (z *Uint) SetBytes9(in []byte) *Uint { + _ = in[8] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(in[0]) + z.arr[0] = binary.BigEndian.Uint64(in[1:9]) + return z +} + +// SetBytes10 is identical to SetBytes(in[:10]), but panics is input is too short +func (z *Uint) SetBytes10(in []byte) *Uint { + _ = in[9] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[0] = binary.BigEndian.Uint64(in[2:10]) + return z +} + +// SetBytes11 is identical to SetBytes(in[:11]), but panics is input is too short +func (z *Uint) SetBytes11(in []byte) *Uint { + _ = in[10] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[0] = binary.BigEndian.Uint64(in[3:11]) + return z +} + +// SetBytes12 is identical to SetBytes(in[:12]), but panics is input is too short +func (z *Uint) SetBytes12(in []byte) *Uint { + _ = in[11] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[0] = binary.BigEndian.Uint64(in[4:12]) + return z +} + +// SetBytes13 is identical to SetBytes(in[:13]), but panics is input is too short +func (z *Uint) SetBytes13(in []byte) *Uint { + _ = in[12] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint40(in[0:5]) + z.arr[0] = binary.BigEndian.Uint64(in[5:13]) + return z +} + +// SetBytes14 is identical to SetBytes(in[:14]), but panics is input is too short +func (z *Uint) SetBytes14(in []byte) *Uint { + _ = in[13] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint48(in[0:6]) + z.arr[0] = binary.BigEndian.Uint64(in[6:14]) + return z +} + +// SetBytes15 is identical to SetBytes(in[:15]), but panics is input is too short +func (z *Uint) SetBytes15(in []byte) *Uint { + _ = in[14] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint56(in[0:7]) + z.arr[0] = binary.BigEndian.Uint64(in[7:15]) + return z +} + +// SetBytes16 is identical to SetBytes(in[:16]), but panics is input is too short +func (z *Uint) SetBytes16(in []byte) *Uint { + _ = in[15] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = binary.BigEndian.Uint64(in[0:8]) + z.arr[0] = binary.BigEndian.Uint64(in[8:16]) + return z +} + +// SetBytes17 is identical to SetBytes(in[:17]), but panics is input is too short +func (z *Uint) SetBytes17(in []byte) *Uint { + _ = in[16] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(in[0]) + z.arr[1] = binary.BigEndian.Uint64(in[1:9]) + z.arr[0] = binary.BigEndian.Uint64(in[9:17]) + return z +} + +// SetBytes18 is identical to SetBytes(in[:18]), but panics is input is too short +func (z *Uint) SetBytes18(in []byte) *Uint { + _ = in[17] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[1] = binary.BigEndian.Uint64(in[2:10]) + z.arr[0] = binary.BigEndian.Uint64(in[10:18]) + return z +} + +// SetBytes19 is identical to SetBytes(in[:19]), but panics is input is too short +func (z *Uint) SetBytes19(in []byte) *Uint { + _ = in[18] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[1] = binary.BigEndian.Uint64(in[3:11]) + z.arr[0] = binary.BigEndian.Uint64(in[11:19]) + return z +} + +// SetBytes20 is identical to SetBytes(in[:20]), but panics is input is too short +func (z *Uint) SetBytes20(in []byte) *Uint { + _ = in[19] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[1] = binary.BigEndian.Uint64(in[4:12]) + z.arr[0] = binary.BigEndian.Uint64(in[12:20]) + return z +} + +// SetBytes21 is identical to SetBytes(in[:21]), but panics is input is too short +func (z *Uint) SetBytes21(in []byte) *Uint { + _ = in[20] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint40(in[0:5]) + z.arr[1] = binary.BigEndian.Uint64(in[5:13]) + z.arr[0] = binary.BigEndian.Uint64(in[13:21]) + return z +} + +// SetBytes22 is identical to SetBytes(in[:22]), but panics is input is too short +func (z *Uint) SetBytes22(in []byte) *Uint { + _ = in[21] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint48(in[0:6]) + z.arr[1] = binary.BigEndian.Uint64(in[6:14]) + z.arr[0] = binary.BigEndian.Uint64(in[14:22]) + return z +} + +// SetBytes23 is identical to SetBytes(in[:23]), but panics is input is too short +func (z *Uint) SetBytes23(in []byte) *Uint { + _ = in[22] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint56(in[0:7]) + z.arr[1] = binary.BigEndian.Uint64(in[7:15]) + z.arr[0] = binary.BigEndian.Uint64(in[15:23]) + return z +} + +// SetBytes24 is identical to SetBytes(in[:24]), but panics is input is too short +func (z *Uint) SetBytes24(in []byte) *Uint { + _ = in[23] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = binary.BigEndian.Uint64(in[0:8]) + z.arr[1] = binary.BigEndian.Uint64(in[8:16]) + z.arr[0] = binary.BigEndian.Uint64(in[16:24]) + return z +} + +// SetBytes25 is identical to SetBytes(in[:25]), but panics is input is too short +func (z *Uint) SetBytes25(in []byte) *Uint { + _ = in[24] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(in[0]) + z.arr[2] = binary.BigEndian.Uint64(in[1:9]) + z.arr[1] = binary.BigEndian.Uint64(in[9:17]) + z.arr[0] = binary.BigEndian.Uint64(in[17:25]) + return z +} + +// SetBytes26 is identical to SetBytes(in[:26]), but panics is input is too short +func (z *Uint) SetBytes26(in []byte) *Uint { + _ = in[25] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[2] = binary.BigEndian.Uint64(in[2:10]) + z.arr[1] = binary.BigEndian.Uint64(in[10:18]) + z.arr[0] = binary.BigEndian.Uint64(in[18:26]) + return z +} + +// SetBytes27 is identical to SetBytes(in[:27]), but panics is input is too short +func (z *Uint) SetBytes27(in []byte) *Uint { + _ = in[26] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[2] = binary.BigEndian.Uint64(in[3:11]) + z.arr[1] = binary.BigEndian.Uint64(in[11:19]) + z.arr[0] = binary.BigEndian.Uint64(in[19:27]) + return z +} + +// SetBytes28 is identical to SetBytes(in[:28]), but panics is input is too short +func (z *Uint) SetBytes28(in []byte) *Uint { + _ = in[27] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[2] = binary.BigEndian.Uint64(in[4:12]) + z.arr[1] = binary.BigEndian.Uint64(in[12:20]) + z.arr[0] = binary.BigEndian.Uint64(in[20:28]) + return z +} + +// SetBytes29 is identical to SetBytes(in[:29]), but panics is input is too short +func (z *Uint) SetBytes29(in []byte) *Uint { + _ = in[23] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint40(in[0:5]) + z.arr[2] = binary.BigEndian.Uint64(in[5:13]) + z.arr[1] = binary.BigEndian.Uint64(in[13:21]) + z.arr[0] = binary.BigEndian.Uint64(in[21:29]) + return z +} + +// SetBytes30 is identical to SetBytes(in[:30]), but panics is input is too short +func (z *Uint) SetBytes30(in []byte) *Uint { + _ = in[29] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint48(in[0:6]) + z.arr[2] = binary.BigEndian.Uint64(in[6:14]) + z.arr[1] = binary.BigEndian.Uint64(in[14:22]) + z.arr[0] = binary.BigEndian.Uint64(in[22:30]) + return z +} + +// SetBytes31 is identical to SetBytes(in[:31]), but panics is input is too short +func (z *Uint) SetBytes31(in []byte) *Uint { + _ = in[30] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint56(in[0:7]) + z.arr[2] = binary.BigEndian.Uint64(in[7:15]) + z.arr[1] = binary.BigEndian.Uint64(in[15:23]) + z.arr[0] = binary.BigEndian.Uint64(in[23:31]) + return z +} + +// SetBytes32 sets z to the value of the big-endian 256-bit unsigned integer in. +func (z *Uint) SetBytes32(in []byte) *Uint { + _ = in[31] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = binary.BigEndian.Uint64(in[0:8]) + z.arr[2] = binary.BigEndian.Uint64(in[8:16]) + z.arr[1] = binary.BigEndian.Uint64(in[16:24]) + z.arr[0] = binary.BigEndian.Uint64(in[24:32]) + return z +} + +// Utility methods that are "missing" among the bigEndian.UintXX methods. + +// bigEndianUint40 returns the uint64 value represented by the 5 bytes in big-endian order. +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +// bigEndianUint56 returns the uint64 value represented by the 7 bytes in big-endian order. +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} + +// bigEndianUint48 returns the uint64 value represented by the 6 bytes in big-endian order. +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} diff --git a/coversion_test.go b/coversion_test.go new file mode 100644 index 0000000..b619b1d --- /dev/null +++ b/coversion_test.go @@ -0,0 +1,58 @@ +package u256 + +import "testing" + +func TestIsUint64(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0x0", true}, + {"0x1", true}, + {"0x10", true}, + {"0xffffffffffffffff", true}, + {"0x10000000000000000", false}, + } + + for _, tc := range tests { + x := MustFromHex(tc.x) + got := x.IsUint64() + + if got != tc.want { + t.Errorf("IsUint64(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestDec(t *testing.T) { + testCases := []struct { + name string + z Uint + want string + }{ + { + name: "zero", + z: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: "0", + }, + { + name: "less than 20 digits", + z: Uint{arr: [4]uint64{1234567890, 0, 0, 0}}, + want: "1234567890", + }, + { + name: "max possible value", + z: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: "115792089237316195423570985008687907853269984665640564039457584007913129639935", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.z.Dec() + if result != tc.want { + t.Errorf("Dec(%v) = %s, want %s", tc.z, result, tc.want) + } + }) + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..61d6328 --- /dev/null +++ b/errors.go @@ -0,0 +1,75 @@ +package u256 + +import ( + "errors" + + "strconv" +) + +var ( + ErrEmptyString = errors.New("empty hex string") + ErrSyntax = errors.New("invalid hex string") + ErrRange = errors.New("number out of range") + ErrMissingPrefix = errors.New("hex string without 0x prefix") + ErrEmptyNumber = errors.New("hex string \"0x\"") + ErrLeadingZero = errors.New("hex number with leading zero digits") + ErrBig256Range = errors.New("hex number > 256 bits") + ErrBadBufferLength = errors.New("bad ssz buffer length") + ErrBadEncodedLength = errors.New("bad ssz encoded length") + ErrInvalidBase = errors.New("invalid base") + ErrInvalidBitSize = errors.New("invalid bit size") +) + +type u256Error struct { + fn string // function name + input string + err error +} + +func (e *u256Error) Error() string { + return e.fn + ": " + e.input + ": " + e.err.Error() +} + +func (e *u256Error) Unwrap() error { + return e.err +} + +func errEmptyString(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrEmptyString} +} + +func errSyntax(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrSyntax} +} + +func errMissingPrefix(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrMissingPrefix} +} + +func errEmptyNumber(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrEmptyNumber} +} + +func errLeadingZero(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrLeadingZero} +} + +func errRange(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrRange} +} + +func errBig256Range(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrBig256Range} +} + +func errBadBufferLength(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrBadBufferLength} +} + +func errInvalidBase(fn string, base int) error { + return &u256Error{fn: fn, input: strconv.Itoa(base), err: ErrInvalidBase} +} + +func errInvalidBitSize(fn string, bitSize int) error { + return &u256Error{fn: fn, input: strconv.Itoa(bitSize), err: ErrInvalidBitSize} +} diff --git a/i256.go b/i256.go deleted file mode 100644 index 89c1f7b..0000000 --- a/i256.go +++ /dev/null @@ -1,253 +0,0 @@ -package u256 - -// signed integer wrapper - -type Int struct { - v Uint -} - -func NewInt(v int64) *Int { - if v >= 0 { - return &Int{v: *NewUint(uint64(v))} - } - return &Int{ - v: Uint{ - arr: [4]uint64{ - uint64(v), // bit preserving cast, little endian - 0xffffffffffffffff, - 0xffffffffffffffff, - 0xffffffffffffffff, - }, - }, - } -} - -// func IntFromBigint(v bigint) *Int { -// if v > MaxUint256/2-1 { -// panic("I256 IntFromBigint overflow") -// } -// if v < -MaxUint256/2 { -// panic("I256 IntFromBigint underflow") -// } - -// if v >= 0 { -// return &Int{v: *FromBigint(v)} -// } else { -// var tmp Int -// tmp.v = *FromBigint(-v) -// tmp.Neg() -// return &tmp -// } - -// panic("I256 IntFromBigint not implemented") -// } - -// func (x *Int) Bigint() bigint { -// if x.Signum() < 0 { -// return -x.Neg().v.Bigint() -// } -// return x.v.Bigint() - -// } - -func (x *Int) IsNeg() bool { - return x.Signum() < 0 -} - -func (x *Int) Add(y *Int, z *Int) *Int { - x.v.Add(&y.v, &z.v) - - ys := y.Signum() - zs := z.Signum() - - if ys > 0 && zs > 0 && x.Signum() < 0 { - panic("I256 Add overflow") - } - - if ys < 0 && zs < 0 && x.Signum() > 0 { - panic("I256 Add underflow") - } - - return x -} - -func (x *Int) Sub(y *Int, z *Int) *Int { - x.v.UnsafeSub(&y.v, &z.v) - - ys := y.Signum() - zs := z.Signum() - - if ys > 0 && zs < 0 && x.Signum() < 0 { - panic("I256 Sub overflow") - } - - if ys < 0 && zs > 0 && x.Signum() > 0 { - panic("I256 Sub underflow") - } - - return x -} - -func (x *Int) Mul(y *Int, z *Int) *Int { - x.v.Mul(&y.v, &z.v) - - ys := y.Signum() - zs := z.Signum() - - if ys > 0 && zs > 0 && x.Signum() < 0 { - panic("I256 Mul overflow #1") - } - - if ys < 0 && zs < 0 && x.Signum() < 0 { - panic("I256 Mul overflow #2") - } - - if ys > 0 && zs < 0 && x.Signum() > 0 { - panic("I256 Mul underflow #1") - } - - if ys < 0 && zs > 0 && x.Signum() > 0 { - panic("I256 Mul underflow #2") - } - - return x -} - -func (x *Int) Lsh(y *Int, n uint) *Int { - x.v.Lsh(&y.v, n) - return x -} - -func (x *Int) Rsh(y *Int, n uint) *Int { - x.v.Rsh(&y.v, n) - return x -} - -func (x *Int) Eq(y *Int) bool { - return x.v.Eq(&y.v) -} - -func (x *Int) IsZero() bool { - return x.v.IsZero() -} - -func (x *Int) Signum() int { - if x.v.arr[3] == 0 && x.v.arr[2] == 0 && x.v.arr[1] == 0 && x.v.arr[0] == 0 { - return 0 - } - if x.v.arr[3] < 0x8000000000000000 { - return 1 - } - return -1 -} - -func (x *Int) Gt(y *Int) bool { - xs := x.Signum() - ys := y.Signum() - - if xs != ys { - return xs > ys - } - if xs == 0 { - return false - } - if xs > 0 { - return x.v.Gt(&y.v) - } - return y.v.Gt(&x.v) -} - -func (x *Int) Lte(y *Int) bool { - return !x.Gt(y) -} - -func (x *Int) Gte(y *Int) bool { - xs := x.Signum() - ys := y.Signum() - - if xs != ys { - return xs > ys - } - if xs == 0 { - return true - } - if xs > 0 { - return x.v.Gte(&y.v) - } - return y.v.Gte(&x.v) -} - -func (x *Int) Int64() int64 { - // TODO: overflow check - if x.v.arr[3] < 0x8000000000000000 { - return int64(x.v.arr[0]) - } - // TODO: check if this is correct - return -int64(^x.v.arr[0] + 1) -} - -func (x *Int) Abs() *Uint { - if x.Signum() > 0 { - return &x.v - } - x1 := &Int{v: x.v} // so that we don't modify x - return &x1.Neg().v -} - -func (x *Int) Neg() *Int { - if x.Signum() == 0 { - return x - } - - // twos complement - x.v.Not(&x.v) - x.v.Add(&x.v, &Uint{arr: [4]uint64{1, 0, 0, 0}}) - return x -} - -func (x *Int) Dec() string { - if x.Signum() < 0 { - return "-" + x.Abs().Dec() - } - return x.Abs().Dec() -} - -func (x *Int) Uint() *Uint { - if x.Signum() < 0 { - // panic("I256 Uint negative") - return &x.Neg().v // r3v4_xxx: safe ?? - } - return &x.v -} - -func (z *Int) Or(x, y *Int) *Int { - z.v.Or(&x.v, &y.v) - return z -} - -func (z *Int) NilToZero() *Int { - if z == nil { - z = NewInt(0) - } - - return z -} - -// Clone creates a new Int identical to z -func (z *Int) Clone() *Int { - var x Int - - x.Sub(z, NewInt(0)) - return &x -} - -// // Clone creates a new Int identical to z -// func (z *Uint) Clone() *Uint { -// var x Uint -// x.arr[0] = z.arr[0] -// x.arr[1] = z.arr[1] -// x.arr[2] = z.arr[2] -// x.arr[3] = z.arr[3] - -// return &x -// } diff --git a/mod.go b/mod.go new file mode 100644 index 0000000..a6bd3c8 --- /dev/null +++ b/mod.go @@ -0,0 +1,605 @@ +package u256 + +import ( + "math/bits" +) + +// Some utility functions + +// Reciprocal computes a 320-bit value representing 1/m +// +// Notes: +// - specialized for m.arr[3] != 0, hence limited to 2^192 <= m < 2^256 +// - returns zero if m.arr[3] == 0 +// - starts with a 32-bit division, refines with newton-raphson iterations +func Reciprocal(m *Uint) (mu [5]uint64) { + if m.arr[3] == 0 { + return mu + } + + s := bits.LeadingZeros64(m.arr[3]) // Replace with leadingZeros(m) for general case + p := 255 - s // floor(log_2(m)), m>0 + + // 0 or a power of 2? + + // Check if at least one bit is set in m.arr[2], m.arr[1] or m.arr[0], + // or at least two bits in m.arr[3] + + if m.arr[0]|m.arr[1]|m.arr[2]|(m.arr[3]&(m.arr[3]-1)) == 0 { + + mu[4] = ^uint64(0) >> uint(p&63) + mu[3] = ^uint64(0) + mu[2] = ^uint64(0) + mu[1] = ^uint64(0) + mu[0] = ^uint64(0) + + return mu + } + + // Maximise division precision by left-aligning divisor + + var ( + y Uint // left-aligned copy of m + r0 uint32 // estimate of 2^31/y + ) + + y.Lsh(m, uint(s)) // 1/2 < y < 1 + + // Extract most significant 32 bits + + yh := uint32(y.arr[3] >> 32) + + if yh == 0x80000000 { // Avoid overflow in division + r0 = 0xffffffff + } else { + r0, _ = bits.Div32(0x80000000, 0, yh) + } + + // First iteration: 32 -> 64 + + t1 := uint64(r0) // 2^31/y + t1 *= t1 // 2^62/y^2 + t1, _ = bits.Mul64(t1, y.arr[3]) // 2^62/y^2 * 2^64/y / 2^64 = 2^62/y + + r1 := uint64(r0) << 32 // 2^63/y + r1 -= t1 // 2^63/y - 2^62/y = 2^62/y + r1 *= 2 // 2^63/y + + if (r1 | (y.arr[3] << 1)) == 0 { + r1 = ^uint64(0) + } + + // Second iteration: 64 -> 128 + + // square: 2^126/y^2 + a2h, a2l := bits.Mul64(r1, r1) + + // multiply by y: e2h:e2l:b2h = 2^126/y^2 * 2^128/y / 2^128 = 2^126/y + b2h, _ := bits.Mul64(a2l, y.arr[2]) + c2h, c2l := bits.Mul64(a2l, y.arr[3]) + d2h, d2l := bits.Mul64(a2h, y.arr[2]) + e2h, e2l := bits.Mul64(a2h, y.arr[3]) + + b2h, c := bits.Add64(b2h, c2l, 0) + e2l, c = bits.Add64(e2l, c2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + _, c = bits.Add64(b2h, d2l, 0) + e2l, c = bits.Add64(e2l, d2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + // subtract: t2h:t2l = 2^127/y - 2^126/y = 2^126/y + t2l, b := bits.Sub64(0, e2l, 0) + t2h, _ := bits.Sub64(r1, e2h, b) + + // double: r2h:r2l = 2^127/y + r2l, c := bits.Add64(t2l, t2l, 0) + r2h, _ := bits.Add64(t2h, t2h, c) + + if (r2h | r2l | (y.arr[3] << 1)) == 0 { + r2h = ^uint64(0) + r2l = ^uint64(0) + } + + // Third iteration: 128 -> 192 + + // square r2 (keep 256 bits): 2^190/y^2 + a3h, a3l := bits.Mul64(r2l, r2l) + b3h, b3l := bits.Mul64(r2l, r2h) + c3h, c3l := bits.Mul64(r2h, r2h) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + // multiply by y: q = 2^190/y^2 * 2^192/y / 2^192 = 2^190/y + + x0 := a3l + x1 := a3h + x2 := c3l + x3 := c3h + + var q0, q1, q2, q3, q4, t0 uint64 + + q0, _ = bits.Mul64(x2, y.arr[0]) + q1, t0 = bits.Mul64(x3, y.arr[0]) + q0, c = bits.Add64(q0, t0, 0) + q1, _ = bits.Add64(q1, 0, c) + + t1, _ = bits.Mul64(x1, y.arr[1]) + q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x3, y.arr[1]) + q1, c = bits.Add64(q1, t0, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x2, y.arr[1]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[2]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x3, y.arr[2]) + q2, c = bits.Add64(q2, t0, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x0, y.arr[2]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, y.arr[2]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[3]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x3, y.arr[3]) + q3, c = bits.Add64(q3, t0, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[3]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[3]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q4, _ = bits.Add64(q4, 0, c) + + // subtract: t3 = 2^191/y - 2^190/y = 2^190/y + _, b = bits.Sub64(0, q0, 0) + _, b = bits.Sub64(0, q1, b) + t3l, b := bits.Sub64(0, q2, b) + t3m, b := bits.Sub64(r2l, q3, b) + t3h, _ := bits.Sub64(r2h, q4, b) + + // double: r3 = 2^191/y + r3l, c := bits.Add64(t3l, t3l, 0) + r3m, c := bits.Add64(t3m, t3m, c) + r3h, _ := bits.Add64(t3h, t3h, c) + + // Fourth iteration: 192 -> 320 + + // square r3 + + a4h, a4l := bits.Mul64(r3l, r3l) + b4h, b4l := bits.Mul64(r3l, r3m) + c4h, c4l := bits.Mul64(r3l, r3h) + d4h, d4l := bits.Mul64(r3m, r3m) + e4h, e4l := bits.Mul64(r3m, r3h) + f4h, f4l := bits.Mul64(r3h, r3h) + + b4h, c = bits.Add64(b4h, c4l, 0) + e4l, c = bits.Add64(e4l, c4h, c) + e4h, _ = bits.Add64(e4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + // multiply by y + + x1, x0 = bits.Mul64(d4h, y.arr[0]) + x3, x2 = bits.Mul64(f4h, y.arr[0]) + t1, t0 = bits.Mul64(f4l, y.arr[0]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + x3, _ = bits.Add64(x3, 0, c) + + t1, t0 = bits.Mul64(d4h, y.arr[1]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + x4, t0 := bits.Mul64(f4h, y.arr[1]) + x3, c = bits.Add64(x3, t0, c) + x4, _ = bits.Add64(x4, 0, c) + t1, t0 = bits.Mul64(d4l, y.arr[1]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[1]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + x4, _ = bits.Add64(x4, 0, c) + + t1, t0 = bits.Mul64(a4h, y.arr[2]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4h, y.arr[2]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + x5, t0 := bits.Mul64(f4h, y.arr[2]) + x4, c = bits.Add64(x4, t0, c) + x5, _ = bits.Add64(x5, 0, c) + t1, t0 = bits.Mul64(d4l, y.arr[2]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[2]) + x3, c = bits.Add64(x3, t0, c) + x4, c = bits.Add64(x4, t1, c) + x5, _ = bits.Add64(x5, 0, c) + + t1, t0 = bits.Mul64(a4h, y.arr[3]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(d4h, y.arr[3]) + x3, c = bits.Add64(x3, t0, c) + x4, c = bits.Add64(x4, t1, c) + x6, t0 := bits.Mul64(f4h, y.arr[3]) + x5, c = bits.Add64(x5, t0, c) + x6, _ = bits.Add64(x6, 0, c) + t1, t0 = bits.Mul64(a4l, y.arr[3]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4l, y.arr[3]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[3]) + x4, c = bits.Add64(x4, t0, c) + x5, c = bits.Add64(x5, t1, c) + x6, _ = bits.Add64(x6, 0, c) + + // subtract + _, b = bits.Sub64(0, x0, 0) + _, b = bits.Sub64(0, x1, b) + r4l, b := bits.Sub64(0, x2, b) + r4k, b := bits.Sub64(0, x3, b) + r4j, b := bits.Sub64(r3l, x4, b) + r4i, b := bits.Sub64(r3m, x5, b) + r4h, _ := bits.Sub64(r3h, x6, b) + + // Multiply candidate for 1/4y by y, with full precision + + x0 = r4l + x1 = r4k + x2 = r4j + x3 = r4i + x4 = r4h + + q1, q0 = bits.Mul64(x0, y.arr[0]) + q3, q2 = bits.Mul64(x2, y.arr[0]) + q5, q4 := bits.Mul64(x4, y.arr[0]) + + t1, t0 = bits.Mul64(x1, y.arr[0]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[0]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[1]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[1]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q6, t0 := bits.Mul64(x4, y.arr[1]) + q5, c = bits.Add64(q5, t0, c) + q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[1]) + q2, c = bits.Add64(q2, t0, 0) + q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[1]) + q4, c = bits.Add64(q4, t0, c) + q5, c = bits.Add64(q5, t1, c) + q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[2]) + q2, c = bits.Add64(q2, t0, 0) + q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[2]) + q4, c = bits.Add64(q4, t0, c) + q5, c = bits.Add64(q5, t1, c) + q7, t0 := bits.Mul64(x4, y.arr[2]) + q6, c = bits.Add64(q6, t0, c) + q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[2]) + q3, c = bits.Add64(q3, t0, 0) + q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[2]) + q5, c = bits.Add64(q5, t0, c) + q6, c = bits.Add64(q6, t1, c) + q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[3]) + q3, c = bits.Add64(q3, t0, 0) + q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[3]) + q5, c = bits.Add64(q5, t0, c) + q6, c = bits.Add64(q6, t1, c) + q8, t0 := bits.Mul64(x4, y.arr[3]) + q7, c = bits.Add64(q7, t0, c) + q8, _ = bits.Add64(q8, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[3]) + q4, c = bits.Add64(q4, t0, 0) + q5, c = bits.Add64(q5, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[3]) + q6, c = bits.Add64(q6, t0, c) + q7, c = bits.Add64(q7, t1, c) + q8, _ = bits.Add64(q8, 0, c) + + // Final adjustment + + // subtract q from 1/4 + _, b = bits.Sub64(0, q0, 0) + _, b = bits.Sub64(0, q1, b) + _, b = bits.Sub64(0, q2, b) + _, b = bits.Sub64(0, q3, b) + _, b = bits.Sub64(0, q4, b) + _, b = bits.Sub64(0, q5, b) + _, b = bits.Sub64(0, q6, b) + _, b = bits.Sub64(0, q7, b) + _, b = bits.Sub64(uint64(1)<<62, q8, b) + + // decrement the result + x0, t := bits.Sub64(r4l, 1, 0) + x1, t = bits.Sub64(r4k, 0, t) + x2, t = bits.Sub64(r4j, 0, t) + x3, t = bits.Sub64(r4i, 0, t) + x4, _ = bits.Sub64(r4h, 0, t) + + // commit the decrement if the subtraction underflowed (reciprocal was too large) + if b != 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + // Shift to correct bit alignment, truncating excess bits + + p = (p & 63) - 1 + + x0, c = bits.Add64(r4l, r4l, 0) + x1, c = bits.Add64(r4k, r4k, c) + x2, c = bits.Add64(r4j, r4j, c) + x3, c = bits.Add64(r4i, r4i, c) + x4, _ = bits.Add64(r4h, r4h, c) + + if p < 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + p = 0 // avoid negative shift below + } + + { + r := uint(p) // right shift + l := uint(64 - r) // left shift + + x0 = (r4l >> r) | (r4k << l) + x1 = (r4k >> r) | (r4j << l) + x2 = (r4j >> r) | (r4i << l) + x3 = (r4i >> r) | (r4h << l) + x4 = (r4h >> r) + } + + if p > 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + mu[0] = r4l + mu[1] = r4k + mu[2] = r4j + mu[3] = r4i + mu[4] = r4h + + return mu +} + +// reduce4 computes the least non-negative residue of x modulo m +// +// requires a four-word modulus (m.arr[3] > 1) and its inverse (mu) +func reduce4(x [8]uint64, m *Uint, mu [5]uint64) (z Uint) { + // NB: Most variable names in the comments match the pseudocode for + // Barrett reduction in the Handbook of Applied Cryptography. + + // q1 = x/2^192 + + x0 := x[3] + x1 := x[4] + x2 := x[5] + x3 := x[6] + x4 := x[7] + + // q2 = q1 * mu; q3 = q2 / 2^320 + + var q0, q1, q2, q3, q4, q5, t0, t1, c uint64 + + q0, _ = bits.Mul64(x3, mu[0]) + q1, t0 = bits.Mul64(x4, mu[0]) + q0, c = bits.Add64(q0, t0, 0) + q1, _ = bits.Add64(q1, 0, c) + + t1, _ = bits.Mul64(x2, mu[1]) + q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x4, mu[1]) + q1, c = bits.Add64(q1, t0, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x3, mu[1]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x2, mu[2]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x4, mu[2]) + q2, c = bits.Add64(q2, t0, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x1, mu[2]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x3, mu[2]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x0, mu[3]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, mu[3]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x4, mu[3]) + q3, c = bits.Add64(q3, t0, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x1, mu[3]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x3, mu[3]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x0, mu[4]) + _, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, mu[4]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q5, t0 = bits.Mul64(x4, mu[4]) + q4, c = bits.Add64(q4, t0, c) + q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x1, mu[4]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, mu[4]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q5, _ = bits.Add64(q5, 0, c) + + // Drop the fractional part of q3 + + q0 = q1 + q1 = q2 + q2 = q3 + q3 = q4 + q4 = q5 + + // r1 = x mod 2^320 + + x0 = x[0] + x1 = x[1] + x2 = x[2] + x3 = x[3] + x4 = x[4] + + // r2 = q3 * m mod 2^320 + + var r0, r1, r2, r3, r4 uint64 + + r4, r3 = bits.Mul64(q0, m.arr[3]) + _, t0 = bits.Mul64(q1, m.arr[3]) + r4, _ = bits.Add64(r4, t0, 0) + + t1, r2 = bits.Mul64(q0, m.arr[2]) + r3, c = bits.Add64(r3, t1, 0) + _, t0 = bits.Mul64(q2, m.arr[2]) + r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m.arr[2]) + r3, c = bits.Add64(r3, t0, 0) + r4, _ = bits.Add64(r4, t1, c) + + t1, r1 = bits.Mul64(q0, m.arr[1]) + r2, c = bits.Add64(r2, t1, 0) + t1, t0 = bits.Mul64(q2, m.arr[1]) + r3, c = bits.Add64(r3, t0, c) + r4, _ = bits.Add64(r4, t1, c) + + t1, t0 = bits.Mul64(q1, m.arr[1]) + r2, c = bits.Add64(r2, t0, 0) + r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q3, m.arr[1]) + r4, _ = bits.Add64(r4, t0, c) + + t1, r0 = bits.Mul64(q0, m.arr[0]) + r1, c = bits.Add64(r1, t1, 0) + t1, t0 = bits.Mul64(q2, m.arr[0]) + r2, c = bits.Add64(r2, t0, c) + r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q4, m.arr[0]) + r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m.arr[0]) + r1, c = bits.Add64(r1, t0, 0) + r2, c = bits.Add64(r2, t1, c) + t1, t0 = bits.Mul64(q3, m.arr[0]) + r3, c = bits.Add64(r3, t0, c) + r4, _ = bits.Add64(r4, t1, c) + + // r = r1 - r2 + + var b uint64 + + r0, b = bits.Sub64(x0, r0, 0) + r1, b = bits.Sub64(x1, r1, b) + r2, b = bits.Sub64(x2, r2, b) + r3, b = bits.Sub64(x3, r3, b) + r4, b = bits.Sub64(x4, r4, b) + + // if r<0 then r+=m + + if b != 0 { + r0, c = bits.Add64(r0, m.arr[0], 0) + r1, c = bits.Add64(r1, m.arr[1], c) + r2, c = bits.Add64(r2, m.arr[2], c) + r3, c = bits.Add64(r3, m.arr[3], c) + r4, _ = bits.Add64(r4, 0, c) + } + + // while (r>=m) r-=m + + for { + // q = r - m + q0, b = bits.Sub64(r0, m.arr[0], 0) + q1, b = bits.Sub64(r1, m.arr[1], b) + q2, b = bits.Sub64(r2, m.arr[2], b) + q3, b = bits.Sub64(r3, m.arr[3], b) + q4, b = bits.Sub64(r4, 0, b) + + // if borrow break + if b != 0 { + break + } + + // r = q + r4, r3, r2, r1, r0 = q4, q3, q2, q1, q0 + } + + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = r3, r2, r1, r0 + + return z +} diff --git a/u256.go b/u256.go index ad048b2..f33ccb4 100644 --- a/u256.go +++ b/u256.go @@ -1,30 +1,16 @@ -// Ported from https://github.com/holiman/uint256/ - +// Ported from https://github.com/holiman/uint256 +// This package provides a 256-bit unsigned integer type, Uint256, and associated functions. package u256 import ( "errors" + "math/bits" ) -const MaxUint64 = 1<<64 - 1 - -// TODO: remove -// const MaxUint256 bigint = 115792089237316195423570985008687907853269984665640564039457584007913129639935 - -func Zero() *Uint { - return NewUint(0) -} - -func One() *Uint { - return NewUint(1) -} - -func (x *Uint) Min(y *Uint) *Uint { - if x.Lt(y) { - return x - } - return y -} +const ( + MaxUint64 = 1<<64 - 1 + uintSize = 32 << (^uint(0) >> 63) +) // Uint is represented as an array of 4 uint64, in little-endian order, // so that Uint[3] is the most significant, and Uint[0] is the least significant @@ -32,80 +18,20 @@ type Uint struct { arr [4]uint64 } -func (x *Uint) Int() *Int { - // panic if x > MaxInt64 - if x.arr[3] > 0x7fffffffffffffff { - panic("U256 Int overflow") - } - - return &Int{v: *x} -} - -// // TODO: to be removed -// func FromBigint(x bigint) *Uint { -// if x > MaxUint256 { -// panic("U256 FromBigint overflow") -// } - -// if x < 0 { -// panic("U256 FromBigint underflow") -// } - -// var z Uint -// z.arr[0] = uint64(x % (1 << 64)) -// z.arr[1] = uint64((x >> 64) % (1 << 64)) -// z.arr[2] = uint64((x >> 128) % (1 << 64)) -// z.arr[3] = uint64((x >> 192) % (1 << 64)) - -// return &z -// } - -// func (x *Uint) Bigint() bigint { -// return (bigint(x.arr[0]) + -// bigint(x.arr[1])*(1<<64) + -// bigint(x.arr[2])*(1<<128) + -// bigint(x.arr[3])*(1<<192)) -// } - // NewUint returns a new initialized Uint. func NewUint(val uint64) *Uint { z := &Uint{arr: [4]uint64{val, 0, 0, 0}} return z } -// Uint64 returns the lower 64-bits of z -func (z *Uint) Uint64() uint64 { - return z.arr[0] -} - -func (z *Uint) Int32() int32 { - x := z.arr[0] - if x > 0x7fffffff { - panic("U256 Int32 overflow") - } - return int32(x) -} - -// SetUint64 sets z to the value x -func (z *Uint) SetUint64(x uint64) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x - return z -} - -// IsUint64 reports whether z can be represented as a uint64. -func (z *Uint) IsUint64() bool { - return (z.arr[1] | z.arr[2] | z.arr[3]) == 0 -} - -// IsZero returns true if z == 0 -func (z *Uint) IsZero() bool { - return (z.arr[0] | z.arr[1] | z.arr[2] | z.arr[3]) == 0 +// Zero returns a new Uint initialized to zero. +func Zero() *Uint { + return NewUint(0) } -// Clear sets z to 0 -func (z *Uint) Clear() *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, 0 - return z +// One returns a new Uint initialized to one. +func One() *Uint { + return NewUint(1) } // SetAllOne sets all the bits of z to 1 @@ -114,60 +40,6 @@ func (z *Uint) SetAllOne() *Uint { return z } -// Not sets z = ^x and returns z. -func (z *Uint) Not(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = ^x.arr[3], ^x.arr[2], ^x.arr[1], ^x.arr[0] - return z -} - -// Gt returns true if z > x -func (z *Uint) Gt(x *Uint) bool { - return x.Lt(z) -} - -func (z *Uint) Gte(x *Uint) bool { - return !z.Lt(x) -} - -func (z *Uint) Lte(x *Uint) bool { - return !x.Gt(z) -} - -// Lt returns true if z < x -func (z *Uint) Lt(x *Uint) bool { - // z < x <=> z - x < 0 i.e. when subtraction overflows. - _, carry := Sub64(z.arr[0], x.arr[0], 0) - _, carry = Sub64(z.arr[1], x.arr[1], carry) - _, carry = Sub64(z.arr[2], x.arr[2], carry) - _, carry = Sub64(z.arr[3], x.arr[3], carry) - return carry != 0 -} - -// Eq returns true if z == x -func (z *Uint) Eq(x *Uint) bool { - return (z.arr[0] == x.arr[0]) && (z.arr[1] == x.arr[1]) && (z.arr[2] == x.arr[2]) && (z.arr[3] == x.arr[3]) -} - -// Cmp compares z and x and returns: -// -// -1 if z < x -// 0 if z == x -// +1 if z > x -func (z *Uint) Cmp(x *Uint) (r int) { - // z < x <=> z - x < 0 i.e. when subtraction overflows. - d0, carry := Sub64(z.arr[0], x.arr[0], 0) - d1, carry := Sub64(z.arr[1], x.arr[1], carry) - d2, carry := Sub64(z.arr[2], x.arr[2], carry) - d3, carry := Sub64(z.arr[3], x.arr[3], carry) - if carry == 1 { - return -1 - } - if d0|d1|d2|d3 == 0 { - return 0 - } - return 1 -} - // Set sets z to x and returns z. func (z *Uint) Set(x *Uint) *Uint { *z = *x @@ -181,677 +53,6 @@ func (z *Uint) SetOne() *Uint { return z } -func (z *Uint) AddInt(x *Uint, y *Int) *Uint { - if y.IsNeg() { - return z.Sub(x, y.Abs()) - } - return z.Add(x, y.Uint()) -} - -// Add sets z to the sum x+y -func (z *Uint) Add(x, y *Uint) *Uint { - var carry uint64 - z.arr[0], carry = Add64(x.arr[0], y.arr[0], 0) - z.arr[1], carry = Add64(x.arr[1], y.arr[1], carry) - z.arr[2], carry = Add64(x.arr[2], y.arr[2], carry) - z.arr[3], _ = Add64(x.arr[3], y.arr[3], carry) - // Different from the original implementation! - // We panic on overflow - if carry != 0 { - panic("U256 Add overflow") - } - return z -} - -// AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred -func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) { - var carry uint64 - z.arr[0], carry = Add64(x.arr[0], y.arr[0], 0) - z.arr[1], carry = Add64(x.arr[1], y.arr[1], carry) - z.arr[2], carry = Add64(x.arr[2], y.arr[2], carry) - z.arr[3], carry = Add64(x.arr[3], y.arr[3], carry) - return z, carry != 0 -} - -// SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed -func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) { - var carry uint64 - z.arr[0], carry = Sub64(x.arr[0], y.arr[0], 0) - z.arr[1], carry = Sub64(x.arr[1], y.arr[1], carry) - z.arr[2], carry = Sub64(x.arr[2], y.arr[2], carry) - z.arr[3], carry = Sub64(x.arr[3], y.arr[3], carry) - return z, carry != 0 -} - -// Sub sets z to the difference x-y -func (z *Uint) Sub(x, y *Uint) *Uint { - var carry uint64 - z.arr[0], carry = Sub64(x.arr[0], y.arr[0], 0) - z.arr[1], carry = Sub64(x.arr[1], y.arr[1], carry) - z.arr[2], carry = Sub64(x.arr[2], y.arr[2], carry) - z.arr[3], _ = Sub64(x.arr[3], y.arr[3], carry) - - // Different from the original implementation! - // We panic on underflow - // r3v4 -> mconcat : why do we panic? - if carry != 0 { - panic("U256 Sub underflow") - } - return z -} - -// Sub sets z to the difference x-y -func (z *Uint) UnsafeSub(x, y *Uint) *Uint { - var carry uint64 - z.arr[0], carry = Sub64(x.arr[0], y.arr[0], 0) - z.arr[1], carry = Sub64(x.arr[1], y.arr[1], carry) - z.arr[2], carry = Sub64(x.arr[2], y.arr[2], carry) - z.arr[3], _ = Sub64(x.arr[3], y.arr[3], carry) - - return z -} - -// commented out for possible overflow -// Mul sets z to the product x*y -func (z *Uint) Mul(x, y *Uint) *Uint { - var ( - res Uint - carry uint64 - res1, res2, res3 uint64 - ) - - carry, res.arr[0] = Mul64(x.arr[0], y.arr[0]) - carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) - carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) - res3 = x.arr[3]*y.arr[0] + carry - - carry, res.arr[1] = umulHop(res1, x.arr[0], y.arr[1]) - carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) - res3 = res3 + x.arr[2]*y.arr[1] + carry - - carry, res.arr[2] = umulHop(res2, x.arr[0], y.arr[2]) - res3 = res3 + x.arr[1]*y.arr[2] + carry - - res.arr[3] = res3 + x.arr[0]*y.arr[3] - - return z.Set(&res) -} - -// MulOverflow sets z to the product x*y, and returns z and whether overflow occurred -func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) { - p := umul(x, y) - copy(z.arr[:], p[:4]) - return z, (p[4] | p[5] | p[6] | p[7]) != 0 -} - -// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry. -func umulStep(z, x, y, carry uint64) (hi, lo uint64) { - hi, lo = Mul64(x, y) - lo, carry = Add64(lo, carry, 0) - hi, _ = Add64(hi, 0, carry) - lo, carry = Add64(lo, z, 0) - hi, _ = Add64(hi, 0, carry) - return hi, lo -} - -// umulHop computes (hi * 2^64 + lo) = z + (x * y) -func umulHop(z, x, y uint64) (hi, lo uint64) { - hi, lo = Mul64(x, y) - lo, carry := Add64(lo, z, 0) - hi, _ = Add64(hi, 0, carry) - return hi, lo -} - -// umul computes full 256 x 256 -> 512 multiplication. -func umul(x, y *Uint) [8]uint64 { - var ( - res [8]uint64 - carry, carry4, carry5, carry6 uint64 - res1, res2, res3, res4, res5 uint64 - ) - - carry, res[0] = Mul64(x.arr[0], y.arr[0]) - carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) - carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) - carry4, res3 = umulHop(carry, x.arr[3], y.arr[0]) - - carry, res[1] = umulHop(res1, x.arr[0], y.arr[1]) - carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) - carry, res3 = umulStep(res3, x.arr[2], y.arr[1], carry) - carry5, res4 = umulStep(carry4, x.arr[3], y.arr[1], carry) - - carry, res[2] = umulHop(res2, x.arr[0], y.arr[2]) - carry, res3 = umulStep(res3, x.arr[1], y.arr[2], carry) - carry, res4 = umulStep(res4, x.arr[2], y.arr[2], carry) - carry6, res5 = umulStep(carry5, x.arr[3], y.arr[2], carry) - - carry, res[3] = umulHop(res3, x.arr[0], y.arr[3]) - carry, res[4] = umulStep(res4, x.arr[1], y.arr[3], carry) - carry, res[5] = umulStep(res5, x.arr[2], y.arr[3], carry) - res[7], res[6] = umulStep(carry6, x.arr[3], y.arr[3], carry) - - return res -} - -// commented out for possible overflow -// Div sets z to the quotient x/y for returns z. -// If y == 0, z is set to 0 -func (z *Uint) Div(x, y *Uint) *Uint { - if y.IsZero() || y.Gt(x) { - return z.Clear() - } - if x.Eq(y) { - return z.SetOne() - } - // Shortcut some cases - if x.IsUint64() { - return z.SetUint64(x.Uint64() / y.Uint64()) - } - - // At this point, we know - // x/y ; x > y > 0 - - var quot Uint - udivrem(quot.arr[:], x.arr[:], y) - return z.Set(") -} - -// udivrem divides u by d and produces both quotient and remainder. -// The quotient is stored in provided quot - len(u)-len(d)+1 words. -// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words. -// See Knuth, Volume 2, section 4.3.1, Algorithm D. -func udivrem(quot, u []uint64, d *Uint) (rem Uint) { - var dLen int - for i := len(d.arr) - 1; i >= 0; i-- { - if d.arr[i] != 0 { - dLen = i + 1 - break - } - } - - shift := uint(LeadingZeros64(d.arr[dLen-1])) - - var dnStorage Uint - dn := dnStorage.arr[:dLen] - for i := dLen - 1; i > 0; i-- { - dn[i] = (d.arr[i] << shift) | (d.arr[i-1] >> (64 - shift)) - } - dn[0] = d.arr[0] << shift - - var uLen int - for i := len(u) - 1; i >= 0; i-- { - if u[i] != 0 { - uLen = i + 1 - break - } - } - - if uLen < dLen { - copy(rem.arr[:], u) - return rem - } - - var unStorage [9]uint64 - un := unStorage[:uLen+1] - un[uLen] = u[uLen-1] >> (64 - shift) - for i := uLen - 1; i > 0; i-- { - un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift)) - } - un[0] = u[0] << shift - - // TODO: Skip the highest word of numerator if not significant. - - if dLen == 1 { - r := udivremBy1(quot, un, dn[0]) - rem.SetUint64(r >> shift) - return rem - } - - udivremKnuth(quot, un, dn) - - for i := 0; i < dLen-1; i++ { - rem.arr[i] = (un[i] >> shift) | (un[i+1] << (64 - shift)) - } - rem.arr[dLen-1] = un[dLen-1] >> shift - - return rem -} - -// udivremBy1 divides u by single normalized word d and produces both quotient and remainder. -// The quotient is stored in provided quot. -func udivremBy1(quot, u []uint64, d uint64) (rem uint64) { - reciprocal := reciprocal2by1(d) - rem = u[len(u)-1] // Set the top word as remainder. - for j := len(u) - 2; j >= 0; j-- { - quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal) - } - return rem -} - -// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm. -// The quotient is stored in provided quot - len(u)-len(d) words. -// Updates u to contain the remainder - len(d) words. -func udivremKnuth(quot, u, d []uint64) { - dh := d[len(d)-1] - dl := d[len(d)-2] - reciprocal := reciprocal2by1(dh) - - for j := len(u) - len(d) - 1; j >= 0; j-- { - u2 := u[j+len(d)] - u1 := u[j+len(d)-1] - u0 := u[j+len(d)-2] - - var qhat, rhat uint64 - if u2 >= dh { // Division overflows. - qhat = ^uint64(0) - // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). - } else { - qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal) - ph, pl := Mul64(qhat, dl) - if ph > rhat || (ph == rhat && pl > u0) { - qhat-- - // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). - } - } - - // Multiply and subtract. - borrow := subMulTo(u[j:], d, qhat) - u[j+len(d)] = u2 - borrow - if u2 < borrow { // Too much subtracted, add back. - qhat-- - u[j+len(d)] += addTo(u[j:], d) - } - - quot[j] = qhat // Store quotient digit. - } -} - -// isBitSet returns true if bit n-th is set, where n = 0 is LSB. -// The n must be <= 255. -func (z *Uint) isBitSet(n uint) bool { - return (z.arr[n/64] & (1 << (n % 64))) != 0 -} - -// addTo computes x += y. -// Requires len(x) >= len(y). -func addTo(x, y []uint64) uint64 { - var carry uint64 - for i := 0; i < len(y); i++ { - x[i], carry = Add64(x[i], y[i], carry) - } - return carry -} - -// subMulTo computes x -= y * multiplier. -// Requires len(x) >= len(y). -func subMulTo(x, y []uint64, multiplier uint64) uint64 { - var borrow uint64 - for i := 0; i < len(y); i++ { - s, carry1 := Sub64(x[i], borrow, 0) - ph, pl := Mul64(y[i], multiplier) - t, carry2 := Sub64(s, pl, 0) - x[i] = t - borrow = ph + carry1 + carry2 - } - return borrow -} - -// reciprocal2by1 computes <^d, ^0> / d. -func reciprocal2by1(d uint64) uint64 { - reciprocal, _ := Div64(^d, ^uint64(0), d) - return reciprocal -} - -// udivrem2by1 divides / d and produces both quotient and remainder. -// It uses the provided d's reciprocal. -// Implementation ported from https://github.com/chfast/intx and is based on -// "Improved division by invariant integers", Algorithm 4. -func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) { - qh, ql := Mul64(reciprocal, uh) - ql, carry := Add64(ql, ul, 0) - qh, _ = Add64(qh, uh, carry) - qh++ - - r := ul - qh*d - - if r > ql { - qh-- - r += d - } - - if r >= d { - qh++ - r -= d - } - - return qh, r -} - -// Lsh sets z = x << n and returns z. -func (z *Uint) Lsh(x *Uint, n uint) *Uint { - // n % 64 == 0 - if n&0x3f == 0 { - switch n { - case 0: - return z.Set(x) - case 64: - return z.lsh64(x) - case 128: - return z.lsh128(x) - case 192: - return z.lsh192(x) - default: - return z.Clear() - } - } - var ( - a, b uint64 - ) - // Big swaps first - switch { - case n > 192: - if n > 256 { - return z.Clear() - } - z.lsh192(x) - n -= 192 - goto sh192 - case n > 128: - z.lsh128(x) - n -= 128 - goto sh128 - case n > 64: - z.lsh64(x) - n -= 64 - goto sh64 - default: - z.Set(x) - } - - // remaining shifts - a = z.arr[0] >> (64 - n) - z.arr[0] = z.arr[0] << n - -sh64: - b = z.arr[1] >> (64 - n) - z.arr[1] = (z.arr[1] << n) | a - -sh128: - a = z.arr[2] >> (64 - n) - z.arr[2] = (z.arr[2] << n) | b - -sh192: - z.arr[3] = (z.arr[3] << n) | a - - return z -} - -// Rsh sets z = x >> n and returns z. -func (z *Uint) Rsh(x *Uint, n uint) *Uint { - // n % 64 == 0 - if n&0x3f == 0 { - switch n { - case 0: - return z.Set(x) - case 64: - return z.rsh64(x) - case 128: - return z.rsh128(x) - case 192: - return z.rsh192(x) - default: - return z.Clear() - } - } - var ( - a, b uint64 - ) - // Big swaps first - switch { - case n > 192: - if n > 256 { - return z.Clear() - } - z.rsh192(x) - n -= 192 - goto sh192 - case n > 128: - z.rsh128(x) - n -= 128 - goto sh128 - case n > 64: - z.rsh64(x) - n -= 64 - goto sh64 - default: - z.Set(x) - } - - // remaining shifts - a = z.arr[3] << (64 - n) - z.arr[3] = z.arr[3] >> n - -sh64: - b = z.arr[2] << (64 - n) - z.arr[2] = (z.arr[2] >> n) | a - -sh128: - a = z.arr[1] << (64 - n) - z.arr[1] = (z.arr[1] >> n) | b - -sh192: - z.arr[0] = (z.arr[0] >> n) | a - - return z -} - -// SRsh (Signed/Arithmetic right shift) -// considers z to be a signed integer, during right-shift -// and sets z = x >> n and returns z. -func (z *Uint) SRsh(x *Uint, n uint) *Uint { - // If the MSB is 0, SRsh is same as Rsh. - if !x.isBitSet(255) { - return z.Rsh(x, n) - } - if n%64 == 0 { - switch n { - case 0: - return z.Set(x) - case 64: - return z.srsh64(x) - case 128: - return z.srsh128(x) - case 192: - return z.srsh192(x) - default: - return z.SetAllOne() - } - } - var ( - a uint64 = MaxUint64 << (64 - n%64) - ) - // Big swaps first - switch { - case n > 192: - if n > 256 { - return z.SetAllOne() - } - z.srsh192(x) - n -= 192 - goto sh192 - case n > 128: - z.srsh128(x) - n -= 128 - goto sh128 - case n > 64: - z.srsh64(x) - n -= 64 - goto sh64 - default: - z.Set(x) - } - - // remaining shifts - z.arr[3], a = (z.arr[3]>>n)|a, z.arr[3]<<(64-n) - -sh64: - z.arr[2], a = (z.arr[2]>>n)|a, z.arr[2]<<(64-n) - -sh128: - z.arr[1], a = (z.arr[1]>>n)|a, z.arr[1]<<(64-n) - -sh192: - z.arr[0] = (z.arr[0] >> n) | a - - return z -} - -func (z *Uint) lsh64(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[2], x.arr[1], x.arr[0], 0 - return z -} -func (z *Uint) lsh128(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[1], x.arr[0], 0, 0 - return z -} -func (z *Uint) lsh192(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[0], 0, 0, 0 - return z -} -func (z *Uint) rsh64(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, x.arr[3], x.arr[2], x.arr[1] - return z -} -func (z *Uint) rsh128(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, x.arr[3], x.arr[2] - return z -} -func (z *Uint) rsh192(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x.arr[3] - return z -} -func (z *Uint) srsh64(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, x.arr[3], x.arr[2], x.arr[1] - return z -} -func (z *Uint) srsh128(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, x.arr[3], x.arr[2] - return z -} -func (z *Uint) srsh192(x *Uint) *Uint { - z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, MaxUint64, x.arr[3] - return z -} - -// Or sets z = x | y and returns z. -func (z *Uint) Or(x, y *Uint) *Uint { - z.arr[0] = x.arr[0] | y.arr[0] - z.arr[1] = x.arr[1] | y.arr[1] - z.arr[2] = x.arr[2] | y.arr[2] - z.arr[3] = x.arr[3] | y.arr[3] - return z -} - -// And sets z = x & y and returns z. -func (z *Uint) And(x, y *Uint) *Uint { - z.arr[0] = x.arr[0] & y.arr[0] - z.arr[1] = x.arr[1] & y.arr[1] - z.arr[2] = x.arr[2] & y.arr[2] - z.arr[3] = x.arr[3] & y.arr[3] - return z -} - -// Xor sets z = x ^ y and returns z. -func (z *Uint) Xor(x, y *Uint) *Uint { - z.arr[0] = x.arr[0] ^ y.arr[0] - z.arr[1] = x.arr[1] ^ y.arr[1] - z.arr[2] = x.arr[2] ^ y.arr[2] - z.arr[3] = x.arr[3] ^ y.arr[3] - return z -} - -// MarshalJSON implements json.Marshaler. -// MarshalJSON marshals using the 'decimal string' representation. This is _not_ compatible -// with big.Uint: big.Uint marshals into JSON 'native' numeric format. -// -// The JSON native format is, on some platforms, (e.g. javascript), limited to 53-bit large -// integer space. Thus, U256 uses string-format, which is not compatible with -// big.int (big.Uint refuses to unmarshal a string representation). -func (z *Uint) MarshalJSON() ([]byte, error) { - return []byte(`"` + z.Dec() + `"`), nil -} - -// UnmarshalJSON implements json.Unmarshaler. UnmarshalJSON accepts either -// - Quoted string: either hexadecimal OR decimal -// - Not quoted string: only decimal -func (z *Uint) UnmarshalJSON(input []byte) error { - if len(input) < 2 || input[0] != '"' || input[len(input)-1] != '"' { - // if not quoted, it must be decimal - return z.fromDecimal(string(input)) - } - return z.UnmarshalText(input[1 : len(input)-1]) -} - -// MarshalText implements encoding.TextMarshaler -// MarshalText marshals using the decimal representation (compatible with big.Uint) -func (z *Uint) MarshalText() ([]byte, error) { - return []byte(z.Dec()), nil -} - -// UnmarshalText implements encoding.TextUnmarshaler. This method -// can unmarshal either hexadecimal or decimal. -// - For hexadecimal, the input _must_ be prefixed with 0x or 0X -func (z *Uint) UnmarshalText(input []byte) error { - if len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X') { - return z.fromHex(string(input)) - } - return z.fromDecimal(string(input)) -} - -const ( - hextable = "0123456789abcdef" - bintable = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x01\x02\x03\x04\x05\x06\a\b\t\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" - badNibble = 0xff -) - -// fromHex is the internal implementation of parsing a hex-string. -func (z *Uint) fromHex(hex string) error { - if err := checkNumberS(hex); err != nil { - return err - } - if len(hex) > 66 { - return ErrBig256Range - } - z.Clear() - end := len(hex) - for i := 0; i < 4; i++ { - start := end - 16 - if start < 2 { - start = 2 - } - for ri := start; ri < end; ri++ { - nib := bintable[hex[ri]] - if nib == badNibble { - return ErrSyntax - } - z.arr[i] = z.arr[i] << 4 - z.arr[i] += uint64(nib) - } - end = start - } - return nil -} - -// FromDecimal is a convenience-constructor to create an Uint from a -// decimal (base 10) string. Numbers larger than 256 bits are not accepted. -func FromDecimal(decimal string) *Uint { - var z Uint - if err := z.SetFromDecimal(decimal); err != nil { - panic(err.Error()) - } - return &z -} - const twoPow256Sub1 = "115792089237316195423570985008687907853269984665640564039457584007913129639935" // SetFromDecimal sets z from the given string, interpreted as a decimal number. @@ -888,42 +89,34 @@ func (z *Uint) SetFromDecimal(s string) (err error) { return ErrBig256Range } -var ( - ErrEmptyString = errors.New("empty hex string") - ErrSyntax = errors.New("invalid hex string") - ErrMissingPrefix = errors.New("hex string without 0x prefix") - ErrEmptyNumber = errors.New("hex string \"0x\"") - ErrLeadingZero = errors.New("hex number with leading zero digits") - ErrBig256Range = errors.New("hex number > 256 bits") - ErrBadBufferLength = errors.New("bad ssz buffer length") - ErrBadEncodedLength = errors.New("bad ssz encoded length") -) - -func checkNumberS(input string) error { - l := len(input) - if l == 0 { - return ErrEmptyString - } - if l < 2 || input[0] != '0' || - (input[1] != 'x' && input[1] != 'X') { - return ErrMissingPrefix - } - if l == 2 { - return ErrEmptyNumber +// FromDecimal is a convenience-constructor to create an Uint from a +// decimal (base 10) string. Numbers larger than 256 bits are not accepted. +func FromDecimal(decimal string) (*Uint, error) { + var z Uint + if err := z.SetFromDecimal(decimal); err != nil { + return nil, err } - if len(input) > 3 && input[2] == '0' { - return ErrLeadingZero + return &z, nil +} + +// MustFromDecimal is a convenience-constructor to create an Uint from a +// decimal (base 10) string. +// Returns a new Uint and panics if any error occurred. +func MustFromDecimal(decimal string) *Uint { + var z Uint + if err := z.SetFromDecimal(decimal); err != nil { + panic(err) } - return nil + return &z } // multipliers holds the values that are needed for fromDecimal var multipliers = [5]*Uint{ nil, // represents first round, no multiplication needed - &Uint{[4]uint64{10000000000000000000, 0, 0, 0}}, // 10 ^ 19 - &Uint{[4]uint64{687399551400673280, 5421010862427522170, 0, 0}}, // 10 ^ 38 - &Uint{[4]uint64{5332261958806667264, 17004971331911604867, 2938735877055718769, 0}}, // 10 ^ 57 - &Uint{[4]uint64{0, 8607968719199866880, 532749306367912313, 1593091911132452277}}, // 10 ^ 76 + {[4]uint64{10000000000000000000, 0, 0, 0}}, // 10 ^ 19 + {[4]uint64{687399551400673280, 5421010862427522170, 0, 0}}, // 10 ^ 38 + {[4]uint64{5332261958806667264, 17004971331911604867, 2938735877055718769, 0}}, // 10 ^ 57 + {[4]uint64{0, 8607968719199866880, 532749306367912313, 1593091911132452277}}, // 10 ^ 76 } // fromDecimal is a helper function to only ever be called via SetFromDecimal @@ -974,250 +167,119 @@ func (z *Uint) fromDecimal(bs string) error { return nil } -// lower(c) is a lower-case letter if and only if -// c is either that lower-case letter or the equivalent upper-case letter. -// Instead of writing c == 'x' || c == 'X' one can write lower(c) == 'x'. -// Note that lower of non-letters can produce other non-letters. -func lower(c byte) byte { - return c | ('x' - 'X') -} - -// ParseUint is like ParseUint but for unsigned numbers. -// -// A sign prefix is not permitted. -func parseUint(s string, base int, bitSize int) (uint64, error) { - const fnParseUint = "ParseUint" - - if s == "" { - return 0, errors.New("syntax error: ParseUint empty string") - } - - base0 := base == 0 - - s0 := s - switch { - case 2 <= base && base <= 36: - // valid base; nothing to do - - case base == 0: - // Look for octal, hex prefix. - base = 10 - if s[0] == '0' { - switch { - case len(s) >= 3 && lower(s[1]) == 'b': - base = 2 - s = s[2:] - case len(s) >= 3 && lower(s[1]) == 'o': - base = 8 - s = s[2:] - case len(s) >= 3 && lower(s[1]) == 'x': - base = 16 - s = s[2:] - default: - base = 8 - s = s[1:] - } +// Byte sets z to the value of the byte at position n, +// with 'z' considered as a big-endian 32-byte integer +// if 'n' > 32, f is set to 0 +// Example: f = '5', n=31 => 5 +func (z *Uint) Byte(n *Uint) *Uint { + // in z, z.arr[0] is the least significant + if number, overflow := n.Uint64WithOverflow(); !overflow { + if number < 32 { + number := z.arr[4-1-number/8] + offset := (n.arr[0] & 0x7) << 3 // 8*(n.d % 8) + z.arr[0] = (number & (0xff00000000000000 >> offset)) >> (56 - offset) + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + return z } - - default: - return 0, errors.New("invalid base") } - if bitSize == 0 { - bitSize = UintSize - } else if bitSize < 0 || bitSize > 64 { - return 0, errors.New("invalid bit size") - } + return z.Clear() +} - // Cutoff is the smallest number such that cutoff*base > maxUint64. - // Use compile-time constants for common cases. - var cutoff uint64 - switch base { - case 10: - cutoff = MaxUint64/10 + 1 - case 16: - cutoff = MaxUint64/16 + 1 +// BitLen returns the number of bits required to represent z +func (z *Uint) BitLen() int { + switch { + case z.arr[3] != 0: + return 192 + bits.Len64(z.arr[3]) + case z.arr[2] != 0: + return 128 + bits.Len64(z.arr[2]) + case z.arr[1] != 0: + return 64 + bits.Len64(z.arr[1]) default: - cutoff = MaxUint64/uint64(base) + 1 + return bits.Len64(z.arr[0]) } +} - maxVal := uint64(1)<= byte(base) { - return 0, errors.New("syntax error") - } - - if n >= cutoff { - // n*base overflows - return maxVal, errors.New("range error") - } - n *= uint64(base) +// ByteLen returns the number of bytes required to represent z +func (z *Uint) ByteLen() int { + return (z.BitLen() + 7) / 8 +} - n1 := n + uint64(d) - if n1 < n || n1 > maxVal { - // n+d overflows - return maxVal, errors.New("range error") - } - n = n1 - } +// Clear sets z to 0 +func (z *Uint) Clear() *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, 0 + return z +} - if underscores && !underscoreOK(s0) { - return 0, errors.New("syntax error") - } +const ( + // hextable = "0123456789abcdef" + bintable = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x01\x02\x03\x04\x05\x06\a\b\t\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + badNibble = 0xff +) - return n, nil +// SetFromHex sets z from the given string, interpreted as a hexadecimal number. +// OBS! This method is _not_ strictly identical to the (*big.Int).SetString(..., 16) method. +// Notable differences: +// - This method _require_ "0x" or "0X" prefix. +// - This method does not accept zero-prefixed hex, e.g. "0x0001" +// - This method does not accept underscore input, e.g. "100_000", +// - This method does not accept negative zero as valid, e.g "-0x0", +// - (this method does not accept any negative input as valid) +func (z *Uint) SetFromHex(hex string) error { + return z.fromHex(hex) } -// underscoreOK reports whether the underscores in s are allowed. -// Checking them in this one function lets all the parsers skip over them simply. -// Underscore must appear only between digits or between a base prefix and a digit. -func underscoreOK(s string) bool { - // saw tracks the last character (class) we saw: - // ^ for beginning of number, - // 0 for a digit or base prefix, - // _ for an underscore, - // ! for none of the above. - saw := '^' - i := 0 - - // Optional sign. - if len(s) >= 1 && (s[0] == '-' || s[0] == '+') { - s = s[1:] +// fromHex is the internal implementation of parsing a hex-string. +func (z *Uint) fromHex(hex string) error { + if err := checkNumberS(hex); err != nil { + return err } - - // Optional base prefix. - hex := false - if len(s) >= 2 && s[0] == '0' && (lower(s[1]) == 'b' || lower(s[1]) == 'o' || lower(s[1]) == 'x') { - i = 2 - saw = '0' // base prefix counts as a digit for "underscore as digit separator" - hex = lower(s[1]) == 'x' + if len(hex) > 66 { + return ErrBig256Range } - - // Number proper. - for ; i < len(s); i++ { - // Digits are always okay. - if '0' <= s[i] && s[i] <= '9' || hex && 'a' <= lower(s[i]) && lower(s[i]) <= 'f' { - saw = '0' - continue + z.Clear() + end := len(hex) + for i := 0; i < 4; i++ { + start := end - 16 + if start < 2 { + start = 2 } - // Underscore must follow digit. - if s[i] == '_' { - if saw != '0' { - return false + for ri := start; ri < end; ri++ { + nib := bintable[hex[ri]] + if nib == badNibble { + return ErrSyntax } - saw = '_' - continue - } - // Underscore must also be followed by digit. - if saw == '_' { - return false + z.arr[i] = z.arr[i] << 4 + z.arr[i] += uint64(nib) } - // Saw non-digit, non-underscore. - saw = '!' + end = start } - return saw != '_' + return nil } -// Dec returns the decimal representation of z. -func (z *Uint) Dec() string { // toString() - if z.IsZero() { - return "0" - } - if z.IsUint64() { - return FormatUint(z.Uint64(), 10) - } - - // The max uint64 value being 18446744073709551615, the largest - // power-of-ten below that is 10000000000000000000. - // When we do a DivMod using that number, the remainder that we - // get back is the lower part of the output. - // - // The ascii-output of remainder will never exceed 19 bytes (since it will be - // below 10000000000000000000). - // - // Algorithm example using 100 as divisor - // - // 12345 % 100 = 45 (rem) - // 12345 / 100 = 123 (quo) - // -> output '45', continue iterate on 123 - var ( - // out is 98 bytes long: 78 (max size of a string without leading zeroes, - // plus slack so we can copy 19 bytes every iteration). - // We init it with zeroes, because when strconv appends the ascii representations, - // it will omit leading zeroes. - out = []byte("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") - divisor = NewUint(10000000000000000000) // 20 digits - y = new(Uint).Set(z) // copy to avoid modifying z - pos = len(out) // position to write to - buf = make([]byte, 0, 19) // buffer to write uint64:s to - ) - for { - // Obtain Q and R for divisor - var quot Uint - rem := udivrem(quot.arr[:], y.arr[:], divisor) - y.Set(") // Set Q for next loop - // Convert the R to ascii representation - buf = AppendUint(buf[:0], rem.Uint64(), 10) - // Copy in the ascii digits - copy(out[pos-len(buf):], buf) - if y.IsZero() { - break - } - // Move 19 digits left - pos -= 19 +// FromHex is a convenience-constructor to create an Uint from +// a hexadecimal string. The string is required to be '0x'-prefixed +// Numbers larger than 256 bits are not accepted. +func FromHex(hex string) (*Uint, error) { + var z Uint + if err := z.fromHex(hex); err != nil { + return nil, err } - // skip leading zeroes by only using the 'used size' of buf - return string(out[pos-len(buf):]) + return &z, nil } -// Mod sets z to the modulus x%y for y != 0 and returns z. -// If y == 0, z is set to 0 (OBS: differs from the big.Uint) -func (z *Uint) Mod(x, y *Uint) *Uint { - if x.IsZero() || y.IsZero() { - return z.Clear() - } - switch x.Cmp(y) { - case -1: - // x < y - copy(z.arr[:], x.arr[:]) - return z - case 0: - // x == y - return z.Clear() // They are equal - } - - // At this point: - // x != 0 - // y != 0 - // x > y - - // Shortcut trivial case - if x.IsUint64() { - return z.SetUint64(x.Uint64() % y.Uint64()) +// MustFromHex is a convenience-constructor to create an Uint from +// a hexadecimal string. +// Returns a new Uint and panics if any error occurred. +func MustFromHex(hex string) *Uint { + var z Uint + if err := z.fromHex(hex); err != nil { + panic(err) } - - var quot Uint - *z = udivrem(quot.arr[:], x.arr[:], y) - return z + return &z } -// Clone creates a new Int identical to z +// Clone creates a new Uint identical to z func (z *Uint) Clone() *Uint { var x Uint x.arr[0] = z.arr[0] @@ -1227,15 +289,3 @@ func (z *Uint) Clone() *Uint { return &x } - -func (z *Uint) IsNil() bool { - return z == nil -} - -func (z *Uint) NilToZero() *Uint { - if z == nil { - z = NewUint(0) - } - - return z -} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..de0b384 --- /dev/null +++ b/utils.go @@ -0,0 +1,180 @@ +package u256 + +// lower(c) is a lower-case letter if and only if +// c is either that lower-case letter or the equivalent upper-case letter. +// Instead of writing c == 'x' || c == 'X' one can write lower(c) == 'x'. +// Note that lower of non-letters can produce other non-letters. +func lower(c byte) byte { + return c | ('x' - 'X') +} + +// underscoreOK reports whether the underscores in s are allowed. +// Checking them in this one function lets all the parsers skip over them simply. +// Underscore must appear only between digits or between a base prefix and a digit. +func underscoreOK(s string) bool { + // saw tracks the last character (class) we saw: + // ^ for beginning of number, + // 0 for a digit or base prefix, + // _ for an underscore, + // ! for none of the above. + saw := '^' + i := 0 + + // Optional sign. + if len(s) >= 1 && (s[0] == '-' || s[0] == '+') { + s = s[1:] + } + + // Optional base prefix. + hex := false + if len(s) >= 2 && s[0] == '0' && (lower(s[1]) == 'b' || lower(s[1]) == 'o' || lower(s[1]) == 'x') { + i = 2 + saw = '0' // base prefix counts as a digit for "underscore as digit separator" + hex = lower(s[1]) == 'x' + } + + // Number proper. + for ; i < len(s); i++ { + // Digits are always okay. + if '0' <= s[i] && s[i] <= '9' || hex && 'a' <= lower(s[i]) && lower(s[i]) <= 'f' { + saw = '0' + continue + } + // Underscore must follow digit. + if s[i] == '_' { + if saw != '0' { + return false + } + saw = '_' + continue + } + // Underscore must also be followed by digit. + if saw == '_' { + return false + } + // Saw non-digit, non-underscore. + saw = '!' + } + return saw != '_' +} + +func checkNumberS(input string) error { + const fn = "UnmarshalText" + l := len(input) + if l == 0 { + return errEmptyString(fn, input) + } + if l < 2 || input[0] != '0' || + (input[1] != 'x' && input[1] != 'X') { + return errMissingPrefix(fn, input) + } + if l == 2 { + return errEmptyNumber(fn, input) + } + if len(input) > 3 && input[2] == '0' { + return errLeadingZero(fn, input) + } + return nil +} + +// ParseUint is like ParseUint but for unsigned numbers. +// +// A sign prefix is not permitted. +func parseUint(s string, base int, bitSize int) (uint64, error) { + const fnParseUint = "ParseUint" + + if s == "" { + return 0, errSyntax(fnParseUint, s) + } + + base0 := base == 0 + + s0 := s + switch { + case 2 <= base && base <= 36: + // valid base; nothing to do + + case base == 0: + // Look for octal, hex prefix. + base = 10 + if s[0] == '0' { + switch { + case len(s) >= 3 && lower(s[1]) == 'b': + base = 2 + s = s[2:] + case len(s) >= 3 && lower(s[1]) == 'o': + base = 8 + s = s[2:] + case len(s) >= 3 && lower(s[1]) == 'x': + base = 16 + s = s[2:] + default: + base = 8 + s = s[1:] + } + } + + default: + return 0, errInvalidBase(fnParseUint, base) + } + + if bitSize == 0 { + bitSize = uintSize + } else if bitSize < 0 || bitSize > 64 { + return 0, errInvalidBitSize(fnParseUint, bitSize) + } + + // Cutoff is the smallest number such that cutoff*base > maxUint64. + // Use compile-time constants for common cases. + var cutoff uint64 + switch base { + case 10: + cutoff = MaxUint64/10 + 1 + case 16: + cutoff = MaxUint64/16 + 1 + default: + cutoff = MaxUint64/uint64(base) + 1 + } + + maxVal := uint64(1)<= byte(base) { + return 0, errSyntax(fnParseUint, s0) + } + + if n >= cutoff { + // n*base overflows + return maxVal, errRange(fnParseUint, s0) + } + n *= uint64(base) + + n1 := n + uint64(d) + if n1 < n || n1 > maxVal { + // n+d overflows + return maxVal, errRange(fnParseUint, s0) + } + n = n1 + } + + if underscores && !underscoreOK(s0) { + return 0, errSyntax(fnParseUint, s0) + } + + return n, nil +}