@@ -4,137 +4,28 @@ using Base.PermutedDimsArrays: genperm
4
4
# i.e. `ContractAdd`?
5
5
function output_axes (
6
6
:: typeof (contract),
7
- biperm_dest:: BlockedPermutation {2} ,
7
+ biperm_dest:: AbstractBlockPermutation {2} ,
8
8
a1:: AbstractArray ,
9
- biperm1:: BlockedPermutation {2} ,
9
+ biperm1:: AbstractBlockPermutation {2} ,
10
10
a2:: AbstractArray ,
11
- biperm2:: BlockedPermutation {2} ,
11
+ biperm2:: AbstractBlockPermutation {2} ,
12
12
α:: Number = one (Bool),
13
13
)
14
- axes_codomain, axes_contracted = blockpermute (axes (a1), biperm1)
15
- axes_contracted2, axes_domain = blockpermute (axes (a2), biperm2)
14
+ axes_codomain, axes_contracted = blocks (axes (a1)[ biperm1] )
15
+ axes_contracted2, axes_domain = blocks (axes (a2)[ biperm2] )
16
16
@assert axes_contracted == axes_contracted2
17
17
return genperm ((axes_codomain... , axes_domain... ), invperm (Tuple (biperm_dest)))
18
18
end
19
19
20
- # Inner-product contraction.
21
- # TODO : Use `ArrayLayouts`-like `MulAdd` object,
22
- # i.e. `ContractAdd`?
23
- function output_axes (
24
- :: typeof (contract),
25
- perm_dest:: BlockedPermutation{0} ,
26
- a1:: AbstractArray ,
27
- perm1:: BlockedPermutation{1} ,
28
- a2:: AbstractArray ,
29
- perm2:: BlockedPermutation{1} ,
30
- α:: Number = one (Bool),
31
- )
32
- axes_contracted = blockpermute (axes (a1), perm1)
33
- axes_contracted′ = blockpermute (axes (a2), perm2)
34
- @assert axes_contracted == axes_contracted′
35
- return ()
36
- end
37
-
38
- # Vec-mat.
39
- function output_axes (
40
- :: typeof (contract),
41
- perm_dest:: BlockedPermutation{1} ,
42
- a1:: AbstractArray ,
43
- perm1:: BlockedPermutation{1} ,
44
- a2:: AbstractArray ,
45
- biperm2:: BlockedPermutation{2} ,
46
- α:: Number = one (Bool),
47
- )
48
- (axes_contracted,) = blockpermute (axes (a1), perm1)
49
- axes_contracted′, axes_dest = blockpermute (axes (a2), biperm2)
50
- @assert axes_contracted == axes_contracted′
51
- return genperm ((axes_dest... ,), invperm (Tuple (perm_dest)))
52
- end
53
-
54
- # Mat-vec.
55
- function output_axes (
56
- :: typeof (contract),
57
- perm_dest:: BlockedPermutation{1} ,
58
- a1:: AbstractArray ,
59
- perm1:: BlockedPermutation{2} ,
60
- a2:: AbstractArray ,
61
- biperm2:: BlockedPermutation{1} ,
62
- α:: Number = one (Bool),
63
- )
64
- axes_dest, axes_contracted = blockpermute (axes (a1), perm1)
65
- (axes_contracted′,) = blockpermute (axes (a2), biperm2)
66
- @assert axes_contracted == axes_contracted′
67
- return genperm ((axes_dest... ,), invperm (Tuple (perm_dest)))
68
- end
69
-
70
- # Outer product.
71
- function output_axes (
72
- :: typeof (contract),
73
- biperm_dest:: BlockedPermutation{2} ,
74
- a1:: AbstractArray ,
75
- perm1:: BlockedPermutation{1} ,
76
- a2:: AbstractArray ,
77
- perm2:: BlockedPermutation{1} ,
78
- α:: Number = one (Bool),
79
- )
80
- @assert istrivialperm (Tuple (perm1))
81
- @assert istrivialperm (Tuple (perm2))
82
- axes_dest = (axes (a1)... , axes (a2)... )
83
- return genperm (axes_dest, invperm (Tuple (biperm_dest)))
84
- end
85
-
86
- # Array-scalar contraction.
87
- function output_axes (
88
- :: typeof (contract),
89
- perm_dest:: BlockedPermutation{1} ,
90
- a1:: AbstractArray ,
91
- perm1:: BlockedPermutation{1} ,
92
- a2:: AbstractArray ,
93
- perm2:: BlockedPermutation{0} ,
94
- α:: Number = one (Bool),
95
- )
96
- @assert istrivialperm (Tuple (perm1))
97
- axes_dest = axes (a1)
98
- return genperm (axes_dest, invperm (Tuple (perm_dest)))
99
- end
100
-
101
- # Scalar-array contraction.
102
- function output_axes (
103
- :: typeof (contract),
104
- perm_dest:: BlockedPermutation{1} ,
105
- a1:: AbstractArray ,
106
- perm1:: BlockedPermutation{0} ,
107
- a2:: AbstractArray ,
108
- perm2:: BlockedPermutation{1} ,
109
- α:: Number = one (Bool),
110
- )
111
- @assert istrivialperm (Tuple (perm2))
112
- axes_dest = axes (a2)
113
- return genperm (axes_dest, invperm (Tuple (perm_dest)))
114
- end
115
-
116
- # Scalar-scalar contraction.
117
- function output_axes (
118
- :: typeof (contract),
119
- perm_dest:: BlockedPermutation{0} ,
120
- a1:: AbstractArray ,
121
- perm1:: BlockedPermutation{0} ,
122
- a2:: AbstractArray ,
123
- perm2:: BlockedPermutation{0} ,
124
- α:: Number = one (Bool),
125
- )
126
- return ()
127
- end
128
-
129
20
# TODO : Use `ArrayLayouts`-like `MulAdd` object,
130
21
# i.e. `ContractAdd`?
131
22
function allocate_output (
132
23
:: typeof (contract),
133
- biperm_dest:: BlockedPermutation ,
24
+ biperm_dest:: AbstractBlockPermutation ,
134
25
a1:: AbstractArray ,
135
- biperm1:: BlockedPermutation ,
26
+ biperm1:: AbstractBlockPermutation ,
136
27
a2:: AbstractArray ,
137
- biperm2:: BlockedPermutation ,
28
+ biperm2:: AbstractBlockPermutation ,
138
29
α:: Number = one (Bool),
139
30
)
140
31
axes_dest = output_axes (contract, biperm_dest, a1, biperm1, a2, biperm2, α)
0 commit comments