Skip to content

Commit f709642

Browse files
[mlir][GPU] Add RecursiveMemoryEffects to gpu.launch (#75315)
Infer the side effects of `gpu.launch` from its body.
1 parent 3903438 commit f709642

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,8 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
672672

673673
def GPU_LaunchOp : GPU_Op<"launch", [
674674
AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
675-
DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
675+
DeclareOpInterfaceMethods<InferIntRangeInterface>,
676+
RecursiveMemoryEffects]>,
676677
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
677678
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
678679
Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,

mlir/test/Dialect/GPU/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ func.func @fold_wait_op_test1() {
1111
}
1212
// CHECK-NOT: gpu.wait
1313

14+
// -----
15+
1416
// Erase duplicate barriers.
1517
// CHECK-LABEL: func @erase_barriers
1618
// CHECK-NEXT: gpu.barrier
@@ -21,6 +23,8 @@ func.func @erase_barriers() {
2123
return
2224
}
2325

26+
// -----
27+
2428
// Replace uses of gpu.wait op with its async dependency.
2529
// CHECK-LABEL: func @fold_wait_op_test2
2630
func.func @fold_wait_op_test2(%arg0: i1) -> (memref<5xf16>, memref<5xf16>) {
@@ -38,6 +42,8 @@ func.func @fold_wait_op_test2(%arg0: i1) -> (memref<5xf16>, memref<5xf16>) {
3842
// CHECK-NEXT: gpu.alloc async [%[[TOKEN1]]] ()
3943
// CHECK-NEXT: return
4044

45+
// -----
46+
4147
// CHECK-LABEL: func @fold_memcpy_op
4248
func.func @fold_memcpy_op(%arg0: i1) {
4349
%cst = arith.constant 0.000000e+00 : f16
@@ -60,6 +66,8 @@ func.func @fold_memcpy_op(%arg0: i1) {
6066
}
6167
// CHECK-NOT: gpu.memcpy
6268

69+
// -----
70+
6371
// We cannot fold memcpy here as dest is a block argument.
6472
// CHECK-LABEL: func @do_not_fold_memcpy_op1
6573
func.func @do_not_fold_memcpy_op1(%arg0: i1, %arg1: memref<2xf16>) {
@@ -75,6 +83,8 @@ func.func @do_not_fold_memcpy_op1(%arg0: i1, %arg1: memref<2xf16>) {
7583
}
7684
// CHECK: gpu.memcpy
7785

86+
// -----
87+
7888
// We cannot fold gpu.memcpy as it is used by an op having read effect on dest.
7989
// CHECK-LABEL: func @do_not_fold_memcpy_op2
8090
func.func @do_not_fold_memcpy_op2(%arg0: i1, %arg1: index) -> f16 {
@@ -92,6 +102,8 @@ func.func @do_not_fold_memcpy_op2(%arg0: i1, %arg1: index) -> f16 {
92102
}
93103
// CHECK: gpu.memcpy
94104

105+
// -----
106+
95107
// We cannot fold gpu.memcpy, as the defining op if dest is not a alloc like op.
96108
// CHECK-LABEL: func @do_not_fold_memcpy_op3
97109
func.func @do_not_fold_memcpy_op3(%arg0: memref<1xi8>, %arg1: memref<i1>) {
@@ -102,6 +114,8 @@ func.func @do_not_fold_memcpy_op3(%arg0: memref<1xi8>, %arg1: memref<i1>) {
102114
}
103115
// CHECK: gpu.memcpy
104116

117+
// -----
118+
105119
// CHECK-LABEL: @memcpy_after_cast
106120
func.func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
107121
// CHECK-NOT: memref.cast
@@ -112,6 +126,8 @@ func.func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
112126
return
113127
}
114128

129+
// -----
130+
115131
// CHECK-LABEL: @memset_after_cast
116132
func.func @memset_after_cast(%arg0: memref<10xf32>, %arg1: f32) {
117133
// CHECK-NOT: memref.cast
@@ -227,3 +243,20 @@ func.func @make_subgroup_reduce_uniform() {
227243
}
228244
return
229245
}
246+
247+
// -----
248+
249+
// The GPU kernel does not have any side effecting ops, so the entire
250+
// gpu.launch op can fold away.
251+
252+
// CHECK-LABEL: func @gpu_launch_without_side_effects
253+
// CHECK-NOT: gpu.launch
254+
func.func @gpu_launch_without_side_effects() {
255+
%0:6 = "test.test1"() : () -> (index, index, index, index, index, index)
256+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
257+
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
258+
%1 = arith.addi %arg0, %arg1 : index
259+
gpu.terminator
260+
}
261+
return
262+
}

0 commit comments

Comments
 (0)