Skip to content

Commit 384b276

Browse files
authored
define matricize and unmatricize (#50)
1 parent 3ee621a commit 384b276

22 files changed

+382
-472
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.10"
4+
version = "0.3.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -19,7 +19,7 @@ BlockArrays = "1.5.0"
1919
EllipsisNotation = "1.8.0"
2020
LinearAlgebra = "1.10"
2121
MatrixAlgebraKit = "0.1.1"
22-
TensorProducts = "0.1.0"
22+
TensorProducts = "0.1.5"
2323
TupleTools = "1.6.0"
2424
TypeParameterAccessors = "0.2.1, 0.3"
2525
julia = "1.10"

docs/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
66
[compat]
77
Documenter = "1.8.1"
88
Literate = "2.20.1"
9-
TensorAlgebra = "0.2.0"
9+
TensorAlgebra = "0.3.0"

examples/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33

44
[compat]
5-
TensorAlgebra = "0.2.0"
5+
TensorAlgebra = "0.3.0"

src/BaseExtensions/permutedims.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Workaround for https://github.com/JuliaLang/julia/issues/52615.
22
# Fixed by https://github.com/JuliaLang/julia/pull/52623.
3+
# TODO remove once support for Julia 1.10 is dropped
34
function _permutedims!(
45
a_dest::AbstractArray{<:Any,N}, a_src::AbstractArray{<:Any,N}, perm::Tuple{Vararg{Int,N}}
56
) where {N}

src/TensorAlgebra.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ include("MatrixAlgebra.jl")
2222
include("blockedtuple.jl")
2323
include("blockedpermutation.jl")
2424
include("BaseExtensions/BaseExtensions.jl")
25-
include("fusedims.jl")
26-
include("splitdims.jl")
25+
include("matricize.jl")
2726
include("contract/contract.jl")
2827
include("contract/output_labels.jl")
2928
include("contract/blockedperms.jl")

src/blockedpermutation.jl

+11-7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ function Base.invperm(bp::AbstractBlockPermutation)
4545
return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp)))
4646
end
4747

48+
# interface
49+
50+
# Bipartition a vector according to the
51+
# bipartitioned permutation.
52+
# Like `Base.permute!` block out-of-place and blocked.
53+
function blockpermute(v, blockedperm::AbstractBlockPermutation)
54+
return tuplemortar(map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)))
55+
end
56+
57+
Base.getindex(v, perm::AbstractBlockPermutation) = blockpermute(v, perm)
58+
4859
#
4960
# Constructors
5061
#
@@ -53,13 +64,6 @@ function blockedperm(bt::AbstractBlockTuple)
5364
return permmortar(blocks(bt))
5465
end
5566

56-
# Bipartition a vector according to the
57-
# bipartitioned permutation.
58-
# Like `Base.permute!` block out-of-place and blocked.
59-
function blockpermute(v, blockedperm::AbstractBlockPermutation)
60-
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
61-
end
62-
6367
# blockedpermvcat((4, 3), (2, 1))
6468
function blockedpermvcat(
6569
permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing

src/contract/allocate_output.jl

+8-117
Original file line numberDiff line numberDiff line change
@@ -4,137 +4,28 @@ using Base.PermutedDimsArrays: genperm
44
# i.e. `ContractAdd`?
55
function output_axes(
66
::typeof(contract),
7-
biperm_dest::BlockedPermutation{2},
7+
biperm_dest::AbstractBlockPermutation{2},
88
a1::AbstractArray,
9-
biperm1::BlockedPermutation{2},
9+
biperm1::AbstractBlockPermutation{2},
1010
a2::AbstractArray,
11-
biperm2::BlockedPermutation{2},
11+
biperm2::AbstractBlockPermutation{2},
1212
α::Number=one(Bool),
1313
)
14-
axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
15-
axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2)
14+
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
15+
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
1616
@assert axes_contracted == axes_contracted2
1717
return genperm((axes_codomain..., axes_domain...), invperm(Tuple(biperm_dest)))
1818
end
1919

