diff --git a/Project.toml b/Project.toml index bd0a73f..340d926 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StatsPlots" uuid = "f3b207a7-027a-5e70-b257-86293d7955fd" -version = "0.14.22" +version = "0.14.23" [deps] Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" diff --git a/src/distributions.jl b/src/distributions.jl index f7547dc..6f6b2a7 100644 --- a/src/distributions.jl +++ b/src/distributions.jl @@ -7,32 +7,34 @@ function default_range(dist::Distribution, alpha = 0.0001) end function default_range(m::Distributions.MixtureModel, alpha = 0.0001) - minval = maxval = 0.0 - for c in m.components - thismin = isfinite(minimum(c)) ? minimum(c) : quantile(c, alpha) - thismax = isfinite(maximum(c)) ? maximum(c) : quantile(c, 1-alpha) - if thismin < minval - minval = thismin - end - if thismax > maxval - maxval = thismax - end - end - minval, maxval + mapreduce(c -> default_range(c, alpha), _minmax, m.components) end +_minmax((xmin, xmax), (ymin, ymax)) = (min(xmin, ymin), max(xmax, ymax)) + yz_args(dist) = default_range(dist) -yz_args(dist::Distribution{N, T}) where N where T<:Discrete = (UnitRange(Int.(default_range(dist))...),) +function yz_args(dist::DiscreteUnivariateDistribution) + minval, maxval = extrema(dist) + if isfinite(minval) && isfinite(maxval) # bounded + sup = support(dist) + return sup isa AbstractVector ? (sup,) : ([sup...],) + else # unbounded + return (UnitRange(default_range(dist)...),) + end +end # this "user recipe" adds a default x vector based on the distribution's μ and σ @recipe function f(dist::Distribution) - if dist isa Distribution{Univariate,Discrete} - seriestype --> :scatterpath + if dist isa DiscreteUnivariateDistribution + seriestype --> :sticks end (dist, yz_args(dist)...) end @recipe function f(m::Distributions.MixtureModel; components = true) + if m isa DiscreteUnivariateDistribution + seriestype --> :sticks + end if components for c in m.components @series begin @@ -48,8 +50,8 @@ end for di in distvec @series begin seriesargs = isempty(yz) ? yz_args(di) : yz - if di isa Distribution{Univariate,Discrete} - seriestype --> :scatterpath + if di isa DiscreteUnivariateDistribution + seriestype --> :sticks end (di, seriesargs...) end diff --git a/test/runtests.jl b/test/runtests.jl index 67a7a4f..71eb71c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test using StableRNGs using NaNMath using Clustering +using Distributions @testset "Grouped histogram" begin rng = StableRNG(1337) @@ -57,3 +58,43 @@ end # testset 0.0 17.0 0.0 28.0 ] end + +@testset "Distributions" begin + @testset "univariate" begin + @testset "discrete" begin + pbern = plot(Bernoulli(0.25)) + @test pbern[1][1][:x][1:2] == zeros(2) + @test pbern[1][1][:x][4:5] == ones(2) + @test pbern[1][1][:y][[1, 4]] == zeros(2) + @test pbern[1][1][:y][[2, 5]] == [0.75, 0.25] + + pdirac = plot(Dirac(0.25)) + @test pdirac[1][1][:x][1:2] == [0.25, 0.25] + @test pdirac[1][1][:y][1:2] == [0, 1] + + ppois_unbounded = plot(Poisson(1)) + @test ppois_unbounded[1][1][:x] isa AbstractVector + @test ppois_unbounded[1][1][:x][1:2] == zeros(2) + @test ppois_unbounded[1][1][:x][4:5] == ones(2) + @test ppois_unbounded[1][1][:y][[1, 4]] == zeros(2) + @test ppois_unbounded[1][1][:y][[2, 5]] == pdf.(Poisson(1), ppois_unbounded[1][1][:x][[1, 4]]) + + pnonint = plot(Bernoulli(0.75) - 1//2) + @test pnonint[1][1][:x][1:2] == [-1//2, -1//2] + @test pnonint[1][1][:x][4:5] == [1//2, 1//2] + @test pnonint[1][1][:y][[1, 4]] == zeros(2) + @test pnonint[1][1][:y][[2, 5]] == [0.25, 0.75] + + pmix = plot(MixtureModel([Bernoulli(0.75), Bernoulli(0.5)], [0.5, 0.5]); components=false) + @test pmix[1][1][:x][1:2] == zeros(2) + @test pmix[1][1][:x][4:5] == ones(2) + @test pmix[1][1][:y][[1, 4]] == zeros(2) + @test pmix[1][1][:y][[2, 5]] == [0.375, 0.625] + + dzip = MixtureModel([Dirac(0), Poisson(1)], [0.1, 0.9]) + pzip = plot(dzip; components=false) + @test pzip[1][1][:x] isa AbstractVector + @test pzip[1][1][:y][2:3:end] == pdf.(dzip, Int.(pzip[1][1][:x][1:3:end])) + end + end +end