diff --git a/src/LinearAlgebra.jl b/src/LinearAlgebra.jl index 7be8c25a..d4ab47f3 100644 --- a/src/LinearAlgebra.jl +++ b/src/LinearAlgebra.jl @@ -180,7 +180,8 @@ public AbstractTriangular, symmetric_type, zeroslike, matprod_dest, - fillstored! + fillstored!, + fillband! const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32} const BlasReal = Union{Float64,Float32} diff --git a/src/bidiag.jl b/src/bidiag.jl index 2c0714e7..3a910361 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -1570,3 +1570,30 @@ function Base._sum(A::Bidiagonal, dims::Integer) end res end + +function fillband!(B::Bidiagonal, x, l, u) + if l > u + return B + end + if ((B.uplo == 'U' && (l < 0 || u > 1)) || + (B.uplo == 'L' && (l < -1 || u > 0))) && !iszero(x) + throw_fillband_error(l, u, x) + else + if B.uplo == 'U' + if l <= 1 <= u + fill!(B.ev, x) + end + if l <= 0 <= u + fill!(B.dv, x) + end + else # B.uplo == 'L' + if l <= 0 <= u + fill!(B.dv, x) + end + if l <= -1 <= u + fill!(B.ev, x) + end + end + end + return B +end diff --git a/src/dense.jl b/src/dense.jl index 9e7eff4f..5fc9e61f 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -205,6 +205,23 @@ tril(M::Matrix, k::Integer) = tril!(copy(M), k) fillband!(A::AbstractMatrix, x, l, u) Fill the band between diagonals `l` and `u` with the value `x`. + +# Examples +```jldoctest +julia> A = zeros(4,4) +4×4 Matrix{Float64}: + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + +julia> LinearAlgebra.fillband!(A, 2, 0, 1) +4×4 Matrix{Float64}: + 2.0 2.0 0.0 0.0 + 0.0 2.0 2.0 0.0 + 0.0 0.0 2.0 2.0 + 0.0 0.0 0.0 2.0 +``` """ function fillband!(A::AbstractMatrix{T}, x, l, u) where T require_one_based_indexing(A) diff --git a/src/diagonal.jl b/src/diagonal.jl index 3f72c16c..76572498 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -1219,3 +1219,18 @@ end uppertriangular(D::Diagonal) = D lowertriangular(D::Diagonal) = D + +throw_fillband_error(l, u, x) = throw(ArgumentError(lazy"cannot set bands $l:$u to a nonzero value ($x)")) + +function fillband!(D::Diagonal, x, l, u) + if l > u + return D + end + if (l < 0 || u > 0) && !iszero(x) + throw_fillband_error(l, u, x) + end + if l <= 0 <= u + fill!(D.diag, x) + end + return D +end diff --git a/src/hessenberg.jl b/src/hessenberg.jl index 297ea737..89c609e0 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -124,6 +124,17 @@ lmul!(x::Number, H::UpperHessenberg) = (lmul!(x, H.data); H) fillstored!(H::UpperHessenberg, x) = (fillband!(H.data, x, -1, size(H,2)-1); H) +function fillband!(H::UpperHessenberg, x, l, u) + if l > u + return H + end + if l < -1 && !iszero(x) + throw_fillband_error(l, u, x) + end + fillband!(H.data, x, l, u) + return H +end + +(A::UpperHessenberg, B::UpperHessenberg) = UpperHessenberg(A.data+B.data) -(A::UpperHessenberg, B::UpperHessenberg) = UpperHessenberg(A.data-B.data) diff --git a/src/triangular.jl b/src/triangular.jl index 5b476d24..a82b3dc1 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -930,6 +930,36 @@ fillstored!(A::UnitLowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), -1); fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1); A) fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1); A) +function fillband!(A::LowerOrUnitLowerTriangular, x, l, u) + if l > u + return A + end + if u > 0 && !iszero(x) + throw_fillband_error(l, u, x) + end + isunit = A isa UnitLowerTriangular + if isunit && u >= 0 && x != oneunit(x) + throw(ArgumentError(lazy"cannot set the diagonal band to a non-unit value ($x)")) + end + fillband!(A.data, x, l, min(u, -isunit)) + return A +end + +function fillband!(A::UpperOrUnitUpperTriangular, x, l, u) + if l > u + return A + end + if l < 0 && !iszero(x) + throw_fillband_error(l, u, x) + end + isunit = A isa UnitUpperTriangular + if isunit && l <= 0 && x != oneunit(x) + throw(ArgumentError(lazy"cannot set the diagonal band to a non-unit value ($x)")) + end + fillband!(A.data, x, max(l, isunit), u) + return A +end + # Binary operations # use broadcasting if the parents are strided, where we loop only over the triangular part function +(A::UpperTriangular, B::UpperTriangular) diff --git a/src/tridiag.jl b/src/tridiag.jl index 0e8a5119..e31709fa 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -1189,3 +1189,44 @@ function _opnorm1Inf(A::SymTridiagonal, p::Real) ), normfirst, normend) end + +function fillband!(T::Tridiagonal, x, l, u) + if l > u + return T + end + if (l < -1 || u > 1) && !iszero(x) + throw_fillband_error(l, u, x) + else + if l <= -1 <= u + fill!(T.dl, x) + end + if l <= 0 <= u + fill!(T.d, x) + end + if l <= 1 <= u + fill!(T.du, x) + end + end + return T +end + +function fillband!(T::SymTridiagonal, x, l, u) + if l > u + return T + end + if (l <= 1 <= u) != (l <= -1 <= u) + throw(ArgumentError(lazy"cannot set only one off-diagonal band of a SymTridiagonal")) + elseif (l < -1 || u > 1) && !iszero(x) + throw_fillband_error(l, u, x) + elseif l <= 0 <= u && !issymmetric(x) + throw(ArgumentError(lazy"cannot set entries in the diagonal band of a SymTridiagonal to an asymmetric value $x")) + else + if l <= 0 <= u + fill!(T.dv, x) + end + if l <= 1 <= u + fill!(T.ev, x) + end + end + return T +end diff --git a/test/bidiag.jl b/test/bidiag.jl index c5ed26e1..99a50f5a 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1228,4 +1228,56 @@ end @test_throws BoundsError B[LinearAlgebra.BandIndex(0,size(B,1)+1)] end +@testset "fillband!" begin + @testset "uplo = :U" begin + B = Bidiagonal(zeros(4), zeros(3), :U) + LinearAlgebra.fillband!(B, 2, 1, 1) + @test all(==(2), diagview(B,1)) + LinearAlgebra.fillband!(B, 3, 0, 0) + @test all(==(3), diagview(B,0)) + @test all(==(2), diagview(B,1)) + LinearAlgebra.fillband!(B, 4, 0, 1) + @test all(==(4), diagview(B,0)) + @test all(==(4), diagview(B,1)) + @test_throws ArgumentError LinearAlgebra.fillband!(B, 3, -1, 0) + + LinearAlgebra.fillstored!(B, 1) + LinearAlgebra.fillband!(B, 0, -3, 3) + @test iszero(B) + LinearAlgebra.fillband!(B, 0, -10, 10) + @test iszero(B) + LinearAlgebra.fillstored!(B, 1) + B2 = copy(B) + LinearAlgebra.fillband!(B, 0, -1, -3) + @test B == B2 + LinearAlgebra.fillband!(B, 0, 10, 10) + @test B == B2 + end + + @testset "uplo = :L" begin + B = Bidiagonal(zeros(4), zeros(3), :L) + LinearAlgebra.fillband!(B, 2, -1, -1) + @test all(==(2), diagview(B,-1)) + LinearAlgebra.fillband!(B, 3, 0, 0) + @test all(==(3), diagview(B,0)) + @test all(==(2), diagview(B,-1)) + LinearAlgebra.fillband!(B, 4, -1, 0) + @test all(==(4), diagview(B,0)) + @test all(==(4), diagview(B,-1)) + @test_throws ArgumentError LinearAlgebra.fillband!(B, 3, 0, 1) + + LinearAlgebra.fillstored!(B, 1) + LinearAlgebra.fillband!(B, 0, -3, 3) + @test iszero(B) + LinearAlgebra.fillband!(B, 0, -10, 10) + @test iszero(B) + LinearAlgebra.fillstored!(B, 1) + B2 = copy(B) + LinearAlgebra.fillband!(B, 0, -1, -3) + @test B == B2 + LinearAlgebra.fillband!(B, 0, 10, 10) + @test B == B2 + end +end + end # module TestBidiagonal diff --git a/test/diagonal.jl b/test/diagonal.jl index 02410707..44c65a72 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1499,4 +1499,26 @@ end @test_throws BoundsError D[LinearAlgebra.BandIndex(0,size(D,1)+1)] end +@testset "fillband!" begin + D = Diagonal(zeros(4)) + LinearAlgebra.fillband!(D, 2, 0, 0) + @test all(==(2), diagview(D,0)) + @test all(==(0), diagview(D,-1)) + @test_throws ArgumentError LinearAlgebra.fillband!(D, 3, -2, 2) + + LinearAlgebra.fillstored!(D, 1) + LinearAlgebra.fillband!(D, 0, -3, 3) + @test iszero(D) + LinearAlgebra.fillstored!(D, 1) + LinearAlgebra.fillband!(D, 0, -10, 10) + @test iszero(D) + + LinearAlgebra.fillstored!(D, 1) + D2 = copy(D) + LinearAlgebra.fillband!(D, 0, -1, -3) + @test D == D2 + LinearAlgebra.fillband!(D, 0, 10, 10) + @test D == D2 +end + end # module TestDiagonal diff --git a/test/hessenberg.jl b/test/hessenberg.jl index bf2a725f..40602765 100644 --- a/test/hessenberg.jl +++ b/test/hessenberg.jl @@ -300,4 +300,24 @@ end @test_throws DimensionMismatch hessenberg(zeros(0,0)).Q * ones(1, 2) end +@testset "fillband" begin + U = UpperHessenberg(zeros(4,4)) + @test_throws ArgumentError LinearAlgebra.fillband!(U, 1, -2, 1) + @test iszero(U) + + LinearAlgebra.fillband!(U, 10, -1, 2) + @test all(==(10), diagview(U,-1)) + @test all(==(10), diagview(U,2)) + @test all(==(0), diagview(U,3)) + + LinearAlgebra.fillband!(U, 0, -5, 5) + @test iszero(U) + + U2 = copy(U) + LinearAlgebra.fillband!(U, -10, 1, -2) + @test U == U2 + LinearAlgebra.fillband!(U, -10, 10, 10) + @test U == U2 +end + end # module TestHessenberg diff --git a/test/triangular.jl b/test/triangular.jl index aeb41aa6..d01d79d4 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -934,4 +934,73 @@ end end end +@testset "fillband!" begin + @testset for TT in (UpperTriangular, UnitUpperTriangular) + U = TT(zeros(4,4)) + @test_throws ArgumentError LinearAlgebra.fillband!(U, 1, -1, 1) + if U isa UnitUpperTriangular + @test_throws ArgumentError LinearAlgebra.fillband!(U, 2, 0, 1) + end + # check that the error paths do not mutate the array + if U isa UpperTriangular + @test iszero(U) + end + + LinearAlgebra.fillband!(U, 1, 0, 1) + @test all(==(1), diagview(U,0)) + @test all(==(1), diagview(U,1)) + @test all(==(0), diagview(U,2)) + + LinearAlgebra.fillband!(U, 10, 1, 2) + @test all(==(10), diagview(U,1)) + @test all(==(10), diagview(U,2)) + @test all(==(1), diagview(U,0)) + @test all(==(0), diagview(U,3)) + + if U isa UpperTriangular + LinearAlgebra.fillband!(U, 0, -5, 5) + @test iszero(U) + end + + U2 = copy(U) + LinearAlgebra.fillband!(U, -10, 1, -2) + @test U == U2 + LinearAlgebra.fillband!(U, -10, 10, 10) + @test U == U2 + end + @testset for TT in (LowerTriangular, UnitLowerTriangular) + L = TT(zeros(4,4)) + @test_throws ArgumentError LinearAlgebra.fillband!(L, 1, -1, 1) + if L isa UnitLowerTriangular + @test_throws ArgumentError LinearAlgebra.fillband!(L, 2, -1, 0) + end + # check that the error paths do not mutate the array + if L isa LowerTriangular + @test iszero(L) + end + + LinearAlgebra.fillband!(L, 1, -1, 0) + @test all(==(1), diagview(L,0)) + @test all(==(1), diagview(L,-1)) + @test all(==(0), diagview(L,-2)) + + LinearAlgebra.fillband!(L, 10, -2, -1) + @test all(==(10), diagview(L,-1)) + @test all(==(10), diagview(L,-2)) + @test all(==(1), diagview(L,0)) + @test all(==(0), diagview(L,-3)) + + if L isa LowerTriangular + LinearAlgebra.fillband!(L, 0, -5, 5) + @test iszero(L) + end + + L2 = copy(L) + LinearAlgebra.fillband!(L, -10, -1, -2) + @test L == L2 + LinearAlgebra.fillband!(L, -10, -10, -10) + @test L == L2 + end +end + end # module TestTriangular diff --git a/test/tridiag.jl b/test/tridiag.jl index effea2b0..5f263488 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1207,4 +1207,61 @@ end @test_throws BoundsError S[LinearAlgebra.BandIndex(0,size(S,1)+1)] end +@testset "fillband!" begin + @testset "Tridiagonal" begin + T = Tridiagonal(zeros(3), zeros(4), zeros(3)) + LinearAlgebra.fillband!(T, 2, 1, 1) + @test all(==(2), diagview(T,1)) + @test all(==(0), diagview(T,0)) + @test all(==(0), diagview(T,-1)) + LinearAlgebra.fillband!(T, 3, 0, 0) + @test all(==(3), diagview(T,0)) + @test all(==(2), diagview(T,1)) + @test all(==(0), diagview(T,-1)) + LinearAlgebra.fillband!(T, 4, -1, 1) + @test all(==(4), diagview(T,-1)) + @test all(==(4), diagview(T,0)) + @test all(==(4), diagview(T,1)) + @test_throws ArgumentError LinearAlgebra.fillband!(T, 3, -2, 2) + + LinearAlgebra.fillstored!(T, 1) + LinearAlgebra.fillband!(T, 0, -3, 3) + @test iszero(T) + LinearAlgebra.fillstored!(T, 1) + LinearAlgebra.fillband!(T, 0, -10, 10) + @test iszero(T) + + LinearAlgebra.fillstored!(T, 1) + T2 = copy(T) + LinearAlgebra.fillband!(T, 0, -1, -3) + @test T == T2 + LinearAlgebra.fillband!(T, 0, 10, 10) + @test T == T2 + end + @testset "SymTridiagonal" begin + S = SymTridiagonal(zeros(4), zeros(3)) + @test_throws ArgumentError LinearAlgebra.fillband!(S, 2, -1, -1) + @test_throws ArgumentError LinearAlgebra.fillband!(S, 2, -2, 2) + + LinearAlgebra.fillband!(S, 1, -1, 1) + @test all(==(1), diagview(S,-1)) + @test all(==(1), diagview(S,0)) + @test all(==(1), diagview(S,1)) + + LinearAlgebra.fillstored!(S, 1) + LinearAlgebra.fillband!(S, 0, -3, 3) + @test iszero(S) + LinearAlgebra.fillstored!(S, 1) + LinearAlgebra.fillband!(S, 0, -10, 10) + @test iszero(S) + + LinearAlgebra.fillstored!(S, 1) + S2 = copy(S) + LinearAlgebra.fillband!(S, 0, -1, -3) + @test S == S2 + LinearAlgebra.fillband!(S, 0, 10, 10) + @test S == S2 + end +end + end # module TestTridiagonal