Skip to content

Commit

Permalink
Merge pull request #61 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.11 release
  • Loading branch information
ablaom authored Sep 13, 2023
2 parents f691205 + a07635d commit afcd606
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CategoricalDistributions"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.10"
version = "0.1.11"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Arrays of `UnivariateFinite` distributions are defined using the same
constructor. Broadcasting methods, such as `pdf`, are optimized for
such arrays:

```
```julia
julia> v = UnivariateFinite(["no", "yes"], [0.1, 0.2, 0.3, 0.4], augment=true, pool=data)
4-element UnivariateFiniteArray{Multiclass{3}, String, UInt32, Float64, 1}:
UnivariateFinite{Multiclass{3}}(no=>0.9, yes=>0.1)
Expand Down Expand Up @@ -119,7 +119,6 @@ julia> pdf(v, L)
0.0 0.6 0.4
```


## Measures over finite labeled sets

There is, in fact, no enforcement that probabilities in a
Expand All @@ -128,7 +127,6 @@ to a type `T` for which `zero(T)` is defined. In particular
`UnivariateFinite` objects implement arbitrary non-negative, signed,
or complex measures over a finite labeled set.


## What does this package provide?

- A new type `UnivariateFinite{S}` for representing probability
Expand All @@ -144,7 +142,7 @@ or complex measures over a finite labeled set.
- Implementations of `rand` for generating random samples of a
`UnivariateFinite` distribution.

- Implementations of the `pdf`, `logpdf` and `mode` methods of
- Implementations of the `pdf`, `logpdf`, `mode` and `modes` methods of
Distributions.jl, with efficient broadcasting over the new array
type.

Expand Down
4 changes: 2 additions & 2 deletions src/CategoricalDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Random

const Dist = Distributions

import Distributions: pdf, logpdf, support, mode
import Distributions: pdf, logpdf, support, mode, modes

include("utilities.jl")
include("types.jl")
Expand All @@ -28,7 +28,7 @@ include("arithmetic.jl")
export UnivariateFinite, UnivariateFiniteArray, UnivariateFiniteVector

# re-eport from Distributions:
export pdf, logpdf, support, mode
export pdf, logpdf, support, mode, modes

# re-export from ScientificTypesBase:
export Multiclass, OrderedFactor
Expand Down
22 changes: 21 additions & 1 deletion src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ Base.Broadcast.broadcasted(
c::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))


## PERFORMANT BROADCASTING OF mode:
## PERFORMANT BROADCASTING OF mode(s):

function Base.Broadcast.broadcasted(::typeof(mode),
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
Expand All @@ -298,6 +298,26 @@ function Base.Broadcast.broadcasted(::typeof(mode),
return reshape(mode_flat, size(u))
end

function Base.Broadcast.broadcasted(::typeof(modes),
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
dic = u.prob_given_ref

# using linear indexing:
mode_flat = map(1:length(u)) do i
max_prob = maximum(dic[ref][i] for ref in keys(dic))
M = R[]

# see comment for in broadcasted(::mode) above
throw_nan_error_if_needed(max_prob)
for ref in keys(dic)
if dic[ref][i] == max_prob
push!(M, ref)
end
end
return u.decoder(M)
end
return reshape(mode_flat, size(u))
end

## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE

Expand Down
32 changes: 24 additions & 8 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ end
# TODO: It would be useful to define == as well.

"""
Dist.pdf(d::UnivariateFinite, x)
Distributions.pdf(d::UnivariateFinite, x)
Probability of `d` at `x`.
Expand Down Expand Up @@ -178,15 +178,31 @@ function Dist.mode(d::UnivariateFinite)
return d.decoder(m)
end

function Dist.modes(d::UnivariateFinite{S,V,R,P}) where {S,V,R,P}
dic = d.prob_given_ref
p = values(dic)
max_prob = maximum(p)
M = R[] # modes

# see comment in `mode` above
throw_nan_error_if_needed(max_prob)
for (x, prob) in dic
if prob == max_prob
push!(M, x)
end
end
return d.decoder(M)
end

const ERR_NAN_FOUND = DomainError(
NaN,
"`mode(s)` is invalid for a `UnivariateFinite` distribution "*
"with `pdf` containing `NaN`s"
)

function throw_nan_error_if_needed(x)
if isnan(x)
throw(
DomainError(
NaN,
"`mode` is invalid for `UnivariateFininite` distribution "*
"with `pdf` containing `NaN`s"
)
)
throw(ERR_NAN_FOUND)
end
end

Expand Down
51 changes: 46 additions & 5 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Random
using Missings
using ScientificTypes

import CategoricalDistributions: classes
import CategoricalDistributions: classes, ERR_NAN_FOUND
import CategoricalArrays.unwrap

rng = StableRNG(111)
Expand Down Expand Up @@ -198,10 +198,10 @@ end
probs = rand(rng, n)
u = UnivariateFinite(probs, augment = true, pool=missing)
supp = Distributions.support(u)
modes = mode.(u)
@test modes isa CategoricalArray
_modes = mode.(u)
@test _modes isa CategoricalArray
expected = [ifelse(p > 0.5, supp[2], supp[1]) for p in probs]
@test all(modes .== expected)
@test all(_modes .== expected)

# multiclass
rng = StableRNG(554)
Expand All @@ -220,7 +220,48 @@ end
],
pool=missing
)
@test_throws DomainError mode.(unf_arr)
@test_throws ERR_NAN_FOUND mode.(unf_arr)
end

@testset "broadcasting modes" begin
# binary
rng = StableRNG(668)
probs = rand(rng, n)
u = UnivariateFinite(probs, augment = true, pool=missing)
supp = Distributions.support(u)
_modes = modes.(u)
@test _modes isa Vector{<:CategoricalArray}
expected = [ifelse(p > 0.5, [supp[2]], [supp[1]]) for p in probs]
@test all(_modes .== expected)

# multiclass, bimodal
rng = StableRNG(554)
P = rand(rng, n, c)
M, M_idx = findmax(P, dims=2)
M_idx = getindex.(M_idx, 2)
for i in axes(P,1)
m = M[i]
j = M_idx[i]
while j == M_idx[i]
j = rand(axes(P,2))
end
P[i,j] = m
end
P ./= sum(P, dims=2)
u = UnivariateFinite(P, pool=missing)
expected = modes.([u...])
@test all(modes.(u) .== expected)

# `mode` broadcasting of `Univariate` objects containing `NaN` in probs.
unf_arr = UnivariateFinite(
[
0.1 0.2 NaN 0.1 NaN;
0.2 0.1 0.1 0.4 0.2;
0.3 NaN 0.2 NaN 0.3
],
pool=missing
)
@test_throws ERR_NAN_FOUND modes.(unf_arr)
end

@testset "cat for UnivariateFiniteArray" begin
Expand Down
25 changes: 21 additions & 4 deletions test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Random
rng = StableRNG(123)
using ScientificTypes

import CategoricalDistributions: classes
import CategoricalDistributions: classes, ERR_NAN_FOUND

v = categorical(collect("asqfasqffqsaaaa"), ordered=true)
V = categorical(collect("asqfasqffqsaaaa"))
Expand All @@ -19,8 +19,10 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
@testset "set 1" begin

# ordered (OrderedFactor)
dict = Dict(s=>0.1, q=> 0.2, f=> 0.7)
dict = Dict(s=>0.1, q=>0.2, f=>0.7)
d = UnivariateFinite(dict)
dict_bimodal = Dict(a=>0.1, s=>0.1, q=>0.4, f=>0.4)
d_bimodal = UnivariateFinite(dict_bimodal)
@test classes(d) == [a, f, q, s]
@test classes(d) == classes(s)
@test levels(d) == levels(s)
Expand All @@ -45,6 +47,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
@test logpdf(d, 'f') log(0.7)
@test isinf(logpdf(d, a))
@test mode(d) == f
@test modes(d_bimodal) == [f, q]

@test UnivariateFinite(support(d), [0.7, 0.2, 0.1]) d

Expand Down Expand Up @@ -72,7 +75,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
@test isapprox(freq[q]/N, ffreq[q]/N)

# unordered (Multiclass):
dict = Dict(S=>0.1, Q=> 0.2, F=> 0.7)
dict = Dict(S=>0.1, Q=>0.2, F=>0.7)
d = UnivariateFinite(dict)
@test classes(d) == [a, f, q, s]
@test classes(d) == classes(s)
Expand Down Expand Up @@ -178,7 +181,21 @@ end

# `mode` of `Univariate` objects containing `NaN` in probs.
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
@test_throws DomainError mode(unf)
@test_throws ERR_NAN_FOUND mode(unf)
end

@testset "Univariate modes, bimodal" begin
v = categorical(1:101)
p = rand(rng,101)
p[24] = 2*maximum(p)
p[42] = p[24]
p = p/sum(p)
d = UnivariateFinite(v, p)
@test modes(d) == [24, 42]

# `mode` of `Univariate` objects containing `NaN` in probs.
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
@test_throws ERR_NAN_FOUND modes(unf)
end

@testset "UnivariateFinite methods" begin
Expand Down

0 comments on commit afcd606

Please sign in to comment.