diff --git a/Batteries.lean b/Batteries.lean index f04a7a1f76..613fb429e4 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -29,6 +29,7 @@ import Batteries.Data.MLList import Batteries.Data.Nat import Batteries.Data.PairingHeap import Batteries.Data.RBMap +import Batteries.Data.Random import Batteries.Data.Range import Batteries.Data.Rat import Batteries.Data.String diff --git a/Batteries/Data/Random.lean b/Batteries/Data/Random.lean new file mode 100644 index 0000000000..cf1e720ee0 --- /dev/null +++ b/Batteries/Data/Random.lean @@ -0,0 +1 @@ +import Batteries.Data.Random.MersenneTwister diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean new file mode 100644 index 0000000000..1318e68c23 --- /dev/null +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -0,0 +1,140 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Vector + +/-! # Mersenne Twister + +Reference implementation for the Mersenne Twister pseudorandom number generator. + +### References: + +- Matsumoto, Makoto and Nishimura, Takuji (1998), + [**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995), + ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30. + [ZBL0917.65005](https://zbmath.org/?q=an:0917.65005). + +- Nishimura, Takuji (2000), + [**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540), + ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357. + [ZBL1390.65014](https://zbmath.org/?q=an:1390.65014). +-/ + +namespace Batteries.Random.MersenneTwister + +/-- +Mersenne Twister configuration. + +Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and +Nishimura (2000). +-/ +structure Config where + /-- Word size (`w`). -/ + wordSize : Nat + /-- Degree of recurrence (`n`). -/ + stateSize : Nat + /-- Middle word (`m`). -/ + shiftSize : Fin stateSize + /-- Twist value (`r`). -/ + maskBits : Fin wordSize + /-- Coefficients of the twist matrix (`a`). -/ + xorMask : BitVec wordSize + /-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/ + temperingShifts : Nat × Nat × Nat × Nat + /-- Tempering mask parameters (`d`, `b`, `c`). -/ + temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize + /-- Initialization multiplier (`f`). -/ + initMult : BitVec wordSize + /-- Default initialization seed value. -/ + initSeed : BitVec wordSize + +private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val + +private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val) + +@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize := + Nat.zero_lt_of_lt cfg.maskBits.is_lt + +@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize := + Nat.zero_lt_of_lt cfg.shiftSize.is_lt + +/-- Mersenne Twister State. -/ +structure State (cfg : Config) where + /-- Data for current state. -/ + data : Vector (BitVec cfg.wordSize) cfg.stateSize + /-- Current data index. -/ + index : Fin cfg.stateSize + +/-- Mersenne Twister initialization given an optional seed. -/ +@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) + (seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := + ⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ +where + loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := + if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else + let v := v.push w + let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size + loop w v (by simp only [v, Array.size_push]; omega) + +/-- Update the state by a number of generation steps (default 1). -/ +@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg := + loop state steps +where + @[inline] loop (s : State cfg) (c : Nat) : State cfg := + if c = 0 then s else + let i := s.index + let i' : Fin cfg.stateSize := + if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ + let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask + let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 + loop ⟨s.data.set i x, i'⟩ (c-1) + +/-- Mersenne Twister iteration. -/ +@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := + let i := state.index + let s := state.update + (temper s.data[i], s) +where + @[inline] temper (x : BitVec cfg.wordSize) := + match cfg.temperingShifts, cfg.temperingMasks with + | (u, s, t, l), (d, b, c) => + let x := x ^^^ x >>> u &&& d + let x := x ^^^ x <<< s &&& b + let x := x ^^^ x <<< t &&& c + x ^^^ x >>> l + +instance (cfg) : RandomGen (State cfg) where + range _ := (0, 2 ^ cfg.wordSize - 1) + next s := match s.next with | (r, s) => (r.toNat, s) + split s := let (a, s) := s.next; (s, cfg.init a) + +instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where + next? s := s.next + +/-- 32 bit Mersenne Twister (MT19937) configuration. -/ +def mt19937 : Config where + wordSize := 32 + stateSize := 624 + shiftSize := 397 + maskBits := 31 + xorMask := 0x9908b0df + temperingShifts := (11, 7, 15, 18) + temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000) + initMult := 1812433253 + initSeed := 4357 + +/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/ +def mt19937_64 : Config where + wordSize := 64 + stateSize := 312 + shiftSize := 156 + maskBits := 31 + xorMask := 0xb5026f5aa96619e9 + temperingShifts := (29, 17, 37, 43) + temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000) + initMult := 6364136223846793005 + initSeed := 19650218