20-
# Inner-product contraction.
21-
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
22-
# i.e. `ContractAdd`?
23-
function output_axes(
24-
::typeof(contract),
25-
perm_dest::BlockedPermutation{0},
26-
a1::AbstractArray,
27-
perm1::BlockedPermutation{1},
28-
a2::AbstractArray,
29-
perm2::BlockedPermutation{1},
30-
α::Number=one(Bool),
31-
)
32-
axes_contracted = blockpermute(axes(a1), perm1)
33-
axes_contracted′ = blockpermute(axes(a2), perm2)
34-
@assert axes_contracted == axes_contracted′
35-
return ()
36-
end
37-
38-
# Vec-mat.
39-
function output_axes(
40-
::typeof(contract),
41-
perm_dest::BlockedPermutation{1},
42-
a1::AbstractArray,
43-
perm1::BlockedPermutation{1},
44-
a2::AbstractArray,
45-
biperm2::BlockedPermutation{2},
46-
α::Number=one(Bool),
47-
)
48-
(axes_contracted,) = blockpermute(axes(a1), perm1)
49-
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
50-
@assert axes_contracted == axes_contracted′
51-
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
52-
end
53-
54-
# Mat-vec.
55-
function output_axes(
56-
::typeof(contract),
57-
perm_dest::BlockedPermutation{1},
58-
a1::AbstractArray,
59-
perm1::BlockedPermutation{2},
60-
a2::AbstractArray,
61-
biperm2::BlockedPermutation{1},
62-
α::Number=one(Bool),
63-
)
64-
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
65-
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
66-
@assert axes_contracted == axes_contracted′
67-
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
68-
end
69-
70-
# Outer product.
71-
function output_axes(
72-
::typeof(contract),
73-
biperm_dest::BlockedPermutation{2},
74-
a1::AbstractArray,
75-
perm1::BlockedPermutation{1},
76-
a2::AbstractArray,
77-
perm2::BlockedPermutation{1},
78-
α::Number=one(Bool),
79-
)
80-
@assert istrivialperm(Tuple(perm1))
81-
@assert istrivialperm(Tuple(perm2))
82-
axes_dest = (axes(a1)..., axes(a2)...)
83-
return genperm(axes_dest, invperm(Tuple(biperm_dest)))
84-
end
85-
86-
# Array-scalar contraction.
87-
function output_axes(
88-
::typeof(contract),
89-
perm_dest::BlockedPermutation{1},
90-
a1::AbstractArray,
91-
perm1::BlockedPermutation{1},
92-
a2::AbstractArray,
93-
perm2::BlockedPermutation{0},
94-
α::Number=one(Bool),
95-
)
96-
@assert istrivialperm(Tuple(perm1))
97-
axes_dest = axes(a1)
98-
return genperm(axes_dest, invperm(Tuple(perm_dest)))
99-
end
100-
101-
# Scalar-array contraction.
102-
function output_axes(
103-
::typeof(contract),
104-
perm_dest::BlockedPermutation{1},
105-
a1::AbstractArray,
106-
perm1::BlockedPermutation{0},
107-
a2::AbstractArray,
108-
perm2::BlockedPermutation{1},
109-
α::Number=one(Bool),
110-
)
111-
@assert istrivialperm(Tuple(perm2))
112-
axes_dest = axes(a2)
113-
return genperm(axes_dest, invperm(Tuple(perm_dest)))
114-
end
115-
116-
# Scalar-scalar contraction.
117-
function output_axes(
118-
::typeof(contract),
119-
perm_dest::BlockedPermutation{0},
120-
a1::AbstractArray,
121-
perm1::BlockedPermutation{0},
122-
a2::AbstractArray,
123-
perm2::BlockedPermutation{0},
124-
α::Number=one(Bool),
125-
)
126-
return ()
127-
end
128-
12920
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
13021
# i.e. `ContractAdd`?
13122
function allocate_output(
13223
::typeof(contract),
133-
biperm_dest::BlockedPermutation,
24+
biperm_dest::AbstractBlockPermutation,
13425
a1::AbstractArray,
135-
biperm1::BlockedPermutation,
26+
biperm1::AbstractBlockPermutation,
13627
a2::AbstractArray,
137-
biperm2::BlockedPermutation,
28+
biperm2::AbstractBlockPermutation,
13829
α::Number=one(Bool),
13930
)
14031
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)

src/contract/blockedperms.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
2222
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)
2323

2424
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
25-
biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...)
25+
biperm_dest = blockedpermvcat(permblocks_dest...)
2626
permblocks1 = (perm_codomain1, perm_domain1)
27-
biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...)
27+
biperm1 = blockedpermvcat(permblocks1...)
2828
permblocks2 = (perm_codomain2, perm_domain2)
29-
biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...)
29+
biperm2 = blockedpermvcat(permblocks2...)
3030
return biperm_dest, biperm1, biperm2
3131
end

