Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallbacks for (un)whiten(!) and (inv)quad(!) to AbstractArray #162

Merged
merged 11 commits into from
May 31, 2022
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "PDMats"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.10"
version = "0.11.11"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ unwhiten!(a, x) # un-whitening transform inplace, updating `x`.
unwhiten!(r, a, x) # write the transformed result to `r`.
```

### Fallbacks for `AbstractArray`s
For ease of composability, some of these functions have generic fallbacks defined that work on `AbstractArray`s.
These fallbacks may not be as fast as the methods specializaed for `AbstractPDMat`s, but they let you more easily swap out types.
While in theory all of them can be defined, at present only the following subset has:

- `dim`
- `whiten`, `whiten!`
- `unwhiten`, `unwhiten!`
- `quad`, `quad!`
- `invquad`, `invquad!`

PRs to implement more generic fallbacks are welcome.

## Define Customized Subtypes

Expand Down
61 changes: 46 additions & 15 deletions src/generics.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Generic functions (on top of the type-specific implementations)

## Basic functions

Base.size(a::AbstractPDMat) = (dim(a), dim(a))
Base.size(a::AbstractPDMat, i::Integer) = 1 <= i <= 2 ? dim(a) : 1
Base.length(a::AbstractPDMat) = abs2(dim(a))

function dim(a::AbstractMatrix)
@check_argdims size(a, 1) == size(a, 2)
return size(a, 1)
end

## arithmetics

pdadd!(r::Matrix, a::Matrix, b::AbstractPDMat{T}) where {T<:Real} = pdadd!(r, a, b, one(T))
Expand All @@ -29,16 +33,25 @@ LinearAlgebra.isposdef(::AbstractPDMat) = true
LinearAlgebra.ishermitian(::AbstractPDMat) = true

## whiten and unwhiten
whiten!(a::AbstractPDMat, x::StridedVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractPDMat, x::StridedVecOrMat) = unwhiten!(x, a, x)

whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
ldiv!(chol_lower(cholesky(a)), v)
end

function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
lmul!(chol_lower(cholesky(a)), v)
end

"""
whiten(a::AbstractPDMat, x::StridedVecOrMat)
whiten!(a::AbstractPDMat, x::StridedVecOrMat)
whiten!(r::StridedVecOrMat, a::AbstractPDMat, x::StridedVecOrMat)
unwhiten(a::AbstractPDMat, x::StridedVecOrMat)
unwhiten!(a::AbstractPDMat, x::StridedVecOrMat)
unwhiten!(r::StridedVecOrMat, a::AbstractPDMat, x::StridedVecOrMat)
whiten(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten!(r::AbstractVecOrMat, a::AbstractPDMat, x::AbstractVecOrMat)

Allocating and in-place versions of the `whiten`ing transform (or its inverse) defined by `a` applied to `x`

Expand Down Expand Up @@ -68,28 +81,38 @@ julia> W * W'
0.0 1.0
```
"""
whiten(a::AbstractPDMat, x::StridedVecOrMat) = whiten!(similar(x), a, x)
unwhiten(a::AbstractPDMat, x::StridedVecOrMat) = unwhiten!(similar(x), a, x)
whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x)


## quad

"""
quad(a::AbstractPDMat, x::StridedVecOrMat)
quad(a::AbstractMatrix, x::AbstractVecOrMat)

Return the value of the quadratic form defined by `a` applied to `x`

If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
function quad(a::AbstractPDMat{T}, x::StridedMatrix{S}) where {T<:Real, S<:Real}
function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims dim(a) == size(x, 1)
quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
end

quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x)
invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x)

