Skip to content

Commit d17b005

Browse files
[mlir][scf] Relax requirements for loops fusion (#79187)
Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices. Fix LIT test cases whose loops were not being fused. Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 036a20c commit d17b005

File tree

2 files changed

+174
-74
lines changed

2 files changed

+174
-74
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
8383
if (write == bufferStores.end())
8484
return WalkResult::advance();
8585

86-
// Allow only single write access per buffer.
87-
if (write->second.size() != 1)
86+
// Check that at last one store was retrieved
87+
if (!write->second.size())
8888
return WalkResult::interrupt();
8989

90+
auto storeIndices = write->second.front();
91+
92+
// Multiple writes to the same memref are allowed only on the same indices
93+
for (const auto &othStoreIndices : write->second) {
94+
if (othStoreIndices != storeIndices)
95+
return WalkResult::interrupt();
96+
}
97+
9098
// Check that the load indices of secondPloop coincide with store indices of
9199
// firstPloop for the same memrefs.
92-
auto storeIndices = write->second.front();
93100
auto loadIndices = load.getIndices();
94101
if (storeIndices.size() != loadIndices.size())
95102
return WalkResult::interrupt();

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 164 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ func.func @fuse_empty_loops() {
1313
return
1414
}
1515
// CHECK-LABEL: func @fuse_empty_loops
16-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
17-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
18-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
16+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
17+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
18+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
1919
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
2020
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
2121
// CHECK: scf.reduce
@@ -24,106 +24,106 @@ func.func @fuse_empty_loops() {
2424

2525
// -----
2626

27-
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
28-
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
27+
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
2928
%c2 = arith.constant 2 : index
3029
%c0 = arith.constant 0 : index
3130
%c1 = arith.constant 1 : index
31+
%c1fp = arith.constant 1.0 : f32
3232
%sum = memref.alloc() : memref<2x2xf32>
3333
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
3434
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
35-
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
36-
%sum_elem = arith.addf %B_elem, %C_elem : f32
35+
%sum_elem = arith.addf %B_elem, %c1fp : f32
3736
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
3837
scf.reduce
3938
}
4039
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
4140
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
4241
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
4342
%product_elem = arith.mulf %sum_elem, %A_elem : f32
44-
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
43+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
4544
scf.reduce
4645
}
4746
memref.dealloc %sum : memref<2x2xf32>
4847
return
4948
}
5049
// CHECK-LABEL: func @fuse_two
51-
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
52-
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
53-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
54-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
55-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
50+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
51+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
52+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
53+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
54+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
5655
// CHECK: [[SUM:%.*]] = memref.alloc()
5756
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
5857
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
5958
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
60-
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
61-
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
59+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
6260
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
61+
// CHECK-NOT: scf.parallel
6362
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
6463
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
6564
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
66-
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
65+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
6766
// CHECK: scf.reduce
6867
// CHECK: }
6968
// CHECK: memref.dealloc [[SUM]]
7069

7170
// -----
7271

73-
func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
74-
%result: memref<100x10xf32>) {
75-
%c100 = arith.constant 100 : index
76-
%c10 = arith.constant 10 : index
72+
func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
73+
%c2 = arith.constant 2 : index
7774
%c0 = arith.constant 0 : index
7875
%c1 = arith.constant 1 : index
79-
%broadcast_rhs = memref.alloc() : memref<100x10xf32>
80-
%diff = memref.alloc() : memref<100x10xf32>
81-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
82-
%rhs_elem = memref.load %rhs[%i] : memref<100xf32>
83-
memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
76+
%c1fp = arith.constant 1.0 : f32
77+
%c2fp = arith.constant 2.0 : f32
78+
%sum = memref.alloc() : memref<2x2xf32>
79+
%prod = memref.alloc() : memref<2x2xf32>
80+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
81+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
82+
%sum_elem = arith.addf %B_elem, %c1fp : f32
83+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
8484
scf.reduce
8585
}
86-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
87-
%lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
88-
%broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
89-
%diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
90-
memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
86+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
87+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
88+
%product_elem = arith.mulf %sum_elem, %c2fp : f32
89+
memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
9190
scf.reduce
9291
}
93-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
94-
%diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32>
95-
%exp_elem = math.exp %diff_elem : f32
96-
memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
97-
scf.reduce
92+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
93+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
94+
%res_elem = arith.addf %A_elem, %c2fp : f32
95+
memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
9896
}
99-
memref.dealloc %broadcast_rhs : memref<100x10xf32>
100-
memref.dealloc %diff : memref<100x10xf32>
97+
memref.dealloc %sum : memref<2x2xf32>
98+
memref.dealloc %prod : memref<2x2xf32>
10199
return
102100
}
103101
// CHECK-LABEL: func @fuse_three
104-
// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
105-
// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
106-
// CHECK: [[C100:%.*]] = arith.constant 100 : index
107-
// CHECK: [[C10:%.*]] = arith.constant 10 : index
108-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
109-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
110-
// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc()
111-
// CHECK: [[DIFF:%.*]] = memref.alloc()
102+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
103+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
104+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
105+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
106+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
107+
// CHECK-DAG: [[C2FP:%.*]] = arith.constant 2.
108+
// CHECK: [[SUM:%.*]] = memref.alloc()
109+
// CHECK: [[PROD:%.*]] = memref.alloc()
112110
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
113-
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
114-
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
115-
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
116-
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
117-
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
118-
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
119-
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
120-
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
121-
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
122-
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
111+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
112+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
113+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
114+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
115+
// CHECK-NOT: scf.parallel
116+
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
117+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
118+
// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
119+
// CHECK-NOT: scf.parallel
120+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
121+
// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
122+
// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
123123
// CHECK: scf.reduce
124124
// CHECK: }
125-
// CHECK: memref.dealloc [[BROADCAST_RHS]]
126-
// CHECK: memref.dealloc [[DIFF]]
125+
// CHECK: memref.dealloc [[SUM]]
126+
// CHECK: memref.dealloc [[PROD]]
127127

