Skip to content

Commit

Permalink
remove SpDiagIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Aug 17, 2017
1 parent 7af9ab2 commit 00905db
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 34 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ Deprecated or removed
* `Base.cpad` has been removed; use an appropriate combination of `rpad` and `lpad`
instead ([#23187]).

* `Base.SparseArrays.SpDiagIterator` has been removed ([#23261]).

Command-line option changes
---------------------------

Expand Down
55 changes: 22 additions & 33 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3380,51 +3380,40 @@ function expandptr(V::Vector{<:Integer})
res
end

## diag and related using an iterator

struct SpDiagIterator{Tv,Ti}
A::SparseMatrixCSC{Tv,Ti}
d::Int # diagonal to iterate over
end

start(it::SpDiagIterator) = it.d < 0 ? (1-it.d, 1) : (1, it.d+1)
done(it::SpDiagIterator, rc) = rc[1] > size(it.A,1) || rc[2] > size(it.A,2)
function next(it::SpDiagIterator{Tv}, rc) where Tv
r, c = rc
A = it.A
r1 = Int(A.colptr[c])
r2 = Int(A.colptr[c+1]-1)
(r1 > r2) && (return (zero(Tv), (r+1, c+1)))
r1 = searchsortedfirst(A.rowval, r, r1, r2, Forward)
(((r1 > r2) || (A.rowval[r1] != r)) ? zero(Tv) : A.nzval[r1], (r+1, c+1))
end

function diag(A::SparseMatrixCSC{Tv,Ti}, d::Int=0) where {Tv,Ti}
m, n = size(A)
if !(-m <= d <= n)
throw(ArgumentError("requested diagonal, $d, out of bounds in matrix of size ($m, $n)"))
end
l = d < 0 ? min(m+d,n) : min(n-d,m)
ind = Vector{Ti}(); sizehint!(ind, min(l, nnz(A)))
val = Vector{Tv}(); sizehint!(val, min(l, nnz(A)))
for (i, v) in enumerate(SpDiagIterator(A, d))
if !iszero(v)
push!(ind, i)
push!(val, v)
end
if d <= 0
rrange = (1-d):min(m, min(m,n)-d)
crange = 1:min(n, m+d)
else # d > 0
rrange = 1:min(m, n-d)
crange = (1+d):min(n, min(m,n)+d)
end
ind = Vector{Ti}()
val = Vector{Tv}()
for (i, (r, c)) in enumerate(zip(rrange, crange))
r1 = Int(A.colptr[c])
r2 = Int(A.colptr[c+1]-1)
r1 > r2 && continue
r1 = searchsortedfirst(A.rowval, r, r1, r2, Forward)
((r1 > r2) || (A.rowval[r1] != r)) && continue
push!(ind, i)
push!(val, A.nzval[r1])
end
return SparseVector{Tv,Ti}(l, ind, val)
return SparseVector{Tv,Ti}(length(rrange), ind, val)
end

function trace(A::SparseMatrixCSC{Tv}) where Tv
if size(A,1) != size(A,2)
throw(DimensionMismatch("expected square matrix"))
end
n = checksquare(A)
s = zero(Tv)
for d in SpDiagIterator(A,0)
s += d
for i in 1:n
s += A[i,i]
end
s
return s
end

function diagm(v::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
Expand Down
2 changes: 1 addition & 1 deletion test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ end
for S in (S1, S2, S3)
A = Matrix(S)
@test diag(S)::SparseVector{T,Int} == diag(A)
for k in -5:5
for k in -size(S,1):size(S,2)
@test diag(S, k)::SparseVector{T,Int} == diag(A, k)
end
@test_throws ArgumentError diag(S, -size(S,1)-1)
Expand Down

0 comments on commit 00905db

Please sign in to comment.