Skip to content

Commit 28577a3

Browse files
authored
Add tensor factorizations through MatrixAlgebraKit (#36)
1 parent 2c46a8e commit 28577a3

11 files changed

+523
-73
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: "Integration Test Request"
2+
3+
on:
4+
issue_comment:
5+
types: [created]
6+
7+
jobs:
8+
integrationrequest:
9+
if: |
10+
github.event.issue.pull_request &&
11+
contains(fromJSON('["OWNER", "COLLABORATOR", "MEMBER"]'), github.event.comment.author_association)
12+
uses: ITensor/ITensorActions/.github/workflows/IntegrationTestRequest.yml@main
13+
with:
14+
localregistry: https://github.com/ITensor/ITensorRegistry.git

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
99
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1112
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1213
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1314

@@ -23,6 +24,7 @@ BlockArrays = "1.2.0"
2324
EllipsisNotation = "1.8.0"
2425
GradedUnitRanges = "0.1.0"
2526
LinearAlgebra = "1.10"
27+
MatrixAlgebraKit = "0.1.1"
2628
TupleTools = "1.6.0"
2729
TypeParameterAccessors = "0.2.1, 0.3"
2830
julia = "1.10"

docs/make.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ makedocs(;
1414
edit_link="main",
1515
assets=String[],
1616
),
17-
pages=["Home" => "index.md"],
17+
pages=["Home" => "index.md", "Reference" => "reference.md"],
1818
)
1919

