Skip to content

Commit 9be3ba2

Browse files
authored
kron for RectDiagonal fill (#272)
* kron for RectDiagonal fill * specialize sparse for Diagonal Fill
1 parent e3c8eb9 commit 9be3ba2

File tree

4 files changed

+45
-4
lines changed

4 files changed

+45
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "1.3.0"
3+
version = "1.4.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
66
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
77
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
88
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
9-
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat
9+
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat,
10+
parent
1011

1112
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1213
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
@@ -369,6 +370,8 @@ axes(T::UpperOrLowerTriangular{<:Any,<:AbstractFill}) = axes(parent(T))
369370
axes(rd::RectDiagonal) = rd.axes
370371
size(rd::RectDiagonal) = map(length, rd.axes)
371372

373+
parent(rd::RectDiagonal) = rd.diag
374+
372375
@inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T
373376
@boundscheck checkbounds(rd, i, j)
374377
if i == j
@@ -411,7 +414,8 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac
411414

412415

413416
const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
414-
const RectDiagonalEye{T} = RectDiagonal{T,<:Ones{T,1}}
417+
const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes}
418+
const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V}
415419
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
416420
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}
417421

@@ -537,6 +541,15 @@ convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} =
537541
convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} =
538542
convert(SparseMatrixCSC{Tv,Ti}, Z)
539543

544+
function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv}
545+
SparseMatrixCSC{Tv,eltype(axes(R,1))}(R)
546+
end
547+
function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti}
548+
Base.require_one_based_indexing(R)
549+
v = parent(R)
550+
J = getindex_value(v)*I
551+
SparseMatrixCSC{Tv,Ti}(J, size(R))
552+
end
540553

541554
#########
542555
# maximum/minimum

src/fillalgebra.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,4 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat)
453453
sz = _kronsize(f, g)
454454
_kron(f, g, sz)
455455
end
456-
kron(E1::RectDiagonalEye, E2::RectDiagonalEye) = kron(sparse(E1), sparse(E2))
456+
kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2))

test/runtests.jl

+28
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,28 @@ end
514514
convert(AbstractSparseMatrix{Float64,Int},Mat) ==
515515
SMat
516516
end
517+
518+
function testsparsediag(E)
519+
S = @inferred SparseMatrixCSC(E)
520+
@test S == E
521+
S = @inferred SparseMatrixCSC{Float64}(E)
522+
@test S == E
523+
@test S isa SparseMatrixCSC{Float64}
524+
@test convert(SparseMatrixCSC{Float64}, E) == S
525+
S = @inferred SparseMatrixCSC{Float64,Int32}(E)
526+
@test S == E
527+
@test S isa SparseMatrixCSC{Float64,Int32}
528+
@test convert(SparseMatrixCSC{Float64,Int32}, E) == S
529+
end
530+
531+
for f in (Fill(Int8(4),3), Ones{Int8}(3), Zeros{Int8}(3))
532+
E = Diagonal(f)
533+
testsparsediag(E)
534+
for sz in ((3,6), (6,3), (3,3))
535+
E = RectDiagonal(f, sz)
536+
testsparsediag(E)
537+
end
538+
end
517539
end
518540

519541
@testset "==" begin
@@ -1534,6 +1556,12 @@ end
15341556
C = collect(E)
15351557
@test K == kron(C, C)
15361558
@test issparse(kron(E,E))
1559+
1560+
E = RectDiagonal(Fill(4,3), (6,3))
1561+
C = collect(E)
1562+
K = kron(E, E)
1563+
@test K == kron(C, C)
1564+
@test issparse(K)
15371565
end
15381566

15391567
@testset "dot products" begin

0 commit comments

Comments
 (0)