128128
// -----
129129

@@ -310,49 +310,48 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
310310

311311
// -----
312312

313-
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
314-
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
313+
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
315314
%c2 = arith.constant 2 : index
316315
%c0 = arith.constant 0 : index
317316
%c1 = arith.constant 1 : index
317+
%c1fp = arith.constant 1.0 : f32
318318
%sum = memref.alloc() : memref<2x2xf32>
319319
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
320320
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
321321
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
322-
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
323-
%sum_elem = arith.addf %B_elem, %C_elem : f32
322+
%sum_elem = arith.addf %B_elem, %c1fp : f32
324323
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
325324
scf.reduce
326325
}
327326
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
328327
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
329328
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
330329
%product_elem = arith.mulf %sum_elem, %A_elem : f32
331-
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
330+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
332331
scf.reduce
333332
}
334333
}
335334
memref.dealloc %sum : memref<2x2xf32>
336335
return
337336
}
338337
// CHECK-LABEL: func @nested_fuse
339-
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340-
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
341-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
342-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
343-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
338+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
339+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
340+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
341+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
342+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
344343
// CHECK: [[SUM:%.*]] = memref.alloc()
345344
// CHECK: scf.parallel
346345
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347346
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348347
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
349-
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
350-
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
348+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
351349
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
350+
// CHECK-NOT: scf.parallel
352351
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
353352
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
354353
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
355-
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
354+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
356355
// CHECK: scf.reduce
357356
// CHECK: }
358357
// CHECK: }
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
382381
}
383382
return
384383
}
385-
386384
// %sum and %result may alias with other args, do not fuse loops
387385
// CHECK-LABEL: func @do_not_fuse_alias
388386
// CHECK: scf.parallel
389387
// CHECK: scf.parallel
388+
389+
// -----
390+
391+
func.func @fuse_when_1st_has_multiple_stores(
392+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
393+
%c0 = arith.constant 0 : index
394+
%c1 = arith.constant 1 : index
395+
%c2 = arith.constant 2 : index
396+
%c0fp = arith.constant 0.0 : f32
397+
%sum = memref.alloc() : memref<2x2xf32>
398+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
399+
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
400+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
401+
%sum_elem = arith.addf %B_elem, %B_elem : f32
402+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
403+
scf.reduce
404+
}
405+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
406+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
407+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
408+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
409+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
410+
scf.reduce
411+
}
412+
memref.dealloc %sum : memref<2x2xf32>
413+
return
414+
}
415+
// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
416+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
417+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
418+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
419+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
420+
// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
421+
// CHECK: [[SUM:%.*]] = memref.alloc()
422+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
423+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
424+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
425+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
426+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
427+
// CHECK-NOT: scf.parallel
428+
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
429+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
430+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
431+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
432+
// CHECK: scf.reduce
433+
// CHECK: }
434+
// CHECK: memref.dealloc [[SUM]]
435+
436+
// -----
437+
438+
func.func @do_not_fuse_multiple_stores_on_diff_indices(
439+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
440+
%c0 = arith.constant 0 : index
441+
%c1 = arith.constant 1 : index
442+
%c2 = arith.constant 2 : index
443+
%c0fp = arith.constant 0.0 : f32
444+
%sum = memref.alloc() : memref<2x2xf32>
445+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
446+
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
447+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
448+
%sum_elem = arith.addf %B_elem, %B_elem : f32
449+
memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
450+
scf.reduce
451+
}
452+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
453+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
454+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
455+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
456+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
457+
scf.reduce
458+
}
459+
memref.dealloc %sum : memref<2x2xf32>
460+
return
461+
}
462+
// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
463+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
464+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
465+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
466+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
467+
// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
468+
// CHECK: [[SUM:%.*]] = memref.alloc()
469+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
470+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
471+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
472+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
473+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
474+
// CHECK: scf.reduce
475+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
476+
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
477+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
478+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
479+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
480+
// CHECK: scf.reduce
481+
// CHECK: }
482+
// CHECK: memref.dealloc [[SUM]]

0 commit comments

Comments
 (0)