Skip to content

Commit

Permalink
Improvements to plotting PMFs for discrete distributions (#451)
Browse files Browse the repository at this point in the history
* Support distributions with non-integer support

* Use support unless unbounded

* Test discrete plots

* Increment version number

* Update src/distributions.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/distributions.jl

Co-authored-by: David Widmann <[email protected]>

* Default to hairplot for discrete dists

* Only add markers if unstyled by user

* Revert "Update src/distributions.jl"

This reverts commit 0ba4549.

* Remove whitespace

* Plot discrete mixtures with sticks

* Promote range bounds

Necessary to call UnitRange on them

* Only promote to float if necessary

* Test zero-inflated Poisson

* Test values of Poisson

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Revert "Promote range bounds"

This reverts commit 0f428de.

* Use mapreduce instead of explicitly looping

* Don't default to showing markers

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
sethaxen and devmotion authored Jun 23, 2021
1 parent 6e65c64 commit fadcdf7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsPlots"
uuid = "f3b207a7-027a-5e70-b257-86293d7955fd"
version = "0.14.22"
version = "0.14.23"

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Expand Down
36 changes: 19 additions & 17 deletions src/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,34 @@ function default_range(dist::Distribution, alpha = 0.0001)
end

function default_range(m::Distributions.MixtureModel, alpha = 0.0001)
minval = maxval = 0.0
for c in m.components
thismin = isfinite(minimum(c)) ? minimum(c) : quantile(c, alpha)
thismax = isfinite(maximum(c)) ? maximum(c) : quantile(c, 1-alpha)
if thismin < minval
minval = thismin
end
if thismax > maxval
maxval = thismax
end
end
minval, maxval
mapreduce(c -> default_range(c, alpha), _minmax, m.components)
end

_minmax((xmin, xmax), (ymin, ymax)) = (min(xmin, ymin), max(xmax, ymax))

yz_args(dist) = default_range(dist)
yz_args(dist::Distribution{N, T}) where N where T<:Discrete = (UnitRange(Int.(default_range(dist))...),)
function yz_args(dist::DiscreteUnivariateDistribution)
minval, maxval = extrema(dist)
if isfinite(minval) && isfinite(maxval) # bounded
sup = support(dist)
return sup isa AbstractVector ? (sup,) : ([sup...],)
else # unbounded
return (UnitRange(default_range(dist)...),)
end
end

# this "user recipe" adds a default x vector based on the distribution's μ and σ
@recipe function f(dist::Distribution)
if dist isa Distribution{Univariate,Discrete}
seriestype --> :scatterpath
if dist isa DiscreteUnivariateDistribution
seriestype --> :sticks
end
(dist, yz_args(dist)...)
end

@recipe function f(m::Distributions.MixtureModel; components = true)
if m isa DiscreteUnivariateDistribution
seriestype --> :sticks
end
if components
for c in m.components
@series begin
Expand All @@ -48,8 +50,8 @@ end
for di in distvec
@series begin
seriesargs = isempty(yz) ? yz_args(di) : yz
if di isa Distribution{Univariate,Discrete}
seriestype --> :scatterpath
if di isa DiscreteUnivariateDistribution
seriestype --> :sticks
end
(di, seriesargs...)
end
Expand Down
41 changes: 41 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test
using StableRNGs
using NaNMath
using Clustering
using Distributions

@testset "Grouped histogram" begin
rng = StableRNG(1337)
Expand Down Expand Up @@ -57,3 +58,43 @@ end # testset
0.0 17.0 0.0 28.0
]
end

@testset "Distributions" begin
@testset "univariate" begin
@testset "discrete" begin
pbern = plot(Bernoulli(0.25))
@test pbern[1][1][:x][1:2] == zeros(2)
@test pbern[1][1][:x][4:5] == ones(2)
@test pbern[1][1][:y][[1, 4]] == zeros(2)
@test pbern[1][1][:y][[2, 5]] == [0.75, 0.25]

pdirac = plot(Dirac(0.25))
@test pdirac[1][1][:x][1:2] == [0.25, 0.25]
@test pdirac[1][1][:y][1:2] == [0, 1]

ppois_unbounded = plot(Poisson(1))
@test ppois_unbounded[1][1][:x] isa AbstractVector
@test ppois_unbounded[1][1][:x][1:2] == zeros(2)
@test ppois_unbounded[1][1][:x][4:5] == ones(2)
@test ppois_unbounded[1][1][:y][[1, 4]] == zeros(2)
@test ppois_unbounded[1][1][:y][[2, 5]] == pdf.(Poisson(1), ppois_unbounded[1][1][:x][[1, 4]])

pnonint = plot(Bernoulli(0.75) - 1//2)
@test pnonint[1][1][:x][1:2] == [-1//2, -1//2]
@test pnonint[1][1][:x][4:5] == [1//2, 1//2]
@test pnonint[1][1][:y][[1, 4]] == zeros(2)
@test pnonint[1][1][:y][[2, 5]] == [0.25, 0.75]

pmix = plot(MixtureModel([Bernoulli(0.75), Bernoulli(0.5)], [0.5, 0.5]); components=false)
@test pmix[1][1][:x][1:2] == zeros(2)
@test pmix[1][1][:x][4:5] == ones(2)
@test pmix[1][1][:y][[1, 4]] == zeros(2)
@test pmix[1][1][:y][[2, 5]] == [0.375, 0.625]

dzip = MixtureModel([Dirac(0), Poisson(1)], [0.1, 0.9])
pzip = plot(dzip; components=false)
@test pzip[1][1][:x] isa AbstractVector
@test pzip[1][1][:y][2:3:end] == pdf.(dzip, Int.(pzip[1][1][:x][1:3:end]))
end
end
end

4 comments on commit fadcdf7

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BeastyBlacksmith can you register a new release here?

@BeastyBlacksmith
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. Just wanted to be sure, you were fine with merging

@BeastyBlacksmith
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/39474

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.23 -m "<description of version>" fadcdf7f2d9c91798b9a789b68835fa7896a2183
git push origin v0.14.23

Please sign in to comment.