Skip to content

Commit

Permalink
fix implementation of rand to close #65
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Oct 15, 2023
1 parent 66f23a3 commit 004a468
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ end
"""
_cumulative(d::UnivariateFinite)
**Private method.**
Return the cumulative probability vector `C` for the distribution `d`,
using only classes in the support of `d`, ordered according to the
categorical elements used at instantiation of `d`. Used only to
Expand All @@ -238,13 +240,15 @@ end
"""
_rand(rng, p_cumulative, R)
**Private method.**
Randomly sample the distribution with discrete support `R(1):R(n)`
which has cumulative probability vector `p_cumulative` (see
[`_cummulative`](@ref)).
"""
function _rand(rng, p_cumulative, R)
real_sample = rand(rng)*p_cumulative[end]
real_sample = Base.rand(rng)*p_cumulative[end]
K = R(length(p_cumulative))
index = K
for i in R(2):R(K)
Expand All @@ -261,10 +265,11 @@ function Base.rand(rng::AbstractRNG,
p_cumulative = _cumulative(d)
return Dist.support(d)[_rand(rng, p_cumulative, R)]
end
Base.rand(d::UnivariateFinite) = rand(Random.default_rng(), d)

Check warning on line 268 in src/methods.jl

View check run for this annotation

Codecov / codecov/patch

src/methods.jl#L268

Added line #L268 was not covered by tests

function Base.rand(rng::AbstractRNG,
d::UnivariateFinite{<:Any,<:Any,R},
dim1::Int, moredims::Int...) where R # ref type
dim1::Integer, moredims::Integer...) where R # ref type
p_cumulative = _cumulative(d)
A = Array{R}(undef, dim1, moredims...)
for i in eachindex(A)
Expand All @@ -274,7 +279,8 @@ function Base.rand(rng::AbstractRNG,
return broadcast(i -> support[i], A)
end

rng(d::UnivariateFinite, args...) = rng(Random.GLOBAL_RNG, d, args...)
Base.rand(d::UnivariateFinite, dim1::Integer, moredims::Integer...) =

Check warning on line 282 in src/methods.jl

View check run for this annotation

Codecov / codecov/patch

src/methods.jl#L282

Added line #L282 was not covered by tests
rand(Random.default_rng(), d, dim1, moredims...)

function Dist.fit(d::Type{<:UnivariateFinite},
v::AbstractVector{C}) where C
Expand Down
23 changes: 22 additions & 1 deletion test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using StableRNGs
import Random
rng = StableRNG(123)
using ScientificTypes
import Random.default_rng

import CategoricalDistributions: classes, ERR_NAN_FOUND

Expand Down Expand Up @@ -127,7 +128,7 @@ end
@testset "broadcasting pdf over single UnivariateFinite object" begin
d = UnivariateFinite(["a", "b"], [0.1, 0.9], pool=missing);
@test pdf.(d, ["a", "b"]) == [0.1, 0.9]
end
end

@testset "constructor arguments not categorical values" begin
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2))
Expand Down Expand Up @@ -299,6 +300,26 @@ end
@test displays_okay([5 + 3im, 4 - 7im])
end

if VERSION >= v"1.7"
@testset "rand signatures" begin
d = UnivariateFinite(
["maybe", "no", "yes"],
[0.5, 0.4, 0.1];
pool=missing,
)

Random.seed!(123)
samples = [rand(default_rng(), d) for i in 1:30]
Random.seed!(123)
@test [rand(d) for i in 1:30] == samples

Random.seed!(123)
samples = rand(Random.default_rng(), d, 3, 5)
Random.seed!(123)
@test samples == rand(d, 3, 5)
end
end

end # module

true

0 comments on commit 004a468

Please sign in to comment.