Skip to content

Commit 3ee621a

Browse files
authored
More matrix factorizations (#52)
1 parent 43553e3 commit 3ee621a

7 files changed

+487
-272
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.9"
4+
version = "0.2.10"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/MatrixAlgebra.jl

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
module MatrixAlgebra
2+
3+
export eigen,
4+
eigen!,
5+
eigvals,
6+
eigvals!,
7+
factorize,
8+
factorize!,
9+
lq,
10+
lq!,
11+
orth,
12+
orth!,
13+
polar,
14+
polar!,
15+
qr,
16+
qr!,
17+
svd,
18+
svd!,
19+
svdvals,
20+
svdvals!
21+
22+
using LinearAlgebra: LinearAlgebra
23+
using MatrixAlgebraKit
24+
25+
for (f, f_full, f_compact) in (
26+
(:qr, :qr_full, :qr_compact),
27+
(:qr!, :qr_full!, :qr_compact!),
28+
(:lq, :lq_full, :lq_compact),
29+
(:lq!, :lq_full!, :lq_compact!),
30+
)
31+
@eval begin
32+
function $f(A::AbstractMatrix; full::Bool=false, kwargs...)
33+
f = full ? $f_full : $f_compact
34+
return f(A; kwargs...)
35+
end
36+
end
37+
end
38+
39+
for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in (
40+
(:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc),
41+
(:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
42+
)
43+
@eval begin
44+
function $eigen(A::AbstractMatrix; trunc=nothing, ishermitian=nothing, kwargs...)
45+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
46+
f = if !isnothing(trunc)
47+
ishermitian ? $eigh_trunc : $eig_trunc
48+
else
49+
ishermitian ? $eigh_full : $eig_full
50+
end
51+
return f(A; kwargs...)
52+
end
53+
end
54+
end
55+
56+
for (eigvals, eigh_vals, eig_vals) in
57+
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!))
58+
@eval begin
59+
function $eigvals(A::AbstractMatrix; ishermitian=nothing, kwargs...)
60+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
61+
f = (ishermitian ? $eigh_vals : $eig_vals)
62+
return f(A; kwargs...)
63+
end
64+
end
65+
end
66+
67+
for (svd, svd_trunc, svd_full, svd_compact) in (
68+
(:svd, :svd_trunc, :svd_full, :svd_compact),
69+
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
70+
)
71+
@eval begin
72+
function $svd(A::AbstractMatrix; full::Bool=false, trunc=nothing, kwargs...)
73+
return if !isnothing(trunc)
74+
@assert !full "Specified both full and truncation, currently not supported"
75+
$svd_trunc(A; trunc, kwargs...)
76+
else
77+
(full ? $svd_full : $svd_compact)(A; kwargs...)
78+
end
79+
end
80+
end
81+
end
82+
83+
for (svdvals, svd_vals) in ((:svdvals, :svd_vals), (:svdvals!, :svd_vals!))
84+
@eval begin
85+
function $svdvals(A::AbstractMatrix; ishermitian=nothing, kwargs...)
86+
return $svd_vals(A; kwargs...)
87+
end
88+
end
89+
end
90+
91+
for (polar, left_polar, right_polar) in
92+
((:polar, :left_polar, :right_polar), (:polar!, :left_polar!, :right_polar!))
93+
@eval begin
94+
function $polar(A::AbstractMatrix; side=:left, kwargs...)
95+
f = if side == :left
96+
$left_polar
97+
elseif side == :right
98+
$right_polar
99+
else
100+
throw(ArgumentError("`side=$side` not supported."))
101+
end
102+
return f(A; kwargs...)
103+
end
104+
end
105+
end
106+
107+
for (orth, left_orth, right_orth) in
108+
((:orth, :left_orth, :right_orth), (:orth!, :left_orth!, :right_orth!))
109+
@eval begin
110+
function $orth(A::AbstractMatrix; side=:left, kwargs...)
111+
f = if side == :left
112+
$left_orth
113+
elseif side == :right
114+
$right_orth
115+
else
116+
throw(ArgumentError("`side=$side` not supported."))
117+
end
118+
return f(A; kwargs...)
119+
end
120+
end
121+
end
122+
123+
for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, :orth!))
124+
@eval begin
125+
function $factorize(A::AbstractMatrix; orth=:left, kwargs...)
126+
f = if orth in (:left, :right)
127+
$orth_f
128+
else
129+
throw(ArgumentError("`orth=$orth` not supported."))
130+
end
131+
return f(A; side=orth, kwargs...)
132+
end
133+
end
134+
end
135+
136+
end

src/TensorAlgebra.jl

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
11
module TensorAlgebra
22

3-
export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals
3+
export contract,
4+
contract!,
5+
eigen,
6+
eigvals,
7+
factorize,
8+
left_null,
9+
left_orth,
10+
left_polar,
11+
lq,
12+
qr,
13+
right_null,
14+
right_orth,
15+
right_polar,
16+
orth,
17+
polar,
18+
svd,
19+
svdvals
420

21+
include("MatrixAlgebra.jl")
522
include("blockedtuple.jl")
623
include("blockedpermutation.jl")
724
include("BaseExtensions/BaseExtensions.jl")

0 commit comments

Comments
 (0)