"""
invquad(a::AbstractPDMat, x::StridedVecOrMat)
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)

Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`
"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x)


"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat)

Return the value of the quadratic form defined by `inv(a)` applied to `x`.

Expand All @@ -98,7 +121,15 @@ For most `PDMat` types this is done in a way that does not require evaluation of
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
function invquad(a::AbstractPDMat{T}, x::StridedMatrix{S}) where {T<:Real, S<:Real}
invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x
function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims dim(a) == size(x, 1)
invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
end

"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)

Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x`
"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x)
33 changes: 0 additions & 33 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,39 +60,6 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat)
Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info))
LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat)))

### whiten and unwhiten
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

function whiten!(r::StridedVecOrMat, a::PDMat, x::StridedVecOrMat)
v = _rcopy!(r, x)
ldiv!(chol_lower(a.chol), v)
end

function unwhiten!(r::StridedVecOrMat, a::PDMat, x::StridedVecOrMat)
v = _rcopy!(r, x)
lmul!(chol_lower(a.chol), v)
end


### quadratic forms

quad(a::PDMat, x::AbstractVector) = sum(abs2, chol_upper(a.chol) * x)
invquad(a::PDMat, x::AbstractVector) = sum(abs2, chol_lower(a.chol) \ x)

"""
quad!(r::AbstractArray, a::AbstractPDMat, x::StridedMatrix)

Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`
"""
quad!(r::AbstractArray, a::PDMat, x::StridedMatrix) = colwise_dot!(r, x, a.mat * x)

"""
invquad!(r::AbstractArray, a::AbstractPDMat, x::StridedMatrix)

Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x`
"""
invquad!(r::AbstractArray, a::PDMat, x::StridedMatrix) = colwise_dot!(r, x, a.mat \ x)


### tri products

function X_A_Xt(a::PDMat, x::AbstractMatrix)
Expand Down
9 changes: 4 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ macro check_argdims(cond)
end
end

_rcopy!(r::StridedVecOrMat, x::StridedVecOrMat) = (r === x || copyto!(r, x); r)
_rcopy!(r, x) = (r === x || copyto!(r, x); r)


function _addscal!(r::Matrix, a::Matrix, b::Union{Matrix, SparseMatrixCSC}, c::Real)
Expand Down Expand Up @@ -69,11 +69,10 @@ function invwsumsq(w::AbstractVector, a::AbstractVector)
end

function colwise_dot!(r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
n = length(r)
@check_argdims n == size(a, 2) == size(b, 2) && size(a, 1) == size(b, 1)
for j = 1:n
@check_argdims(axes(a) == axes(b))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
for j in axes(a, 2)
v = zero(promote_type(eltype(a), eltype(b)))
@simd for i = 1:size(a, 1)
@simd for i in axes(a, 1)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
@inbounds v += a[i, j]*b[i, j]
end
r[j] = v
devmotion marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
23 changes: 23 additions & 0 deletions test/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Test, PDMats

@testset "AbstractMatrix fallback functionality" begin
C = Cmat = [4. -2. -1.; -2. 5. -1.; -1. -1. 6.]

test_pdmat(C, Cmat;
verbose=2, # the level to display intermediate steps
cmat_eq=true, # require Cmat and Matrix(C) to be exactly equal
t_diag=false, # whether to test diag method
t_cholesky=false, # whether to test cholesky method
t_scale=false, # whether to test scaling
t_add=false, # whether to test pdadd
t_det=false, # whether to test det method
t_logdet=false, # whether to test logdet method
t_eig=false, # whether to test eigmax and eigmin
t_mul=false, # whether to test multiplication
t_div=false, # whether to test division
t_quad=true, # whether to test quad & invquad
t_triprod=false, # whether to test X_A_Xt, Xt_A_X, X_invA_Xt, and Xt_invA_X
t_whiten=true # whether to test whiten and unwhiten
)

end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
include("testutils.jl")
tests = ["pdmtypes", "addition", "generics", "kron", "chol", "specialarrays", "sqrt"]
tests = ["pdmtypes", "abstracttypes", "addition", "generics", "kron", "chol", "specialarrays", "sqrt"]
println("Running tests ...")

for t in tests
Expand Down
30 changes: 15 additions & 15 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const HAVE_CHOLMOD = isdefined(SuiteSparse, :CHOLMOD)
const PDMatType = HAVE_CHOLMOD ? Union{PDMat, PDSparseMat, PDiagMat} : Union{PDMat, PDiagMat}

## driver function
function test_pdmat(C::AbstractPDMat, Cmat::Matrix;
function test_pdmat(C, Cmat::Matrix;
verbose::Int=2, # the level to display intermediate steps
cmat_eq::Bool=false, # require Cmat and Matrix(C) to be exactly equal
t_diag::Bool=true, # whether to test diag method
Expand Down Expand Up @@ -62,7 +62,7 @@ end
_pdt(vb::Int, s) = (vb >= 2 && printstyled(" .. testing $s\n", color=:green))


function pdtest_basics(C::AbstractPDMat, Cmat::Matrix, d::Int, verbose::Int)
function pdtest_basics(C, Cmat::Matrix, d::Int, verbose::Int)
_pdt(verbose, "dim")
@test dim(C) == d

Expand Down Expand Up @@ -94,7 +94,7 @@ function pdtest_basics(C::AbstractPDMat, Cmat::Matrix, d::Int, verbose::Int)
end


function pdtest_cmat(C::AbstractPDMat, Cmat::Matrix, cmat_eq::Bool, verbose::Int)
function pdtest_cmat(C, Cmat::Matrix, cmat_eq::Bool, verbose::Int)
_pdt(verbose, "full")
if cmat_eq
@test Matrix(C) == Cmat
Expand All @@ -104,7 +104,7 @@ function pdtest_cmat(C::AbstractPDMat, Cmat::Matrix, cmat_eq::Bool, verbose::Int
end


function pdtest_diag(C::AbstractPDMat, Cmat::Matrix, cmat_eq::Bool, verbose::Int)
function pdtest_diag(C, Cmat::Matrix, cmat_eq::Bool, verbose::Int)
_pdt(verbose, "diag")
if cmat_eq
@test diag(C) == diag(Cmat)
Expand Down Expand Up @@ -133,14 +133,14 @@ if HAVE_CHOLMOD
end
end

function pdtest_scale(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_scale(C, Cmat::Matrix, verbose::Int)
_pdt(verbose, "scale")
@test Matrix(C * convert(eltype(C),2)) ≈ Cmat * convert(eltype(C),2)
@test Matrix(convert(eltype(C),2) * C) ≈ convert(eltype(C),2) * Cmat
end


function pdtest_add(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_add(C, Cmat::Matrix, verbose::Int)
M = rand(eltype(C),size(Cmat))
_pdt(verbose, "add")
@test C + M ≈ Cmat + M
Expand All @@ -156,7 +156,7 @@ function pdtest_add(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
@test Mr ≈ R
end

function pdtest_det(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_det(C, Cmat::Matrix, verbose::Int)
_pdt(verbose, "det")
@test det(C) ≈ det(Cmat)

Expand All @@ -166,7 +166,7 @@ function pdtest_det(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
end
end

function pdtest_logdet(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_logdet(C, Cmat::Matrix, verbose::Int)
_pdt(verbose, "logdet")
@test logdet(C) ≈ logdet(Cmat)

Expand All @@ -177,7 +177,7 @@ function pdtest_logdet(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
end


function pdtest_eig(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_eig(C, Cmat::Matrix, verbose::Int)
_pdt(verbose, "eigmax")
@test eigmax(C) ≈ eigmax(Cmat)

Expand All @@ -186,14 +186,14 @@ function pdtest_eig(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
end


function pdtest_mul(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_mul(C, Cmat::Matrix, verbose::Int)
n = 5
X = rand(eltype(C), dim(C), n)
pdtest_mul(C, Cmat, X, verbose)
end


function pdtest_mul(C::AbstractPDMat, Cmat::Matrix, X::Matrix, verbose::Int)
function pdtest_mul(C, Cmat::Matrix, X::Matrix, verbose::Int)
_pdt(verbose, "multiply")
d, n = size(X)
@assert d == dim(C)
Expand All @@ -217,7 +217,7 @@ function pdtest_mul(C::AbstractPDMat, Cmat::Matrix, X::Matrix, verbose::Int)
end


function pdtest_div(C::AbstractPDMat, Imat::Matrix, X::Matrix, verbose::Int)
function pdtest_div(C, Imat::Matrix, X::Matrix, verbose::Int)
_pdt(verbose, "divide")
d, n = size(X)
@assert d == dim(C)
Expand Down Expand Up @@ -246,7 +246,7 @@ function pdtest_div(C::AbstractPDMat, Imat::Matrix, X::Matrix, verbose::Int)
end


function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
function pdtest_quad(C, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
n = size(X, 2)

_pdt(verbose, "quad")
Expand All @@ -271,7 +271,7 @@ function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, ve
end


function pdtest_triprod(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
function pdtest_triprod(C, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
d, n = size(X)
@assert d == dim(C)
Xt = copy(transpose(X))
Expand All @@ -298,7 +298,7 @@ function pdtest_triprod(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix,
end


function pdtest_whiten(C::AbstractPDMat, Cmat::Matrix, verbose::Int)
function pdtest_whiten(C, Cmat::Matrix, verbose::Int)
Y = PDMats.chol_lower(Cmat)
Q = qr(convert(Array{eltype(C),2},randn(size(Cmat)))).Q
Y = Y * Q' # generate a matrix Y such that Y * Y' = C
Expand Down