-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dictionaries-based implementation appears 45x faster for rand #64
Comments
Thanks @JockLawrie for the insights here. Started playing around with this but tripped on #65. I'll return to this after addressing that. |
Okayed, I've looked into this bit more. Choice of containerCurrently the implementation already uses dictionaries but:
If we swap out Precomputation of inverse pdfIn the current implementation, a cumulative pdf is pre-computed (following the implementation of Distributions.Categorical, but which may have changed since we adapted it). This avoids repeating arithmetic when drawing multiple samples at once. So, if we repeat your experiment for 10_000 samples, but use Next steps for the single sample caseIt doesn't seem likely to me that the slowdown you have observed is mostly due to the precomputation of the inverse CDF, but I haven't tested this. My guess is the slowdown is somewhere else. Perhaps if we reimplemented
Needs CategoricalDistributions 0.1.12:
using Dictionaries
using OrderedCollections
using Distributions
using Random
using CategoricalDistributions
struct CatDist{K} <: Distribution{Univariate, Discrete}
dict::Dictionary{K, Float64} # TODO: Constructor that checks probs, unique levels, etc
end
CatDist(levels, probs) = CatDist(Dictionary(levels, probs))
Base.eltype(d::CatDist{K}) where {K} = K
function Base.rand(rng::Random.AbstractRNG, d::CatDist{K}) where {K}
u = rand(rng)
total = 0.0
for (level, prob) in pairs(d.dict)
total += prob
u <= total && return level
end
end
struct LittleCatDist{K} <: Distribution{Univariate, Discrete}
dict::LittleDict{K, Float64, Vector{K}, Vector{Float64}}
end
LittleCatDist(levels, probs) = LittleCatDist(LittleDict(levels, probs))
Base.eltype(d::LittleCatDist{K}) where {K} = K
function Base.rand(rng::Random.AbstractRNG, d::LittleCatDist{K}) where {K}
u = rand(rng)
total = 0.0
for (level, prob) in pairs(d.dict)
total += prob
u <= total && return level
end
end
## Benchmarks for single sample
using BenchmarkTools
categorical = Categorical([0.5, 0.4, 0.1])
catdist = CatDist(["maybe", "no", "yes"], [0.5, 0.4, 0.1])
littlecatdist = LittleCatDist(["maybe", "no", "yes"], [0.5, 0.4, 0.1])
d = UnivariateFinite(["maybe", "no", "yes"], [0.5, 0.4, 0.1]; pool=missing)
@btime rand($categorical) # 11.6 ns
@btime rand($catdist) # 13.7 ns
@btime rand($littlecatdist) # 12.1 ns
@btime rand($d) # 855 ns
## Benchmarks for multiple samples
N = 10_000
@btime rand($categorical, N); # 141 μs
@btime [rand($catdist) for i in 1:N]; # 136 μs
@btime [rand($littlecatdist) for i in 1:N]; # 128 μs
@btime rand($d, N); # 137 μs
## Benchmarks for larger numbers of classes
N = 10_0000
M = 100
p = rand(100)
p = p ./ sum(p)
labels = 1:M
categorical = Categorical(p)
catdist = CatDist(labels, p)
littlecatdist = LittleCatDist(labels, p)
d = UnivariateFinite(labels, p; pool=missing)
@btime rand($categorical, N); # 1.41 ms
@btime [rand($catdist) for i in 1:N]; # 6.50 ms
@btime [rand($littlecatdist) for i in 1:N]; # 6.40 ms
@btime rand($d, N); # 4.62 ms |
@JockLawrie For your use case, do you actually need a fast call of |
Thanks for looking into this Anthony, much appreciated. I do need a fast Looks like the Perhaps a pre-computed CDF for |
Closed by #68 |
Hi there,
I need a categorical distribution with a fast
rand
method to be used in simulations.I found the bare-bones implementation below to be about 45x faster for
rand
on my machine.It is based on
Dictionaries.jl
.Hopefully I've done the benchmarking correctly.
If so, is a
Dictionaries.jl
-based implementation suitable forCategoricalDistributions.jl
?Can it cover the uses cases that
CategoricalDistributions.jl
satisfies?The text was updated successfully, but these errors were encountered: