Skip to content

Commit 93e97b4

Browse files
authored
Cleanup and generalize functions of Hermitian matrices (#1340 encore) (#1358)
1 parent 3e4d569 commit 93e97b4

File tree

2 files changed

+59
-114
lines changed

2 files changed

+59
-114
lines changed

src/symmetric.jl

Lines changed: 49 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,15 @@ const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
224224
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
225225
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
226226
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
227-
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}
227+
const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian}
228228

229229
wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
230230
wrappertype(::Hermitian) = Hermitian
231231

232+
nonhermitianwrappertype(::SymSymTri{<:Real}) = Symmetric
233+
nonhermitianwrappertype(::Hermitian{<:Real}) = Symmetric
234+
nonhermitianwrappertype(::Hermitian) = identity
235+
232236
size(A::HermOrSym) = size(A.data)
233237
axes(A::HermOrSym) = axes(A.data)
234238
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
@@ -830,123 +834,75 @@ function svdvals!(A::RealHermSymComplexHerm)
830834
end
831835

832836
# Matrix functions
833-
^(A::Symmetric{<:Real}, p::Integer) = sympow(A, p)
834-
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
835-
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
836-
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
837-
function sympow(A::SymSymTri, p::Integer)
838-
if p < 0
839-
return Symmetric(Base.power_by_squaring(inv(A), -p))
840-
else
841-
return Symmetric(Base.power_by_squaring(A, p))
842-
end
843-
end
844-
for hermtype in (:Symmetric, :SymTridiagonal)
845-
@eval begin
846-
function ^(A::$hermtype{<:Real}, p::Real)
847-
isinteger(p) && return integerpow(A, p)
848-
F = eigen(A)
849-
if all-> λ 0, F.values)
850-
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
851-
else
852-
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
853-
end
854-
end
855-
function ^(A::$hermtype{<:Complex}, p::Real)
856-
isinteger(p) && return integerpow(A, p)
857-
return Symmetric(schurpow(A, p))
858-
end
859-
end
860-
end
861-
function ^(A::Hermitian, p::Integer)
837+
^(A::SymSymTri{<:Complex}, p::Integer) = sympow(A, p)
838+
^(A::SelfAdjoint, p::Integer) = sympow(A, p)
839+
function sympow(A, p::Integer)
862840
if p < 0
863841
retmat = Base.power_by_squaring(inv(A), -p)
864842
else
865843
retmat = Base.power_by_squaring(A, p)
866844
end
867-
return Hermitian(retmat)
845+
return wrappertype(A)(retmat)
868846
end
869-
function ^(A::Hermitian{T}, p::Real) where T
847+
function ^(A::SelfAdjoint, p::Real)
870848
isinteger(p) && return integerpow(A, p)
871849
F = eigen(A)
872850
if all-> λ 0, F.values)
873851
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
874-
return Hermitian(retmat)
852+
return wrappertype(A)(retmat)
875853
else
876-
retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
877-
if T <: Real
878-
return Symmetric(retmat)
879-
else
880-
return retmat
881-
end
854+
retmat = (F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors'
855+
return nonhermitianwrappertype(A)(retmat)
882856
end
883857
end
858+
function ^(A::SymSymTri{<:Complex}, p::Real)
859+
isinteger(p) && return integerpow(A, p)
860+
return Symmetric(schurpow(A, p))
861+
end
884862

885863
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
886864
@eval begin
887-
function ($func)(A::RealHermSymSymTri)
888-
F = eigen(A)
889-
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
890-
end
891-
function ($func)(A::Hermitian{<:Complex})
865+
function ($func)(A::SelfAdjoint)
892866
F = eigen(A)
893867
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
894-
return Hermitian(retmat)
868+
return wrappertype(A)(retmat)
895869
end
896870
end
897871
end
898872

899-
function cis(A::RealHermSymSymTri)
873+
function cis(A::SelfAdjoint)
900874
F = eigen(A)
901-
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
875+
retmat = F.vectors .* cis.(F.values') * F.vectors'
876+
return nonhermitianwrappertype(A)(retmat)
902877
end
903-
function cis(A::Hermitian{<:Complex})
904-
F = eigen(A)
905-
return F.vectors .* cis.(F.values') * F.vectors'
906-
end
907-
908878

909879
for func in (:acos, :asin)
910880
@eval begin
911-
function ($func)(A::RealHermSymSymTri)
912-
F = eigen(A)
913-
if all-> -1 λ 1, F.values)
914-
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
915-
else
916-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
917-
end
918-
end
919-
function ($func)(A::Hermitian{<:Complex})
881+
function ($func)(A::SelfAdjoint)
920882
F = eigen(A)
921883
if all-> -1 λ 1, F.values)
922884
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
923-
return Hermitian(retmat)
885+
return wrappertype(A)(retmat)
924886
else
925-
return (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
887+
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
888+
return nonhermitianwrappertype(A)(retmat)
926889
end
927890
end
928891
end
929892
end
930893

931-
function acosh(A::RealHermSymSymTri)
932-
F = eigen(A)
933-
if all-> λ 1, F.values)
934-
return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
935-
else
936-
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
937-
end
938-
end
939-
function acosh(A::Hermitian{<:Complex})
894+
function acosh(A::SelfAdjoint)
940895
F = eigen(A)
941896
if all-> λ 1, F.values)
942897
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
943-
return Hermitian(retmat)
898+
return wrappertype(A)(retmat)
944899
else
945-
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
900+
retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
901+
return nonhermitianwrappertype(A)(retmat)
946902
end
947903
end
948904

