From bac35c2840e6b5a71afd9ccad54f3a89b62ceb8c Mon Sep 17 00:00:00 2001 From: PoorvaGarg Date: Tue, 19 Dec 2023 12:41:42 -0800 Subject: [PATCH] bit blast geometric --- src/dist/number/fix.jl | 17 +++++++++++++++-- test/dist/number/fix_test.jl | 6 ++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/dist/number/fix.jl b/src/dist/number/fix.jl index c1394916..c25df082 100644 --- a/src/dist/number/fix.jl +++ b/src/dist/number/fix.jl @@ -3,7 +3,7 @@ using SymPy @vars varint @vars v2 -export DistFix, bitblast, bitblast_linear, bitblast_exponential, bitblast_exact, unit_exponential, exponential, laplace, unit_gamma, shift_point_gamma, n_unit_exponentials +export DistFix, bitblast, bitblast_linear, bitblast_exponential, bitblast_exact, unit_exponential, exponential, laplace, unit_gamma, shift_point_gamma, n_unit_exponentials, geometric ################################## # types, structs, and constructors @@ -53,7 +53,7 @@ function Base.convert(::Type{DistFix{W2, F2}}, x::DistFix{W1, F1}) where {W1,W2, mantissa = convert(DistInt{W1+(F2-F1)}, x.mantissa) mantissa <<= (F2-F1) else #F2 < F1 - mantissa = drop_bits(DistInt{W1+(F2-F1)}, x.mantissa; last=true) + mantissa = drop_bits(DistInt{W1+(F2-F1)}, x.mantissa; last=false) end convert(DistFix{W2, F2}, DistFix{W1+(F2-F1), F2}(mantissa)) end @@ -646,3 +646,16 @@ end # Y # end +###################################################################################################### +# bit blasting geometric distribution: https://en.wikipedia.org/wiki/Geometric_distribution +###################################################################################################### + +# Creates a geometric distribution over the integers in the range [0, stop-1] +function geometric(::Type{DistFix{W, F}}, success::Float64, stop::Int) where {W, F} + #TODO: use convert + @assert ispow2(stop) + bits = Int(log2(stop)) + @assert W - F > bits + + convert(DistFix{W, F}, DistFix{W, 0}(unit_exponential(DistFix{bits+1, bits}, log(1 - success)*2^bits).mantissa)) +end diff --git a/test/dist/number/fix_test.jl b/test/dist/number/fix_test.jl index 9c0293b7..a2c80f1c 100644 --- a/test/dist/number/fix_test.jl +++ b/test/dist/number/fix_test.jl @@ -388,5 +388,11 @@ end #TODO test for beta = 0 #TODO test for positive beta + # Truncated Geometric + x = Truncated(Geometric(0.8), 0, 32) + @test_throws Exception a = geometric(DistFix{6, 2}, 0.8, 32) + a = geometric(DistFix{8, 2}, 0.8, 32) + @test pr(a)[0.0] ≈ pdf(x, 0.0) + @test pr(a)[19] ≈ pdf(x, 19.0) end