Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Mersenne Twister PRNG #984

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
27 changes: 13 additions & 14 deletions Batteries/Data/Random/MersenneTwister.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ structure State (cfg : Config) where
/-- Mersenne Twister initialization given an optional seed. -/
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config)
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg :=
⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
where
/-- Inner loop for Mersenne Twister initalization. -/
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) :=
Expand All @@ -81,24 +81,23 @@ where
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size
loop w v (by simp only [v, Array.size_push]; omega)

/-- Apply the twisting transformation to the given state. -/
@[specialize cfg] protected def State.twist (state : State cfg) : State cfg :=
let i := state.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩
let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask
let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
⟨state.data.set i x, i'⟩

/-- Update the state by a number of generation steps (default 1). -/
@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg :=
loop state steps
where
/-- Inner loop for Mersenne Twister update. -/
@[inline] loop (s : State cfg) (c : Nat) : State cfg :=
if c = 0 then s else
let i := s.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩
let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask
let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
loop ⟨s.data.set i x, i'⟩ (c-1)
@[inline] protected def State.update (state : State cfg) (steps := 1) : State cfg :=
fgdorais marked this conversation as resolved.
Show resolved Hide resolved
if steps = 0 then state else state.twist.update (steps-1)

/-- Mersenne Twister iteration. -/
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg :=
let i := state.index
let s := state.update
let s := state.twist
(temper s.data[i], s)
where
/-- Tempering step for Mersenne Twister. -/
Expand Down