949-
function sincos(A::RealHermSymSymTri)
905+
function sincos(A::SelfAdjoint)
950906
n = checksquare(A)
951907
F = eigen(A)
952908
T = float(eltype(F.values))
@@ -956,49 +912,28 @@ function sincos(A::RealHermSymSymTri)
956912
end
957913
return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
958914
end
959-
function sincos(A::Hermitian{<:Complex})
960-
n = checksquare(A)
915+
916+
function log(A::SelfAdjoint)
961917
F = eigen(A)
962-
T = float(eltype(F.values))
963-
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
964-
for i in eachindex(S.diag, C.diag, F.values)
965-
S.diag[i], C.diag[i] = sincos(F.values[i])
966-
end
967-
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
968-
for i in diagind(retmatS, IndexStyle(retmatS))
969-
retmatS[i] = real(retmatS[i])
970-
retmatC[i] = real(retmatC[i])
918+
if all-> λ 0, F.values)
919+
retmat = (F.vectors * Diagonal(log.(F.values))) * F.vectors'
920+
return wrappertype(A)(retmat)
921+
else
922+
retmat = (F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors'
923+
return nonhermitianwrappertype(A)(retmat)
971924
end
972-
return Hermitian(retmatS), Hermitian(retmatC)
973925
end
974926

975-
976-
for func in (:log, :sqrt)
977-
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
978-
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
979-
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
980-
@eval begin
981-
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
982-
F = eigen(A)
983-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
984-
if all-> λ λ₀, F.values)
985-
return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
986-
else
987-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
988-
end
989-
end
990-
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
991-
n = checksquare(A)
992-
F = eigen(A)
993-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
994-
if all-> λ λ₀, F.values)
995-
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
996-
return Hermitian(retmat)
997-
else
998-
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
999-
return retmat
1000-
end
1001-
end
927+
# sqrt has rtol kwarg to handle matrices that are semidefinite up to roundoff errors
928+
function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1))
929+
F = eigen(A)
930+
λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
931+
if all-> λ λ₀, F.values)
932+
retmat = (F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors'
933+
return wrappertype(A)(retmat)
934+
else
935+
retmat = (F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors'
936+
return nonhermitianwrappertype(A)(retmat)
1002937
end
1003938
end
1004939

test/symmetric.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,4 +1199,14 @@ end
11991199
end
12001200
end
12011201

1202+
@testset "asin/acos/acosh for matrix outside the real domain" begin
1203+
M = [0 2;2 0] #eigenvalues are ±2
1204+
for T (Float32, Float64, ComplexF32, ComplexF64)
1205+
M2 = Hermitian(T.(M))
1206+
@test sin(asin(M2)) M2
1207+
@test cos(acos(M2)) M2
1208+
@test cosh(acosh(M2)) M2
1209+
end
1210+
end
1211+
12021212
end # module TestSymmetric

0 commit comments

Comments
 (0)