2020
deploydocs(;

docs/src/reference.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Reference
2+
3+
```@autodocs
4+
Modules = [TensorAlgebra]
5+
```

src/TensorAlgebra.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TensorAlgebra
22

3-
export contract, contract!
3+
export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals
44

55
include("blockedtuple.jl")
66
include("blockedpermutation.jl")

src/factorizations.jl

+285-44
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,286 @@
1-
using ArrayLayouts: LayoutMatrix
2-
using LinearAlgebra: LinearAlgebra, Diagonal
3-
4-
function qr(a::AbstractArray, biperm::BlockedPermutation{2})
5-
a_matricized = fusedims(a, biperm)
6-
# TODO: Make this more generic, allow choosing thin or full,
7-
# make sure this works on GPU.
8-
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
9-
q_matricized = typeof(a_matricized)(q_fact)
10-
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
11-
axes_q = (axes_codomain..., axes(q_matricized, 2))
12-
axes_r = (axes(r_matricized, 1), axes_domain...)
13-
q = splitdims(q_matricized, axes_q)
14-
r = splitdims(r_matricized, axes_r)
15-
return q, r
16-
end
17-
18-
function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain)
19-
# TODO: Generalize to conversion to `Tuple` isn't needed.
20-
return qr(
21-
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
22-
)
23-
end
24-
25-
function svd(a::AbstractArray, biperm::BlockedPermutation{2})
26-
a_matricized = fusedims(a, biperm)
27-
usv_matricized = LinearAlgebra.svd(a_matricized)
28-
u_matricized = usv_matricized.U
29-
s_diag = usv_matricized.S
30-
v_matricized = usv_matricized.Vt
31-
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
32-
axes_u = (axes_codomain..., axes(u_matricized, 2))
33-
axes_v = (axes(v_matricized, 1), axes_domain...)
34-
u = splitdims(u_matricized, axes_u)
35-
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
36-
s = Diagonal(s_diag)
37-
v = splitdims(v_matricized, axes_v)
38-
return u, s, v
39-
end
40-
41-
function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain)
42-
return svd(
43-
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
44-
)
1+
using MatrixAlgebraKit:
2+
eig_full!,
3+
eig_trunc!,
4+
eig_vals!,
5+
eigh_full!,
6+
eigh_trunc!,
7+
eigh_vals!,
8+
left_null!,
9+
lq_full!,
10+
lq_compact!,
11+
qr_full!,
12+
qr_compact!,
13+
right_null!,
14+
svd_full!,
15+
svd_compact!,
16+
svd_trunc!,
17+
svd_vals!
18+
using LinearAlgebra: LinearAlgebra
19+
20+
"""
21+
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R
22+
qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Q, R
23+
24+
Compute the QR decomposition of a generic N-dimensional array, by interpreting it as
25+
a linear map from the domain to the codomain indices. These can be specified either via
26+
their labels, or directly through a `biperm`.
27+
28+
## Keyword arguments
29+
30+
- `full::Bool=false`: select between a "full" or a "compact" decomposition, where `Q` is unitary or `R` is square, respectively.
31+
- `positive::Bool=false`: specify if the diagonal of `R` should be positive, leading to a unique decomposition.
32+
- Other keywords are passed on directly to MatrixAlgebraKit.
33+
34+
See also `MatrixAlgebraKit.qr_full!` and `MatrixAlgebraKit.qr_compact!`.
35+
"""
36+
function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
37+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
38+
return qr(A, biperm; kwargs...)
39+
end
40+
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
41+
# tensor to matrix
42+
A_mat = fusedims(A, biperm)
43+
44+
# factorization
45+
Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...)
46+
47+
# matrix to tensor
48+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
49+
axes_Q = (axes_codomain..., axes(Q, 2))
50+
axes_R = (axes(R, 1), axes_domain...)
51+
return splitdims(Q, axes_Q), splitdims(R, axes_R)
52+
end
53+
54+
"""
55+
lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> L, Q
56+
lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> L, Q
57+
58+
Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as
59+
a linear map from the domain to the codomain indices. These can be specified either via
60+
their labels, or directly through a `biperm`.
61+
62+
## Keyword arguments
63+
64+
- `full::Bool=false`: select between a "full" or a "compact" decomposition, where `Q` is unitary or `L` is square, respectively.
65+
- `positive::Bool=false`: specify if the diagonal of `L` should be positive, leading to a unique decomposition.
66+
- Other keywords are passed on directly to MatrixAlgebraKit.
67+
68+
See also `MatrixAlgebraKit.lq_full!` and `MatrixAlgebraKit.lq_compact!`.
69+
"""
70+
function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
71+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
72+
return lq(A, biperm; kwargs...)
73+
end
74+
function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
75+
# tensor to matrix
76+
A_mat = fusedims(A, biperm)
77+
78+
# factorization
79+
L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...)
80+
81+
# matrix to tensor
82+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
83+
axes_L = (axes_codomain..., axes(L, ndims(L)))
84+
axes_Q = (axes(Q, 1), axes_domain...)
85+
return splitdims(L, axes_L), splitdims(Q, axes_Q)
86+
end
87+
88+
"""
89+
eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D, V
90+
eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D, V
91+
92+
Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as
93+
a linear map from the domain to the codomain indices. These can be specified either via
94+
their labels, or directly through a `biperm`.
95+
96+
## Keyword arguments
97+
98+
- `ishermitian::Bool`: specify if the matrix is Hermitian, which can be used to speed up the
99+
computation. If `false`, the output `eltype` will always be `<:Complex`.
100+
- `trunc`: Truncation keywords for `eig(h)_trunc`.
101+
- Other keywords are passed on directly to MatrixAlgebraKit.
102+
103+
See also `MatrixAlgebraKit.eig_full!`, `MatrixAlgebraKit.eig_trunc!`, `MatrixAlgebraKit.eig_vals!`,
104+
`MatrixAlgebraKit.eigh_full!`, `MatrixAlgebraKit.eigh_trunc!`, and `MatrixAlgebraKit.eigh_vals!`.
105+
"""
106+
function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
107+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
108+
return eigen(A, biperm; kwargs...)
109+
end
110+
function eigen(
111+
A::AbstractArray,
112+
biperm::BlockedPermutation{2};
113+
trunc=nothing,
114+
ishermitian=nothing,
115+
kwargs...,
116+
)
117+
# tensor to matrix
118+
A_mat = fusedims(A, biperm)
119+
120+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)
121+
122+
# factorization
123+
if !isnothing(trunc)
124+
D, V = (ishermitian ? eigh_trunc! : eig_trunc!)(A_mat; trunc, kwargs...)
125+
else
126+
D, V = (ishermitian ? eigh_full! : eig_full!)(A_mat; kwargs...)
127+
end
128+
129+
# matrix to tensor
130+
axes_codomain, = blockpermute(axes(A), biperm)
131+
axes_V = (axes_codomain..., axes(V, ndims(V)))
132+
return D, splitdims(V, axes_V)
133+
end
134+
135+
"""
136+
eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> D
137+
eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> D
138+
139+
Compute the eigenvalues of a generic N-dimensional array, by interpreting it as
140+
a linear map from the domain to the codomain indices. These can be specified either via
141+
their labels, or directly through a `biperm`. The output is a vector of eigenvalues.
142+
143+
## Keyword arguments
144+
145+
- `ishermitian::Bool`: specify if the matrix is Hermitian, which can be used to speed up the
146+
computation. If `false`, the output `eltype` will always be `<:Complex`.
147+
- Other keywords are passed on directly to MatrixAlgebraKit.
148+
149+
See also `MatrixAlgebraKit.eig_vals!` and `MatrixAlgebraKit.eigh_vals!`.
150+
"""
151+
function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
152+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
153+
return eigvals(A, biperm; kwargs...)
154+
end
155+
function eigvals(
156+
A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs...
157+
)
158+
A_mat = fusedims(A, biperm)
159+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)
160+
return (ishermitian ? eigh_vals! : eig_vals!)(A_mat; kwargs...)
161+
end
162+
163+
# TODO: separate out the algorithm selection step from the implementation
164+
"""
165+
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ
166+
svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ
167+
168+
Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as
169+
a linear map from the domain to the codomain indices. These can be specified either via
170+
their labels, or directly through a `biperm`.
171+
172+
## Keyword arguments
173+
174+
- `full::Bool=false`: select between a "thick" or a "thin" decomposition, where both `U` and `Vᴴ`
175+
are unitary or isometric.
176+
- `trunc`: Truncation keywords for `svd_trunc`. Not compatible with `full=true`.
177+
- Other keywords are passed on directly to MatrixAlgebraKit.
178+
179+
See also `MatrixAlgebraKit.svd_full!`, `MatrixAlgebraKit.svd_compact!`, and `MatrixAlgebraKit.svd_trunc!`.
180+
"""
181+
function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
182+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
183+
return svd(A, biperm; kwargs...)
184+
end
185+
function svd(
186+
A::AbstractArray,
187+
biperm::BlockedPermutation{2};
188+
full::Bool=false,
189+
trunc=nothing,
190+
kwargs...,
191+
)
192+
# tensor to matrix
193+
A_mat = fusedims(A, biperm)
194+
195+
# factorization
196+
if !isnothing(trunc)
197+
@assert !full "Specified both full and truncation, currently not supported"
198+
U, S, Vᴴ = svd_trunc!(A_mat; trunc, kwargs...)
199+
else
200+
U, S, Vᴴ = full ? svd_full!(A_mat; kwargs...) : svd_compact!(A_mat; kwargs...)
201+
end
202+
203+
# matrix to tensor
204+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
205+
axes_U = (axes_codomain..., axes(U, 2))
206+
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...)
207+
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ)
208+
end
209+
210+
"""
211+
svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain) -> S
212+
svdvals(A::AbstractArray, biperm::BlockedPermutation{2}) -> S
213+
214+
Compute the singular values of a generic N-dimensional array, by interpreting it as
215+
a linear map from the domain to the codomain indices. These can be specified either via
216+
their labels, or directly through a `biperm`. The output is a vector of singular values.
217+
218+
See also `MatrixAlgebraKit.svd_vals!`.
219+
"""
220+
function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain)
221+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
222+
return svdvals(A, biperm)
223+
end
224+
function svdvals(A::AbstractArray, biperm::BlockedPermutation{2})
225+
A_mat = fusedims(A, biperm)
226+
return svd_vals!(A_mat)
227+
end
228+
229+
"""
230+
left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> N
231+
left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> N
232+
233+
Compute the left nullspace of a generic N-dimensional array, by interpreting it as
234+
a linear map from the domain to the codomain indices. These can be specified either via
235+
their labels, or directly through a `biperm`.
236+
The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`.
237+
238+
## Keyword arguments
239+
240+
- `atol::Real=0`: absolute tolerance for the nullspace computation.
241+
- `rtol::Real=0`: relative tolerance for the nullspace computation.
242+
- `kind::Symbol`: specify the kind of decomposition used to compute the nullspace.
243+
The options are `:qr`, `:qrpos` and `:svd`. The former two require `0 == atol == rtol`.
244+
The default is `:qrpos` if `atol == rtol == 0`, and `:svd` otherwise.
245+
"""
246+
function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
247+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
248+
return left_null(A, biperm; kwargs...)
249+
end
250+
function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
251+
A_mat = fusedims(A, biperm)
252+
N = left_null!(A_mat; kwargs...)
253+
axes_codomain, _ = blockpermute(axes(A), biperm)
254+
axes_N = (axes_codomain..., axes(N, 2))
255+
N_tensor = splitdims(N, axes_N)
256+
return N_tensor
257+
end
258+
259+
"""
260+
right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Nᴴ
261+
right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> Nᴴ
262+
263+
Compute the right nullspace of a generic N-dimensional array, by interpreting it as
264+
a linear map from the domain to the codomain indices. These can be specified either via
265+
their labels, or directly through a `biperm`.
266+
The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`.
267+
268+
## Keyword arguments
269+
270+
- `atol::Real=0`: absolute tolerance for the nullspace computation.
271+
- `rtol::Real=0`: relative tolerance for the nullspace computation.
272+
- `kind::Symbol`: specify the kind of decomposition used to compute the nullspace.
273+
The options are `:lq`, `:lqpos` and `:svd`. The former two require `0 == atol == rtol`.
274+
The default is `:lqpos` if `atol == rtol == 0`, and `:svd` otherwise.
275+
"""
276+
function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
277+
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
278+
return right_null(A, biperm; kwargs...)
279+
end
280+
function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
281+
A_mat = fusedims(A, biperm)
282+
Nᴴ = right_null!(A_mat; kwargs...)
283+
_, axes_domain = blockpermute(axes(A), biperm)
284+
axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...)
285+
return splitdims(Nᴴ, axes_Nᴴ)
45286
end

0 commit comments

Comments
 (0)