Skip to content

Commit

Permalink
Merge pull request #67 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.12 release
  • Loading branch information
ablaom authored Oct 19, 2023
2 parents afcd606 + 1ef8a99 commit d642ed0
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 31 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
version:
- '1.6'
- '1' # automatically expands to the latest stable 1.x release of Julia.
- '~1.9.0-0'
os:
- ubuntu-latest
arch:
Expand Down
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CategoricalDistributions"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.11"
version = "0.1.12"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -19,6 +19,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
UnivariateFiniteDisplayExt = "UnicodePlots"

[compat]
BenchmarkTools = "1.3.2"
CategoricalArrays = "0.9, 0.10"
Distributions = "0.25"
Missings = "0.4, 1"
Expand All @@ -28,11 +29,12 @@ UnicodePlots = "2, 3"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[targets]
test = ["FillArrays", "Random", "StableRNGs", "Test", "UnicodePlots"]
test = ["BenchmarkTools", "FillArrays", "Random", "StableRNGs", "Test", "UnicodePlots"]
37 changes: 21 additions & 16 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ end
"""
_cumulative(d::UnivariateFinite)
**Private method.**
Return the cumulative probability vector `C` for the distribution `d`,
using only classes in the support of `d`, ordered according to the
categorical elements used at instantiation of `d`. Used only to
Expand All @@ -238,6 +240,8 @@ end
"""
_rand(rng, p_cumulative, R)
**Private method.**
Randomly sample the distribution with discrete support `R(1):R(n)`
which has cumulative probability vector `p_cumulative` (see
[`_cummulative`](@ref)).
Expand All @@ -256,26 +260,27 @@ function _rand(rng, p_cumulative, R)
return index
end

function Base.rand(rng::AbstractRNG,
d::UnivariateFinite{<:Any,<:Any,R}) where R
p_cumulative = _cumulative(d)
return Dist.support(d)[_rand(rng, p_cumulative, R)]
Random.eltype(::Type{<:UnivariateFinite{<:Any,V}}) where V = V

# The Sampler hook into Random's API is discussed in the Julia documentation, in the
# Standard Library section on Random.
function Random.Sampler(
::AbstractRNG,
d::UnivariateFinite,
::Random.Repetition,
)
data = (_cumulative(d), Dist.support(d))
Random.SamplerSimple(d, data)
end

function Base.rand(rng::AbstractRNG,
d::UnivariateFinite{<:Any,<:Any,R},
dim1::Int, moredims::Int...) where R # ref type
p_cumulative = _cumulative(d)
A = Array{R}(undef, dim1, moredims...)
for i in eachindex(A)
@inbounds A[i] = _rand(rng, p_cumulative, R)
end
support = Dist.support(d)
return broadcast(i -> support[i], A)
function Base.rand(
rng::AbstractRNG,
sampler::Random.SamplerSimple{<:UnivariateFinite{<:Any,<:Any,R}},
) where R
p_cumulative, support = sampler.data
return support[_rand(rng, p_cumulative, R)]
end

rng(d::UnivariateFinite, args...) = rng(Random.GLOBAL_RNG, d, args...)

function Dist.fit(d::Type{<:UnivariateFinite},
v::AbstractVector{C}) where C
C <: CategoricalValue ||
Expand Down
25 changes: 15 additions & 10 deletions test/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
module TestArithmetic

using Test
import BenchmarkTools: @belapsed
using CategoricalDistributions
using StableRNGs
rng = StableRNG(123)

macro belapsed1(ex)
:(@belapsed $ex seconds=1 evals=1)
end

L = ["yes", "no"]
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
Expand Down Expand Up @@ -46,27 +51,27 @@ fast = UnivariateFinite(L, P, pool=missing);

@testset "performant arithmetic for UnivariateFiniteArray" begin
@test pdf(slow + slow, L) == pdf(fast + fast, L)
t_slow = @elapsed @eval slow + slow
t_fast = @elapsed @eval fast + fast
t_slow = @belapsed1 $slow + $slow
t_fast = @belapsed1 $fast + $fast
@test t_slow/t_fast > 10

@test pdf(slow - slow, L) == pdf(fast - fast, L)
t_slow = @elapsed @eval slow - slow
t_fast = @elapsed @eval fast - fast
t_slow = @belapsed1 $slow - $slow
t_fast = @belapsed1 $fast - $fast
@test t_slow/t_fast > 10

@test pdf(42*slow, L) == pdf(42*fast, L)
@test pdf(slow*42, L) == pdf(fast*42, L)
t_slow = @elapsed @eval 42*slow
t_fast = @elapsed @eval 42*fast
t_slow = @belapsed1 42*$slow
t_fast = @belapsed1 42*$fast
@test t_slow/t_fast > 10
t_slow = @elapsed @eval slow*42
t_fast = @elapsed @eval fast*42
t_slow = @belapsed1 $slow*42
t_fast = @belapsed1 $fast*42
@test t_slow/t_fast > 10

@test pdf(slow/42, L) == pdf(fast/42, L)
t_slow = @elapsed @eval slow/42
t_fast = @elapsed @eval fast/42
t_slow = @belapsed1 $slow/42
t_fast = @belapsed1 $fast/42
@test t_slow/t_fast > 10
end

Expand Down
25 changes: 24 additions & 1 deletion test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using StableRNGs
import Random
rng = StableRNG(123)
using ScientificTypes
import Random.default_rng

import CategoricalDistributions: classes, ERR_NAN_FOUND

Expand Down Expand Up @@ -127,7 +128,7 @@ end
@testset "broadcasting pdf over single UnivariateFinite object" begin
d = UnivariateFinite(["a", "b"], [0.1, 0.9], pool=missing);
@test pdf.(d, ["a", "b"]) == [0.1, 0.9]
end
end

@testset "constructor arguments not categorical values" begin
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2))
Expand Down Expand Up @@ -299,6 +300,28 @@ end
@test displays_okay([5 + 3im, 4 - 7im])
end

@testset "rand signatures" begin
d = UnivariateFinite(
["maybe", "no", "yes"],
[0.5, 0.4, 0.1];
pool=missing,
)

# smoke test:
sampler = Random.Sampler(default_rng(), d, Val(1))
rand(default_rng(), sampler)

Random.seed!(123)
samples = [rand(default_rng(), d) for i in 1:30]
Random.seed!(123)
@test [rand(d) for i in 1:30] == samples

Random.seed!(123)
samples = rand(Random.default_rng(), d, 3, 5)
Random.seed!(123)
@test samples == rand(d, 3, 5)
end

end # module

true

0 comments on commit d642ed0

Please sign in to comment.