@@ -13,9 +13,9 @@ func.func @fuse_empty_loops() {
13
13
return
14
14
}
15
15
// CHECK-LABEL: func @fuse_empty_loops
16
- // CHECK: [[C2:%.*]] = arith.constant 2 : index
17
- // CHECK: [[C0:%.*]] = arith.constant 0 : index
18
- // CHECK: [[C1:%.*]] = arith.constant 1 : index
16
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
17
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
18
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
19
19
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
20
20
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
21
21
// CHECK: scf.reduce
@@ -24,106 +24,106 @@ func.func @fuse_empty_loops() {
24
24
25
25
// -----
26
26
27
- func.func @fuse_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >,
28
- %C: memref <2 x2 xf32 >, %result: memref <2 x2 xf32 >) {
27
+ func.func @fuse_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
29
28
%c2 = arith.constant 2 : index
30
29
%c0 = arith.constant 0 : index
31
30
%c1 = arith.constant 1 : index
31
+ %c1fp = arith.constant 1.0 : f32
32
32
%sum = memref.alloc () : memref <2 x2 xf32 >
33
33
scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
34
34
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
35
- %C_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
36
- %sum_elem = arith.addf %B_elem , %C_elem : f32
35
+ %sum_elem = arith.addf %B_elem , %c1fp : f32
37
36
memref.store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
38
37
scf.reduce
39
38
}
40
39
scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
41
40
%sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
42
41
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
43
42
%product_elem = arith.mulf %sum_elem , %A_elem : f32
44
- memref.store %product_elem , %result [%i , %j ] : memref <2 x2 xf32 >
43
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
45
44
scf.reduce
46
45
}
47
46
memref.dealloc %sum : memref <2 x2 xf32 >
48
47
return
49
48
}
50
49
// CHECK-LABEL: func @fuse_two
51
- // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
52
- // CHECK-SAME : [[RESULT :%.*]]: {{.*}}) {
53
- // CHECK: [[C2 :%.*]] = arith.constant 2 : index
54
- // CHECK: [[C0 :%.*]] = arith.constant 0 : index
55
- // CHECK: [[C1 :%.*]] = arith.constant 1 : index
50
+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
51
+ // CHECK-DAG : [[C2 :%.*]] = arith.constant 2 : index
52
+ // CHECK-DAG : [[C0 :%.*]] = arith.constant 0 : index
53
+ // CHECK-DAG : [[C1 :%.*]] = arith.constant 1 : index
54
+ // CHECK-DAG : [[C1FP :%.*]] = arith.constant 1.
56
55
// CHECK: [[SUM:%.*]] = memref.alloc()
57
56
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
58
57
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
59
58
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
60
- // CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
61
- // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
59
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
62
60
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
61
+ // CHECK-NOT: scf.parallel
63
62
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
64
63
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
65
64
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
66
- // CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT ]]{{\[}}[[I]], [[J]]]
65
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[B ]]{{\[}}[[I]], [[J]]]
67
66
// CHECK: scf.reduce
68
67
// CHECK: }
69
68
// CHECK: memref.dealloc [[SUM]]
70
69
71
70
// -----
72
71
73
- func.func @fuse_three (%lhs: memref <100 x10 xf32 >, %rhs: memref <100 xf32 >,
74
- %result: memref <100 x10 xf32 >) {
75
- %c100 = arith.constant 100 : index
76
- %c10 = arith.constant 10 : index
72
+ func.func @fuse_three (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
73
+ %c2 = arith.constant 2 : index
77
74
%c0 = arith.constant 0 : index
78
75
%c1 = arith.constant 1 : index
79
- %broadcast_rhs = memref.alloc () : memref <100 x10 xf32 >
80
- %diff = memref.alloc () : memref <100 x10 xf32 >
81
- scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c100 , %c10 ) step (%c1 , %c1 ) {
82
- %rhs_elem = memref.load %rhs [%i ] : memref <100 xf32 >
83
- memref.store %rhs_elem , %broadcast_rhs [%i , %j ] : memref <100 x10 xf32 >
76
+ %c1fp = arith.constant 1.0 : f32
77
+ %c2fp = arith.constant 2.0 : f32
78
+ %sum = memref.alloc () : memref <2 x2 xf32 >
79
+ %prod = memref.alloc () : memref <2 x2 xf32 >
80
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
81
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
82
+ %sum_elem = arith.addf %B_elem , %c1fp : f32
83
+ memref.store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
84
84
scf.reduce
85
85
}
86
- scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c100 , %c10 ) step (%c1 , %c1 ) {
87
- %lhs_elem = memref.load %lhs [%i , %j ] : memref <100 x10 xf32 >
88
- %broadcast_rhs_elem = memref.load %broadcast_rhs [%i , %j ] : memref <100 x10 xf32 >
89
- %diff_elem = arith.subf %lhs_elem , %broadcast_rhs_elem : f32
90
- memref.store %diff_elem , %diff [%i , %j ] : memref <100 x10 xf32 >
86
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
87
+ %sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
88
+ %product_elem = arith.mulf %sum_elem , %c2fp : f32
89
+ memref.store %product_elem , %prod [%i , %j ] : memref <2 x2 xf32 >
91
90
scf.reduce
92
91
}
93
- scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c100 , %c10 ) step (%c1 , %c1 ) {
94
- %diff_elem = memref.load %diff [%i , %j ] : memref <100 x10 xf32 >
95
- %exp_elem = math.exp %diff_elem : f32
96
- memref.store %exp_elem , %result [%i , %j ] : memref <100 x10 xf32 >
97
- scf.reduce
92
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
93
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
94
+ %res_elem = arith.addf %A_elem , %c2fp : f32
95
+ memref.store %res_elem , %B [%i , %j ] : memref <2 x2 xf32 >
98
96
}
99
- memref.dealloc %broadcast_rhs : memref <100 x 10 x f32 >
100
- memref.dealloc %diff : memref <100 x 10 x f32 >
97
+ memref.dealloc %sum : memref <2 x 2 x f32 >
98
+ memref.dealloc %prod : memref <2 x 2 x f32 >
101
99
return
102
100
}
103
101
// CHECK-LABEL: func @fuse_three
104
- // CHECK-SAME: ([[LHS :%.*]]: memref<100x10xf32> , [[RHS :%.*]]: memref<100xf32>,
105
- // CHECK-SAME : [[RESULT :%.*]]: memref<100x10xf32>) {
106
- // CHECK: [[C100 :%.*]] = arith.constant 100 : index
107
- // CHECK: [[C10 :%.*]] = arith.constant 10 : index
108
- // CHECK: [[C0 :%.*]] = arith.constant 0 : index
109
- // CHECK: [[C1 :%.*]] = arith.constant 1 : index
110
- // CHECK: [[BROADCAST_RHS :%.*]] = memref.alloc()
111
- // CHECK: [[DIFF :%.*]] = memref.alloc()
102
+ // CHECK-SAME: ([[A :%.*]]: {{.*}} , [[B :%.*]]: {{.*}}) {
103
+ // CHECK-DAG : [[C2 :%.*]] = arith.constant 2 : index
104
+ // CHECK-DAG : [[C0 :%.*]] = arith.constant 0 : index
105
+ // CHECK-DAG : [[C1 :%.*]] = arith.constant 1 : index
106
+ // CHECK-DAG : [[C1FP :%.*]] = arith.constant 1.
107
+ // CHECK-DAG : [[C2FP :%.*]] = arith.constant 2.
108
+ // CHECK: [[SUM :%.*]] = memref.alloc()
109
+ // CHECK: [[PROD :%.*]] = memref.alloc()
112
110
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
113
- // CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
114
- // CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
115
- // CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
116
- // CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
117
- // CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
118
- // CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
119
- // CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
120
- // CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
121
- // CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
122
- // CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
111
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
112
+ // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
113
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
114
+ // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
115
+ // CHECK-NOT: scf.parallel
116
+ // CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
117
+ // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
118
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
119
+ // CHECK-NOT: scf.parallel
120
+ // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
121
+ // CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
122
+ // CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
123
123
// CHECK: scf.reduce
124
124
// CHECK: }
125
- // CHECK: memref.dealloc [[BROADCAST_RHS ]]
126
- // CHECK: memref.dealloc [[DIFF ]]
125
+ // CHECK: memref.dealloc [[SUM ]]
126
+ // CHECK: memref.dealloc [[PROD ]]
127
127
128
128
// -----
129
129
@@ -310,49 +310,48 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
310
310
311
311
// -----
312
312
313
- func.func @nested_fuse (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >,
314
- %C: memref <2 x2 xf32 >, %result: memref <2 x2 xf32 >) {
313
+ func.func @nested_fuse (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
315
314
%c2 = arith.constant 2 : index
316
315
%c0 = arith.constant 0 : index
317
316
%c1 = arith.constant 1 : index
317
+ %c1fp = arith.constant 1.0 : f32
318
318
%sum = memref.alloc () : memref <2 x2 xf32 >
319
319
scf.parallel (%k ) = (%c0 ) to (%c2 ) step (%c1 ) {
320
320
scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
321
321
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
322
- %C_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
323
- %sum_elem = arith.addf %B_elem , %C_elem : f32
322
+ %sum_elem = arith.addf %B_elem , %c1fp : f32
324
323
memref.store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
325
324
scf.reduce
326
325
}
327
326
scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
328
327
%sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
329
328
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
330
329
%product_elem = arith.mulf %sum_elem , %A_elem : f32
331
- memref.store %product_elem , %result [%i , %j ] : memref <2 x2 xf32 >
330
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
332
331
scf.reduce
333
332
}
334
333
}
335
334
memref.dealloc %sum : memref <2 x2 xf32 >
336
335
return
337
336
}
338
337
// CHECK-LABEL: func @nested_fuse
339
- // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340
- // CHECK-SAME : [[RESULT :%.*]]: {{.*}}) {
341
- // CHECK: [[C2 :%.*]] = arith.constant 2 : index
342
- // CHECK: [[C0 :%.*]] = arith.constant 0 : index
343
- // CHECK: [[C1 :%.*]] = arith.constant 1 : index
338
+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
339
+ // CHECK-DAG : [[C2 :%.*]] = arith.constant 2 : index
340
+ // CHECK-DAG : [[C0 :%.*]] = arith.constant 0 : index
341
+ // CHECK-DAG : [[C1 :%.*]] = arith.constant 1 : index
342
+ // CHECK-DAG : [[C1FP :%.*]] = arith.constant 1.
344
343
// CHECK: [[SUM:%.*]] = memref.alloc()
345
344
// CHECK: scf.parallel
346
345
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347
346
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348
347
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
349
- // CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
350
- // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
348
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
351
349
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
350
+ // CHECK-NOT: scf.parallel
352
351
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
353
352
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
354
353
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
355
- // CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT ]]{{\[}}[[I]], [[J]]]
354
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[B ]]{{\[}}[[I]], [[J]]]
356
355
// CHECK: scf.reduce
357
356
// CHECK: }
358
357
// CHECK: }
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
382
381
}
383
382
return
384
383
}
385
-
386
384
// %sum and %result may alias with other args, do not fuse loops
387
385
// CHECK-LABEL: func @do_not_fuse_alias
388
386
// CHECK: scf.parallel
389
387
// CHECK: scf.parallel
388
+
389
+ // -----
390
+
391
+ func.func @fuse_when_1st_has_multiple_stores (
392
+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
393
+ %c0 = arith.constant 0 : index
394
+ %c1 = arith.constant 1 : index
395
+ %c2 = arith.constant 2 : index
396
+ %c0fp = arith.constant 0.0 : f32
397
+ %sum = memref.alloc () : memref <2 x2 xf32 >
398
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
399
+ memref.store %c0fp , %sum [%i , %j ] : memref <2 x2 xf32 >
400
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
401
+ %sum_elem = arith.addf %B_elem , %B_elem : f32
402
+ memref.store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
403
+ scf.reduce
404
+ }
405
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
406
+ %sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
407
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
408
+ %product_elem = arith.mulf %sum_elem , %A_elem : f32
409
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
410
+ scf.reduce
411
+ }
412
+ memref.dealloc %sum : memref <2 x2 xf32 >
413
+ return
414
+ }
415
+ // CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
416
+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
417
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
418
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
419
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
420
+ // CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
421
+ // CHECK: [[SUM:%.*]] = memref.alloc()
422
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
423
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
424
+ // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
425
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
426
+ // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
427
+ // CHECK-NOT: scf.parallel
428
+ // CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
429
+ // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
430
+ // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
431
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
432
+ // CHECK: scf.reduce
433
+ // CHECK: }
434
+ // CHECK: memref.dealloc [[SUM]]
435
+
436
+ // -----
437
+
438
+ func.func @do_not_fuse_multiple_stores_on_diff_indices (
439
+ %A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
440
+ %c0 = arith.constant 0 : index
441
+ %c1 = arith.constant 1 : index
442
+ %c2 = arith.constant 2 : index
443
+ %c0fp = arith.constant 0.0 : f32
444
+ %sum = memref.alloc () : memref <2 x2 xf32 >
445
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
446
+ memref.store %c0fp , %sum [%i , %j ] : memref <2 x2 xf32 >
447
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
448
+ %sum_elem = arith.addf %B_elem , %B_elem : f32
449
+ memref.store %sum_elem , %sum [%c0 , %j ] : memref <2 x2 xf32 >
450
+ scf.reduce
451
+ }
452
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
453
+ %sum_elem = memref.load %sum [%i , %j ] : memref <2 x2 xf32 >
454
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
455
+ %product_elem = arith.mulf %sum_elem , %A_elem : f32
456
+ memref.store %product_elem , %B [%i , %j ] : memref <2 x2 xf32 >
457
+ scf.reduce
458
+ }
459
+ memref.dealloc %sum : memref <2 x2 xf32 >
460
+ return
461
+ }
462
+ // CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
463
+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
464
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
465
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
466
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
467
+ // CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
468
+ // CHECK: [[SUM:%.*]] = memref.alloc()
469
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
470
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
471
+ // CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
472
+ // CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
473
+ // CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
474
+ // CHECK: scf.reduce
475
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
476
+ // CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
477
+ // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
478
+ // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
479
+ // CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
480
+ // CHECK: scf.reduce
481
+ // CHECK: }
482
+ // CHECK: memref.dealloc [[SUM]]
0 commit comments