src/contract/contract.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ default_contract_alg() = Matricize()
1313
function contract!(
1414
alg::Algorithm,
1515
a_dest::AbstractArray,
16-
biperm_dest::BlockedPermutation,
16+
biperm_dest::AbstractBlockPermutation,
1717
a1::AbstractArray,
18-
biperm1::BlockedPermutation,
18+
biperm1::AbstractBlockPermutation,
1919
a2::AbstractArray,
20-
biperm2::BlockedPermutation,
20+
biperm2::AbstractBlockPermutation,
2121
α::Number,
2222
β::Number,
2323
)
@@ -110,11 +110,11 @@ end
110110

111111
function contract(
112112
alg::Algorithm,
113-
biperm_dest::BlockedPermutation,
113+
biperm_dest::AbstractBlockPermutation,
114114
a1::AbstractArray,
115-
biperm1::BlockedPermutation,
115+
biperm1::AbstractBlockPermutation,
116116
a2::AbstractArray,
117-
biperm2::BlockedPermutation,
117+
biperm2::AbstractBlockPermutation,
118118
α::Number;
119119
kwargs...,
120120
)
+9-92
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,20 @@
11
using LinearAlgebra: mul!
22

33
function contract!(
4-
alg::Matricize,
4+
::Matricize,
55
a_dest::AbstractArray,
6-
biperm_dest::BlockedPermutation,
6+
biperm_dest::AbstractBlockPermutation{2},
77
a1::AbstractArray,
8-
biperm1::BlockedPermutation,
8+
biperm1::AbstractBlockPermutation{2},
99
a2::AbstractArray,
10-
biperm2::BlockedPermutation,
10+
biperm2::AbstractBlockPermutation{2},
1111
α::Number,
1212
β::Number,
1313
)
14-
a_dest_mat = fusedims(a_dest, biperm_dest)
15-
a1_mat = fusedims(a1, biperm1)
16-
a2_mat = fusedims(a2, biperm2)
17-
_mul!(a_dest_mat, a1_mat, a2_mat, α, β)
18-
splitdims!(a_dest, a_dest_mat, biperm_dest)
19-
return a_dest
20-
end
21-
22-
# Matrix multiplication.
23-
function _mul!(
24-
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
25-
)
26-
mul!(a_dest, a1, a2, α, β)
27-
return a_dest
28-
end
29-
30-
# Inner product.
31-
function _mul!(
32-
a_dest::AbstractArray{<:Any,0},
33-
a1::AbstractVector,
34-
a2::AbstractVector,
35-
α::Number,
36-
β::Number,
37-
)
38-
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
39-
return a_dest
40-
end
41-
42-
# Vec-mat.
43-
function _mul!(
44-
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
45-
)
46-
mul!(transpose(a_dest), transpose(a1), a2, α, β)
47-
return a_dest
48-
end
49-
50-
# Mat-vec.
51-
function _mul!(
52-
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
53-
)
54-
mul!(a_dest, a1, a2, α, β)
55-
return a_dest
56-
end
57-
58-
# Outer product.
59-
function _mul!(
60-
a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number
61-
)
62-
mul!(a_dest, a1, transpose(a2), α, β)
63-
return a_dest
64-
end
65-
66-
# Array-scalar contraction.
67-
function _mul!(
68-
a_dest::AbstractVector,
69-
a1::AbstractVector,
70-
a2::AbstractArray{<:Any,0},
71-
α::Number,
72-
β::Number,
73-
)
74-
α′ = a2[] * α
75-
a_dest .= a1 .* α′ .+ a_dest .* β
76-
return a_dest
77-
end
78-
79-
# Scalar-array contraction.
80-
function _mul!(
81-
a_dest::AbstractVector,
82-
a1::AbstractArray{<:Any,0},
83-
a2::AbstractVector,
84-
α::Number,
85-
β::Number,
86-
)
87-
# Preserve the ordering in case of non-commutative algebra.
88-
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
89-
return a_dest
90-
end
91-
92-
# Scalar-scalar contraction.
93-
function _mul!(
94-
a_dest::AbstractArray{<:Any,0},
95-
a1::AbstractArray{<:Any,0},
96-
a2::AbstractArray{<:Any,0},
97-
α::Number,
98-
β::Number,
99-
)
100-
# Preserve the ordering in case of non-commutative algebra.
101-
a_dest[] = a1[] * a2[] * α + a_dest[] * β
14+
a_dest_mat = matricize(a_dest, biperm_dest)
15+
a1_mat = matricize(a1, biperm1)
16+
a2_mat = matricize(a2, biperm2)
17+
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
18+
unmatricize!(a_dest, a_dest_mat, biperm_dest)
10219
return a_dest
10320
end

0 commit comments

Comments
 (0)