Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Positive Definate Woodbury #2

Merged
merged 7 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
name = "PDMatsExtras"
uuid = "2c7acb1b-7338-470f-b38f-951d2bcb9193"
authors = ["Invenia Technical Computing"]
version = "2.0.0"
version = "2.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"

[compat]
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
ChainRulesCore = "0.9.17"
Distributions = "0.23, 0.24"
FiniteDifferences = "0.11"
PDMats = "0.9, 0.10"
Zygote = "0.5.5"
julia = "1"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributions", "Random", "SuiteSparse", "Test"]
test = ["Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,37 @@ julia> PSDMat(cholesky(X, Val(true); check=false))
2.0 -1.0 1.0 1.0
4.0 -4.0 1.0 6.0
```

## WoodburyPDMat
It is a positive definite Woodbury matrix.
This is a special case of the Symmetric Woodbury Matrix (see [WoodburyMatrices.jl's](https://github.com/timholy/WoodburyMatrices.jl/) `SymWoodbury` type) which is given by `A*D*A' + S` for `S` and `D` being diagonal,
which has the additional requirement that the diagonal matrices are also non-negative.

```julia
julia> using LinearAlgebra, PDMatsExtras

julia> A = Float64[
2.0 2.0 -8.0 5.0 -1.0 2.0 6.0
2.0 7.0 -1.0 -5.0 -4.0 8.0 7.0
-2.0 9.0 -9.0 -5.0 9.0 -5.0 -3.0
3.0 4.0 -6.0 -4.0 3.0 -3.0 -3.0
];

julia> D = Diagonal(Float64[1, 2, 3, 2, 2, 1, 5]);

julia> S = Diagonal(Float64[4, 2, 3, 6]);

julia> W = WoodburyPDMat(A, D, S)
4×4 WoodburyPDMat{Float64,Array{Float64,2},Diagonal{Float64,Array{Float64,1}},Diagonal{Float64,Array{Float64,1}}}:
444.0 240.0 80.0 24.0
240.0 498.0 -18.0 -33.0
80.0 -18.0 694.0 382.0
24.0 -33.0 382.0 259.0

julia> A*D*A' + S
4×4 Array{Float64,2}:
444.0 240.0 80.0 24.0
240.0 498.0 -18.0 -33.0
80.0 -18.0 694.0 382.0
24.0 -33.0 382.0 259.0
```
5 changes: 3 additions & 2 deletions src/PDMatsExtras.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module PDMatsExtras

using ChainRulesCore
using LinearAlgebra
using PDMats
import Base: *, \

export PSDMat
export PSDMat, WoodburyPDMat

include("psd_mat.jl")
include("woodbury_pd_mat.jl")

end
86 changes: 86 additions & 0 deletions src/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
WoodburyPDMat(
A::AbstractMatrix{T}, D::Diagonal{T}, S::Diagonal{T},
) where {T<:Real}

Lazily represents matrices of the form
```julia
W = A * D * A' + S
```
`D` and `S` must have only non-negative entries.

Using this matrix type is a good idea if `size(A, 1) > size(A, 2)` as the structure in the
matrix can be exploited to accelerate operations involving `W`'s inverse [1], such as
`invquad`, and it's determinant [2], such as `logdet`.

You probably don't want to use this matrix type if `size(A, 1) < size(A, 2)`.

[1] - https://en.wikipedia.org/wiki/Woodbury_matrix_identity
[2] - https://en.wikipedia.org/wiki/Matrix_determinant_lemma
"""
struct WoodburyPDMat{
T<:Real, TA<:AbstractMatrix{T}, TD<:Diagonal{T}, TS<:Diagonal{T},
} <: AbstractPDMat{T}
A::TA
D::TD
S::TS
function WoodburyPDMat(
A::AbstractMatrix{T}, D::Diagonal{T}, S::Diagonal{T},
) where {T<:Real}
validate_woodbury_arguments(A, D, S)
return new{T, typeof(A), typeof(D), typeof(S)}(A, D, S)
end
end

function WoodburyPDMat(
A::AbstractMatrix{T}, D::AbstractVector{T}, S::AbstractVector{T},
) where {T<:Real}
return WoodburyPDMat(A, Diagonal(D), Diagonal(S))
end

PDMats.dim(W::WoodburyPDMat) = size(W.A, 1)

# Convesion method. Primarily useful for testing purposes.
Base.Matrix(W::WoodburyPDMat) = W.A * W.D * W.A' + W.S

Base.getindex(W::WoodburyPDMat, inds...) = getindex(Matrix(W), inds...)

function validate_woodbury_arguments(A, D, S)
if size(A, 1) != size(S, 1)
throw(ArgumentError("size(A, 1) != size(S, 1)"))
end
if size(A, 2) != size(D, 1)
throw(ArgumentError("size(A, 2) != size(D, 1)"))
end
if any(x -> x < 0, diag(D))
throw(ArgumentError("Detected negative element on diagonal of D: $(D)"))
end
if any(x -> x < 0, diag(S))
throw(ArgumentError("Detected negative element on diagonal of S: $(S)"))
end
end

@non_differentiable validate_woodbury_arguments(A, D, S)

function LinearAlgebra.logdet(W::WoodburyPDMat)
C_S = cholesky(W.S)
B = C_S.U' \ (W.A * cholesky(W.D).U')
return logdet(C_S) + logdet(cholesky(Symmetric(I + B'B)))
end

# Utilises the matrix inversion lemma to produce an efficient implementation.
function PDMats.invquad(W::WoodburyPDMat{<:Real}, x::AbstractVector{<:Real})
C_S = cholesky(W.S)
B = C_S.U' \ (W.A * cholesky(W.D).U')
α = C_S.U' \ x
β = B' * α
return α'α - sum(abs2, cholesky(Symmetric(I + B'B)).U' \ β)
end

# This doesn't get us the computational wins, but it's unclear how to construct a
# root for a Woodbury matrix. Consequently, if performance is very important when sampling,
# it's necessary to implement a method of `rand` or `_rand` that explicitly uses ancestral
# sampling to exploit the approximately low-rank structre in a WoodburyPDMat.
function PDMats.unwhiten!(r::DenseVecOrMat, W::WoodburyPDMat{<:Real}, x::DenseVecOrMat)
return unwhiten!(r, PDMat(Symmetric(Matrix(W))), x)
end
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
using PDMatsExtras
using Test

using LinearAlgebra
using Random

using ChainRulesCore
using Distributions
using FiniteDifferences
using LinearAlgebra
using PDMats
using Random
using Test
using Zygote

Random.seed!(1)
@testset "PDMatsExtras.jl" begin
include("testutils.jl")
include("test_ad.jl")

include("psd_mat.jl")
include("woodbury_pd_mat.jl")
end
15 changes: 15 additions & 0 deletions test/test_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function test_ad(test_function, Δoutput, inputs...; atol=1e-6, rtol=1e-6)

# Verify that the forwards-pass produces the correct answer.
output, pb = Zygote.pullback(test_function, inputs...)
@test output ≈ test_function(inputs...)

# Compute the adjoints using AD and FiniteDifferences.
dW_ad = pb(Δoutput)
dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)

# Compare AD and FiniteDifferences results.
@testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
@test dW_ad[n] ≈ dW_fd[n] atol=atol rtol=rtol
end
end
65 changes: 65 additions & 0 deletions test/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testset "woodbury_pd_mat" begin
A = randn(4, 2)
D = Diagonal(randn(2).^2 .+ 1)
S = Diagonal(randn(4).^2 .+ 1)
x = randn(size(A, 1))

W = WoodburyPDMat(A, D, S)
W_dense = PDMat(Symmetric(Matrix(W)))

@testset "invalid constructors error" begin
@test_throws ArgumentError WoodburyPDMat(randn(5, 2), D, S)
@test_throws ArgumentError WoodburyPDMat(randn(4, 3), D, S)
@test_throws ArgumentError WoodburyPDMat(A, Diagonal(.-randn(2).^2), S)
@test_throws ArgumentError WoodburyPDMat(A, D, Diagonal(.-randn(4).^2))
end

@testset "Basic functionality" begin
# Checks getindex works.
@test all(isapprox.(W, W_dense))
end

@testset "unwhiten!" begin
@test PDMats.unwhiten!(similar(x), W, x) ≈ PDMats.unwhiten!(similar(x), W_dense, x)
end

@testset "logdet" begin
@test logdet(W) ≈ logdet(W_dense)
test_ad(randn(), A, D, S) do A, D, S
logdet(WoodburyPDMat(A, D, S))
end
end

@testset "invquad" begin
@test invquad(W, x) ≈ invquad(W_dense, x)

test_ad(randn(), A, D, S, x) do A, D, S, x
W = WoodburyPDMat(A, D, S)
return invquad(W, x)
end
end

@testset "MvNormal logpdf" begin
m = randn(size(A, 1))
@test logpdf(MvNormal(m, W), x) ≈ logpdf(MvNormal(m, Symmetric(Matrix(W))), x)

test_ad(randn(), m, A, D, S, x) do m, A, D, S, x
W = WoodburyPDMat(A, D, S)
return logpdf(MvNormal(m, W), x)
end
end

@testset "GenericMvTDist logpdf" begin
α = 2.1
m = randn(size(A, 1))
@test isapprox(
logpdf(Distributions.GenericMvTDist(α, m, W), x),
logpdf(Distributions.GenericMvTDist(α, m, W_dense), x),
)

test_ad(randn(), α, m, A, D, S, x) do α, m, A, D, S, x
W = WoodburyPDMat(A, D, S)
return logpdf(Distributions.GenericMvTDist(α, m, W), x)
end
end
rofinn marked this conversation as resolved.
Show resolved Hide resolved
end