Skip to content

Commit 77efe55

Browse files
ogauthemtfishman
andauthored
define TensorAlgebra.matricize (#95)
Co-authored-by: Matt Fishman <[email protected]>
1 parent 57c083d commit 77efe55

File tree

5 files changed

+38
-44
lines changed

5 files changed

+38
-44
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -40,7 +40,7 @@ MacroTools = "0.5.13"
4040
MapBroadcast = "0.1.5"
4141
SparseArraysBase = "0.5"
4242
SplitApplyCombine = "1.2.3"
43-
TensorAlgebra = "0.2.4"
43+
TensorAlgebra = "0.3.2"
4444
Test = "1.10"
4545
TypeParameterAccessors = "0.2.0, 0.3"
4646
julia = "1.10"
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
module BlockSparseArraysTensorAlgebraExt
22

3-
using BlockArrays: AbstractBlockedUnitRange
43
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
5-
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
4+
using TensorAlgebra:
5+
TensorAlgebra,
6+
BlockedTrivialPermutation,
7+
BlockedTuple,
8+
FusionStyle,
9+
ReshapeFusion,
10+
fuseaxes
611

7-
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
12+
struct BlockReshapeFusion <: FusionStyle end
813

9-
function TensorAlgebra.fusedims(
10-
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
14+
function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
15+
return BlockReshapeFusion()
16+
end
17+
18+
function TensorAlgebra.matricize(
19+
::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
1120
)
12-
return blockreshape(a, axes)
21+
new_axes = fuseaxes(axes(a), biperm)
22+
return blockreshape(a, new_axes)
1323
end
1424

15-
function TensorAlgebra.splitdims(
16-
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
25+
function TensorAlgebra.unmatricize(
26+
::BlockReshapeFusion,
27+
m::AbstractMatrix,
28+
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
1729
)
18-
return blockreshape(a, axes)
30+
return blockreshape(m, Tuple(blocked_axes)...)
1931
end
2032

2133
end

test/Project.toml

+1-9
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,12 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
66
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
77
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
9-
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
109
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
11-
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
1210
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1411
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1512
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1613
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1714
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
18-
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
1915
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2016
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2117
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
@@ -29,17 +25,13 @@ BlockArrays = "1"
2925
BlockSparseArrays = "0.4"
3026
DiagonalArrays = "0.3"
3127
GPUArraysCore = "0.2"
32-
GradedUnitRanges = "0.2.2"
3328
JLArrays = "0.2"
34-
LabelledNumbers = "0.1"
3529
LinearAlgebra = "1"
36-
Pkg = "1"
3730
Random = "1"
3831
SafeTestsets = "0.1"
3932
SparseArraysBase = "0.5"
4033
Suppressor = "0.2"
41-
SymmetrySectors = "0.1.7"
42-
TensorAlgebra = "0.2.4"
34+
TensorAlgebra = "0.3.2"
4335
Test = "1"
4436
TestExtras = "0.3"
4537
TypeParameterAccessors = "0.3"

test/test_basics.jl

+2-19
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@ using BlockArrays:
44
BlockArrays,
55
Block,
66
BlockArray,
7-
BlockIndexRange,
87
BlockRange,
9-
BlockSlice,
108
BlockVector,
119
BlockedOneTo,
12-
BlockedUnitRange,
1310
BlockedArray,
1411
BlockedVector,
1512
blockedrange,
@@ -35,9 +32,8 @@ using BlockSparseArrays:
3532
view!
3633
using GPUArraysCore: @allowscalar
3734
using JLArrays: JLArray, JLMatrix
38-
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
35+
using LinearAlgebra: Adjoint, Transpose, dot, norm
3936
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
40-
using TensorAlgebra: contract
4137
using Test: @test, @test_broken, @test_throws, @testset, @inferred
4238
using TestExtras: @constinferred
4339
using TypeParameterAccessors: TypeParameterAccessors, Position
@@ -1120,20 +1116,7 @@ arrayts = (Array, JLArray)
11201116
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
11211117
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
11221118
end
1123-
@testset "TensorAlgebra" begin
1124-
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
1125-
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
1126-
a2 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
1127-
a2[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
1128-
# TODO: Make this work, requires customization of `TensorAlgebra.fusedims` and
1129-
# `TensorAlgebra.splitdims` in terms of `BlockSparseArrays.blockreshape`,
1130-
# and customization of `TensorAlgebra.:⊗` in terms of `GradedUnitRanges.tensor_product`.
1131-
a_dest, dimnames_dest = contract(a1, (1, -1), a2, (-1, 2))
1132-
@allowscalar begin
1133-
a_dest_dense, dimnames_dest_dense = contract(Array(a1), (1, -1), Array(a2), (-1, 2))
1134-
@test a_dest a_dest_dense
1135-
end
1136-
end
1119+
11371120
@testset "blockreshape" begin
11381121
a = dev(BlockSparseArray{elt}(undef, ([3, 4], [2, 3])))
11391122
a[Block(1, 2)] = dev(randn(elt, size(@view(a[Block(1, 2)]))))

test/test_tensoralgebraext.jl

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize
1+
using BlockArrays: Block, BlockArray, blockedrange, blocksize
22
using BlockSparseArrays: BlockSparseArray
33
using Random: randn!
44
using TensorAlgebra: contract
@@ -14,6 +14,13 @@ function randn_blockdiagonal(elt::Type, axes::Tuple)
1414
return a
1515
end
1616

17+
@testset "Regression test for BlockArrays" begin
18+
# test https://github.com/ITensor/BlockSparseArrays.jl/issues/57
19+
d = blockedrange([1, 1])
20+
a = BlockArray(ones((d, d, d, d)))
21+
@test contract((-1, -2, -3, -4), a, (1, -1, 2, -2), a, (2, -3, 1, -4)) isa BlockArray
22+
end
23+
1724
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1825
@testset "`contract` `BlockSparseArray` (eltype=$elt)" for elt in elts
1926
@testset "BlockedOneTo" begin
@@ -36,14 +43,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3643
@test a_dest a_dest_dense
3744

3845
# matrix vector
39-
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
40-
#=
46+
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
4147
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
4248
@test dimnames_dest == dimnames_dest_dense
4349
@test size(a_dest) == size(a_dest_dense)
4450
@test a_dest isa BlockSparseArray
4551
@test a_dest a_dest_dense
46-
=#
4752

4853
# vector matrix
4954
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
@@ -54,12 +59,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
5459
@test a_dest a_dest_dense
5560

5661
# vector vector
62+
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
63+
#=
5764
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
58-
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
5965
@test dimnames_dest == dimnames_dest_dense
6066
@test size(a_dest) == size(a_dest_dense)
6167
@test a_dest isa BlockSparseArray{elt,0}
6268
@test a_dest ≈ a_dest_dense
69+
=#
6370

6471
# outer product
6572
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))

0 commit comments

Comments
 (0)