-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Mersenne Twister pseudorandom generator
- Loading branch information
Showing
3 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import Batteries.Data.Random.MersenneTwister |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/- | ||
Copyright (c) 2024 François G. Dorais. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: François G. Dorais | ||
-/ | ||
import Batteries.Data.Vector | ||
|
||
/-! # Mersenne Twister | ||
Reference implementation for the Mersenne Twister pseudorandom number generator. | ||
### References: | ||
- Matsumoto, Makoto and Nishimura, Takuji (1998), | ||
[**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995), | ||
ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30. | ||
[ZBL0917.65005](https://zbmath.org/?q=an:0917.65005). | ||
- Nishimura, Takuji (2000), | ||
[**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540), | ||
ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357. | ||
[ZBL1390.65014](https://zbmath.org/?q=an:1390.65014). | ||
-/ | ||
|
||
namespace Batteries.Random.MersenneTwister | ||
|
||
/-- | ||
Mersenne Twister configuration. | ||
Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and | ||
Nishimura (2000). | ||
-/ | ||
structure Config where | ||
/-- Word size (`w`). -/ | ||
wordSize : Nat | ||
/-- Degree of recurrence (`n`). -/ | ||
stateSize : Nat | ||
/-- Middle word (`m`). -/ | ||
shiftSize : Fin stateSize | ||
/-- Twist value (`r`). -/ | ||
maskBits : Fin wordSize | ||
/-- Coefficients of the twist matrix (`a`). -/ | ||
xorMask : BitVec wordSize | ||
/-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/ | ||
temperingShifts : Nat × Nat × Nat × Nat | ||
/-- Tempering mask parameters (`d`, `b`, `c`). -/ | ||
temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize | ||
/-- Initialization multiplier (`f`). -/ | ||
initMult : BitVec wordSize | ||
/-- Default initialization seed value. -/ | ||
initSeed : BitVec wordSize | ||
|
||
private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize := | ||
BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val | ||
|
||
private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize := | ||
BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val) | ||
|
||
@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize := | ||
Nat.zero_lt_of_lt cfg.maskBits.is_lt | ||
|
||
@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize := | ||
Nat.zero_lt_of_lt cfg.shiftSize.is_lt | ||
|
||
/-- Mersenne Twister State. -/ | ||
structure State (cfg : Config) where | ||
/-- Data for current state. -/ | ||
data : Vector (BitVec cfg.wordSize) cfg.stateSize | ||
/-- Current data index. -/ | ||
index : Fin cfg.stateSize | ||
|
||
/-- Mersenne Twister initialization given an optional seed. -/ | ||
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) | ||
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := | ||
⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ | ||
where | ||
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := | ||
if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else | ||
let v := v.push w | ||
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size | ||
loop w v (by simp only [v, Array.size_push]; omega) | ||
|
||
/-- Update the state by a number of generation steps (default 1). -/ | ||
@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg := | ||
loop state steps | ||
where | ||
@[inline] loop (s : State cfg) (c : Nat) : State cfg := | ||
if c = 0 then s else | ||
let i := s.index | ||
let i' : Fin cfg.stateSize := | ||
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ | ||
let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask | ||
let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 | ||
loop ⟨s.data.set i x, i'⟩ (c-1) | ||
|
||
/-- Mersenne Twister iteration. -/ | ||
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := | ||
let i := state.index | ||
let s := state.update | ||
(temper s.data[i], s) | ||
where | ||
@[inline] temper (x : BitVec cfg.wordSize) := | ||
match cfg.temperingShifts, cfg.temperingMasks with | ||
| (u, s, t, l), (d, b, c) => | ||
let x := x ^^^ x >>> u &&& d | ||
let x := x ^^^ x <<< s &&& b | ||
let x := x ^^^ x <<< t &&& c | ||
x ^^^ x >>> l | ||
|
||
instance (cfg) : RandomGen (State cfg) where | ||
range _ := (0, 2 ^ cfg.wordSize - 1) | ||
next s := match s.next with | (r, s) => (r.toNat, s) | ||
split s := let (a, s) := s.next; (s, cfg.init a) | ||
|
||
instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where | ||
next? s := s.next | ||
|
||
/-- 32 bit Mersenne Twister (MT19937) configuration. -/ | ||
def mt19937 : Config where | ||
wordSize := 32 | ||
stateSize := 624 | ||
shiftSize := 397 | ||
maskBits := 31 | ||
xorMask := 0x9908b0df | ||
temperingShifts := (11, 7, 15, 18) | ||
temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000) | ||
initMult := 1812433253 | ||
initSeed := 4357 | ||
|
||
/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/ | ||
def mt19937_64 : Config where | ||
wordSize := 64 | ||
stateSize := 312 | ||
shiftSize := 156 | ||
maskBits := 31 | ||
xorMask := 0xb5026f5aa96619e9 | ||
temperingShifts := (29, 17, 37, 43) | ||
temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000) | ||
initMult := 6364136223846793005 | ||
initSeed := 19650218 |