Skip to content

Commit

Permalink
feat: add Mersenne Twister pseudorandom generator
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais committed Oct 12, 2024
1 parent daf1ed9 commit b5c9d8e
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
1 change: 1 addition & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Batteries.Data.MLList
import Batteries.Data.Nat
import Batteries.Data.PairingHeap
import Batteries.Data.RBMap
import Batteries.Data.Random
import Batteries.Data.Range
import Batteries.Data.Rat
import Batteries.Data.String
Expand Down
1 change: 1 addition & 0 deletions Batteries/Data/Random.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import Batteries.Data.Random.MersenneTwister
143 changes: 143 additions & 0 deletions Batteries/Data/Random/MersenneTwister.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: François G. Dorais
-/
import Batteries.Data.Vector

/-! # Mersenne Twister
Reference implementation for the Mersenne Twister pseudorandom number generator.
### References:
- Matsumoto, Makoto and Nishimura, Takuji (1998),
[**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995),
ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30.
[ZBL0917.65005](https://zbmath.org/?q=an:0917.65005).
- Nishimura, Takuji (2000),
[**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540),
ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357.
[ZBL1390.65014](https://zbmath.org/?q=an:1390.65014).
-/

namespace Batteries.Random.MersenneTwister

/--
Mersenne Twister configuration.
Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and
Nishimura (2000).
-/
structure Config where
/-- Word size (`w`). -/
wordSize : Nat
/-- Degree of recurrence (`n`). -/
stateSize : Nat
/-- Middle word (`m`). -/
shiftSize : Fin stateSize
/-- Twist value (`r`). -/
maskBits : Fin wordSize
/-- Coefficients of the twist matrix (`a`). -/
xorMask : BitVec wordSize
/-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/
temperingShifts : Nat × Nat × Nat × Nat
/-- Tempering mask parameters (`d`, `b`, `c`). -/
temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize
/-- Initialization multiplier (`f`). -/
initMult : BitVec wordSize
/-- Default initialization seed value. -/
initSeed : BitVec wordSize

private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize :=
BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val

private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize :=
BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val)

@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize :=
Nat.zero_lt_of_lt cfg.maskBits.is_lt

@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize :=
Nat.zero_lt_of_lt cfg.shiftSize.is_lt

/-- Mersenne Twister State. -/
structure State (cfg : Config) where
/-- Data for current state. -/
data : Vector (BitVec cfg.wordSize) cfg.stateSize
/-- Current data index. -/
index : Fin cfg.stateSize

/-- Mersenne Twister initialization given an optional seed. -/
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config)
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg :=
⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
where
/-- Inner loop for Mersenne Twister initalization. -/
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) :=
if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else
let v := v.push w
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size
loop w v (by simp only [v, Array.size_push]; omega)

/-- Update the state by a number of generation steps (default 1). -/
@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg :=
loop state steps
where
/-- Inner loop for Mersenne Twister update. -/
@[inline] loop (s : State cfg) (c : Nat) : State cfg :=
if c = 0 then s else
let i := s.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else0, cfg.zero_lt_stateSize⟩
let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask
let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
loop ⟨s.data.set i x, i'⟩ (c-1)

/-- Mersenne Twister iteration. -/
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg :=
let i := state.index
let s := state.update
(temper s.data[i], s)
where
/-- Tempering step for Mersenne Twister. -/
@[inline] temper (x : BitVec cfg.wordSize) :=
match cfg.temperingShifts, cfg.temperingMasks with
| (u, s, t, l), (d, b, c) =>
let x := x ^^^ x >>> u &&& d
let x := x ^^^ x <<< s &&& b
let x := x ^^^ x <<< t &&& c
x ^^^ x >>> l

instance (cfg) : RandomGen (State cfg) where
range _ := (0, 2 ^ cfg.wordSize - 1)
next s := match s.next with | (r, s) => (r.toNat, s)
split s := let (a, s) := s.next; (s, cfg.init a)

instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where
next? s := s.next

/-- 32 bit Mersenne Twister (MT19937) configuration. -/
def mt19937 : Config where
wordSize := 32
stateSize := 624
shiftSize := 397
maskBits := 31
xorMask := 0x9908b0df
temperingShifts := (11, 7, 15, 18)
temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000)
initMult := 1812433253
initSeed := 4357

/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/
def mt19937_64 : Config where
wordSize := 64
stateSize := 312
shiftSize := 156
maskBits := 31
xorMask := 0xb5026f5aa96619e9
temperingShifts := (29, 17, 37, 43)
temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000)
initMult := 6364136223846793005
initSeed := 19650218

0 comments on commit b5c9d8e

Please sign in to comment.