Skip to content

Commit 797dfaf

Browse files
committed
[CIR] [FlattenCFG] hoist allocas to entry block for funcOp in flattenCFG
1 parent db6b7c0 commit 797dfaf

File tree

10 files changed

+98
-33
lines changed

10 files changed

+98
-33
lines changed

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

+45-5
Original file line numberDiff line numberDiff line change
@@ -868,12 +868,51 @@ class CIRTernaryOpFlattening
868868
return mlir::success();
869869
}
870870
};
871+
class CIRFuncOpFlattening : public mlir::OpRewritePattern<mlir::cir::FuncOp> {
872+
public:
873+
using OpRewritePattern<mlir::cir::FuncOp>::OpRewritePattern;
874+
875+
mlir::LogicalResult
876+
matchAndRewrite(mlir::cir::FuncOp op,
877+
mlir::PatternRewriter &rewriter) const override {
878+
if (op.getRegion().empty())
879+
return mlir::failure();
880+
881+
// Hoist all static allocas to the entry block.
882+
mlir::Block &entryBlock = op.getRegion().front();
883+
884+
llvm::SmallVector<mlir::cir::AllocaOp> allocas;
885+
op.getBody().walk([&](mlir::cir::AllocaOp alloca) {
886+
if (alloca->getBlock() == &entryBlock)
887+
return;
888+
889+
// Don't hoist allocas with dynamic alloca size.
890+
if (alloca.getDynAllocSize() != mlir::Value())
891+
return;
892+
893+
allocas.push_back(alloca);
894+
});
895+
896+
if (allocas.empty())
897+
return mlir::failure();
898+
899+
rewriter.setInsertionPointToStart(&entryBlock);
900+
901+
for (auto alloca : allocas) {
902+
mlir::Operation *new_alloca = rewriter.insert(alloca.clone());
903+
alloca.replaceAllUsesWith(new_alloca);
904+
alloca.erase();
905+
}
906+
907+
return mlir::success();
908+
}
909+
};
871910

872911
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
873-
patterns
874-
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
875-
CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
876-
patterns.getContext());
912+
patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
913+
CIRScopeOpFlattening, CIRSwitchOpFlattening,
914+
CIRTernaryOpFlattening, CIRTryOpFlattening, CIRFuncOpFlattening>(
915+
patterns.getContext());
877916
}
878917

879918
void FlattenCFGPass::runOnOperation() {
@@ -883,7 +922,8 @@ void FlattenCFGPass::runOnOperation() {
883922
// Collect operations to apply patterns.
884923
SmallVector<Operation *, 16> ops;
885924
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
886-
if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
925+
if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp, FuncOp>(
926+
op))
887927
ops.push_back(op);
888928
});
889929

clang/test/CIR/CodeGen/builtin-bit-cast.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ two_ints test_rvalue_aggregate() {
130130
// CIR: }
131131

132132
// LLVM-LABEL: define dso_local %struct.two_ints @_Z21test_rvalue_aggregatev
133-
// LLVM: %[[#SRC_SLOT:]] = alloca i64, i64 1, align 8
134-
// LLVM-NEXT: store i64 42, ptr %[[#SRC_SLOT]], align 8
133+
// LLVM: %[[#SRC_SLOT:]] = alloca i64, i64 1, align 8
134+
// LLVM: store i64 42, ptr %[[#SRC_SLOT]], align 8
135135
// LLVM-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr %{{.+}}, ptr %[[#SRC_SLOT]], i64 8, i1 false)
136136
// LLVM: }

clang/test/CIR/CodeGen/initlist-ptr-ptr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ void test() {
6363
// LLVM: }
6464

6565
// LLVM: define dso_local void @_ZSt4testv()
66-
// LLVM: br label %[[SCOPE_START:.*]],
67-
// LLVM: [[SCOPE_START]]: ; preds = %0
6866
// LLVM: [[INIT_STRUCT:%.*]] = alloca %"class.std::initializer_list<const char *>", i64 1, align 8,
6967
// LLVM: [[ELEM_ARRAY_PTR:%.*]] = alloca [2 x ptr], i64 1, align 8,
68+
// LLVM: br label %[[SCOPE_START:.*]],
69+
// LLVM: [[SCOPE_START]]: ; preds = %0
7070
// LLVM: [[PTR_FIRST_ELEM:%.*]] = getelementptr ptr, ptr [[ELEM_ARRAY_PTR]], i32 0,
7171
// LLVM: store ptr @.str, ptr [[PTR_FIRST_ELEM]], align 8,
7272
// LLVM: [[PTR_SECOND_ELEM:%.*]] = getelementptr ptr, ptr [[PTR_FIRST_ELEM]], i64 1,

clang/test/CIR/CodeGen/initlist-ptr-unsigned.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ void test() {
4747
// LLVM: store %"class.std::initializer_list<int>" [[ARG]], ptr [[LOCAL]], align 8,
4848

4949
// LLVM: define dso_local void @_ZSt4testv()
50-
// LLVM: br label %[[SCOPE_START:.*]],
51-
// LLVM: [[SCOPE_START]]: ; preds = %0
5250
// LLVM: [[INIT_STRUCT:%.*]] = alloca %"class.std::initializer_list<int>", i64 1, align 8,
5351
// LLVM: [[ELEM_ARRAY:%.*]] = alloca [1 x i32], i64 1, align 4,
52+
// LLVM: br label %[[SCOPE_START:.*]],
53+
// LLVM: [[SCOPE_START]]: ; preds = %0
5454
// LLVM: [[PTR_FIRST_ELEM:%.*]] = getelementptr i32, ptr [[ELEM_ARRAY]], i32 0,
5555
// LLVM: store i32 7, ptr [[PTR_FIRST_ELEM]], align 4,
5656
// LLVM: [[ELEM_ARRAY_PTR:%.*]] = getelementptr %"class.std::initializer_list<int>", ptr [[INIT_STRUCT]], i32 0, i32 0,

clang/test/CIR/CodeGen/try-catch-dtors.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ void yo() {
4040

4141
// LLVM-LABEL: @_Z2yov()
4242

43-
// LLVM: 2:
4443
// LLVM: %[[Vec:.*]] = alloca %struct.Vec
4544
// LLVM: br label %[[INVOKE_BB:.*]],
4645

@@ -101,7 +100,7 @@ void yo2() {
101100
// CIR: }
102101

103102
// CIR_FLAT-LABEL: @_Z3yo2v
104-
// CIR_FLAT: cir.try_call @_ZN3VecC1Ev(%2) ^[[NEXT_CALL_PREP:.*]], ^[[PAD_NODTOR:.*]] : (!cir.ptr<![[VecTy]]>) -> ()
103+
// CIR_FLAT: cir.try_call @_ZN3VecC1Ev(%[[vec:.+]]) ^[[NEXT_CALL_PREP:.*]], ^[[PAD_NODTOR:.*]] : (!cir.ptr<![[VecTy]]>) -> ()
105104
// CIR_FLAT: ^[[NEXT_CALL_PREP]]:
106105
// CIR_FLAT: cir.br ^[[NEXT_CALL:.*]] loc
107106
// CIR_FLAT: ^[[NEXT_CALL]]:
@@ -117,7 +116,7 @@ void yo2() {
117116
// CIR_FLAT: cir.br ^[[CATCH_BEGIN:.*]](%exception_ptr : !cir.ptr<!void>)
118117
// CIR_FLAT: ^[[PAD_DTOR]]:
119118
// CIR_FLAT: %exception_ptr_0, %type_id_1 = cir.eh.inflight_exception
120-
// CIR_FLAT: cir.call @_ZN3VecD1Ev(%2) : (!cir.ptr<![[VecTy]]>) -> ()
119+
// CIR_FLAT: cir.call @_ZN3VecD1Ev(%[[vec]]) : (!cir.ptr<![[VecTy]]>) -> ()
121120
// CIR_FLAT: cir.br ^[[CATCH_BEGIN]](%exception_ptr_0 : !cir.ptr<!void>)
122121
// CIR_FLAT: ^[[CATCH_BEGIN]](
123122
// CIR_FLAT: cir.catch_param begin
@@ -169,7 +168,6 @@ void yo3(bool x) {
169168
// CIR: cir.return
170169

171170
// CIR_FLAT-LABEL: @_Z3yo3b
172-
// CIR_FLAT: ^bb1:
173171
// CIR_FLAT: %[[V1:.*]] = cir.alloca ![[VecTy]], !cir.ptr<![[VecTy]]>, ["v1"
174172
// CIR_FLAT: %[[V2:.*]] = cir.alloca ![[VecTy]], !cir.ptr<![[VecTy]]>, ["v2"
175173
// CIR_FLAT: %[[V3:.*]] = cir.alloca ![[VecTy]], !cir.ptr<![[VecTy]]>, ["v3"

clang/test/CIR/Lowering/OpenMP/parallel.cir

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ module {
2626
// CHECK: ret void
2727
// CHECK-NEXT: }
2828
// CHECK: define{{.*}} void @omp_parallel..omp_par(ptr
29+
// CHECK: %[[XVar:.*]] = load ptr, ptr %{{.*}}, align 8
2930
// CHECK: %[[YVar:.*]] = load ptr, ptr %{{.*}}, align 8
30-
// CHECK: %[[XVar:.*]] = alloca i32, i64 1, align 4
3131
// CHECK: store i32 1, ptr %[[XVar]], align 4
3232
// CHECK: %[[XVal:.*]] = load i32, ptr %[[XVar]], align 4
3333
// CHECK: %[[BinOp:.*]] = add i32 %[[XVal]], 1

clang/test/CIR/Lowering/dot.cir

+8-8
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ module {
5353
}
5454

5555
// MLIR-LABEL: llvm.func @dot(
56+
// MLIR: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64
57+
// MLIR: %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
5658
// MLIR: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64
5759
// MLIR: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x !llvm.ptr {alignment = 8 : i64} : (i64) -> !llvm.ptr
5860
// MLIR: %[[VAL_5:.*]] = llvm.mlir.constant(1 : index) : i64
@@ -70,13 +72,11 @@ module {
7072
// MLIR: llvm.store %[[VAL_13]], %[[VAL_12]] {{.*}}: f64, !llvm.ptr
7173
// MLIR: llvm.br ^bb1
7274
// MLIR: ^bb1:
73-
// MLIR: %[[VAL_14:.*]] = llvm.mlir.constant(1 : index) : i64
74-
// MLIR: %[[VAL_15:.*]] = llvm.alloca %[[VAL_14]] x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
7575
// MLIR: %[[VAL_16:.*]] = llvm.mlir.constant(0 : i32) : i32
76-
// MLIR: llvm.store %[[VAL_16]], %[[VAL_15]] {{.*}}: i32, !llvm.ptr
76+
// MLIR: llvm.store %[[VAL_16]], %[[VAL_2]] {{.*}}: i32, !llvm.ptr
7777
// MLIR: llvm.br ^bb2
7878
// MLIR: ^bb2:
79-
// MLIR: %[[VAL_17:.*]] = llvm.load %[[VAL_15]] {alignment = 4 : i64} : !llvm.ptr -> i32
79+
// MLIR: %[[VAL_17:.*]] = llvm.load %[[VAL_2]] {alignment = 4 : i64} : !llvm.ptr -> i32
8080
// MLIR: %[[VAL_18:.*]] = llvm.load %[[VAL_8]] {alignment = 4 : i64} : !llvm.ptr -> i32
8181
// MLIR: %[[VAL_19:.*]] = llvm.icmp "slt" %[[VAL_17]], %[[VAL_18]] : i32
8282
// MLIR: %[[VAL_20:.*]] = llvm.zext %[[VAL_19]] : i1 to i32
@@ -85,12 +85,12 @@ module {
8585
// MLIR: llvm.cond_br %[[VAL_22]], ^bb3, ^bb5
8686
// MLIR: ^bb3:
8787
// MLIR: %[[VAL_23:.*]] = llvm.load %[[VAL_4]] {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
88-
// MLIR: %[[VAL_24:.*]] = llvm.load %[[VAL_15]] {alignment = 4 : i64} : !llvm.ptr -> i32
88+
// MLIR: %[[VAL_24:.*]] = llvm.load %[[VAL_2]] {alignment = 4 : i64} : !llvm.ptr -> i32
8989
// MLIR: %[[VAL_25:.*]] = llvm.sext %[[VAL_24]] : i32 to i64
9090
// MLIR: %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_23]]{{\[}}%[[VAL_25]]] : (!llvm.ptr, i64) -> !llvm.ptr, f64
9191
// MLIR: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {alignment = 8 : i64} : !llvm.ptr -> f64
9292
// MLIR: %[[VAL_28:.*]] = llvm.load %[[VAL_6]] {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
93-
// MLIR: %[[VAL_29:.*]] = llvm.load %[[VAL_15]] {alignment = 4 : i64} : !llvm.ptr -> i32
93+
// MLIR: %[[VAL_29:.*]] = llvm.load %[[VAL_2]] {alignment = 4 : i64} : !llvm.ptr -> i32
9494
// MLIR: %[[VAL_30:.*]] = llvm.sext %[[VAL_29]] : i32 to i64
9595
// MLIR: %[[VAL_31:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f64
9696
// MLIR: %[[VAL_32:.*]] = llvm.load %[[VAL_31]] {alignment = 8 : i64} : !llvm.ptr -> f64
@@ -100,10 +100,10 @@ module {
100100
// MLIR: llvm.store %[[VAL_35]], %[[VAL_12]] {{.*}}: f64, !llvm.ptr
101101
// MLIR: llvm.br ^bb4
102102
// MLIR: ^bb4:
103-
// MLIR: %[[VAL_36:.*]] = llvm.load %[[VAL_15]] {alignment = 4 : i64} : !llvm.ptr -> i32
103+
// MLIR: %[[VAL_36:.*]] = llvm.load %[[VAL_2]] {alignment = 4 : i64} : !llvm.ptr -> i32
104104
// MLIR: %[[VAL_37:.*]] = llvm.mlir.constant(1 : i32) : i32
105105
// MLIR: %[[VAL_38:.*]] = llvm.add %[[VAL_36]], %[[VAL_37]] : i32
106-
// MLIR: llvm.store %[[VAL_38]], %[[VAL_15]] {{.*}}: i32, !llvm.ptr
106+
// MLIR: llvm.store %[[VAL_38]], %[[VAL_2]] {{.*}}: i32, !llvm.ptr
107107
// MLIR: llvm.br ^bb2
108108
// MLIR: ^bb5:
109109
// MLIR: llvm.br ^bb6
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o - | FileCheck %s
2+
3+
struct def;
4+
typedef struct def *decl;
5+
struct def {
6+
int index;
7+
};
8+
struct def d;
9+
int foo(unsigned char cond)
10+
{
11+
if (cond)
12+
goto label;
13+
14+
{
15+
decl b = &d;
16+
17+
label:
18+
return b->index;
19+
}
20+
21+
return 0;
22+
}
23+
24+
// It is fine enough to check the LLVM IR are generated succesfully.
25+
// CHECK: define {{.*}}i32 @foo
26+
// CHECK: alloca ptr
27+
// CHECK: alloca i8

clang/test/CIR/Lowering/scope.cir

+8-8
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,23 @@ module {
1414
}
1515

1616
// MLIR: llvm.func @foo()
17-
// MLIR-NEXT: llvm.br ^bb1
18-
// MLIR-NEXT: ^bb1:
17+
// MLIR: [[v2:%[0-9]]] = llvm.mlir.constant(1 : index) : i64
18+
// MLIR: [[v3:%[0-9]]] = llvm.alloca [[v2]] x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
19+
// MLIR: llvm.br ^bb1
20+
// MLIR: ^bb1:
1921
// MLIR-DAG: [[v1:%[0-9]]] = llvm.mlir.constant(4 : i32) : i32
20-
// MLIR-DAG: [[v2:%[0-9]]] = llvm.mlir.constant(1 : index) : i64
21-
// MLIR-DAG: [[v3:%[0-9]]] = llvm.alloca [[v2]] x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
2222
// MLIR-NEXT: llvm.store [[v1]], [[v3]] {{.*}}: i32, !llvm.ptr
2323
// MLIR-NEXT: llvm.br ^bb2
2424
// MLIR-NEXT: ^bb2:
2525
// MLIR-NEXT: llvm.return
2626

2727

2828
// LLVM: define void @foo()
29-
// LLVM-NEXT: br label %1
29+
// LLVM-NEXT: %1 = alloca i32, i64 1, align 4
30+
// LLVM-NEXT: br label %2
3031
// LLVM-EMPTY:
31-
// LLVM-NEXT: 1:
32-
// LLVM-NEXT: %2 = alloca i32, i64 1, align 4
33-
// LLVM-NEXT: store i32 4, ptr %2, align 4
32+
// LLVM-NEXT: 2:
33+
// LLVM-NEXT: store i32 4, ptr %1, align 4
3434
// LLVM-NEXT: br label %3
3535
// LLVM-EMPTY:
3636
// LLVM-NEXT: 3:

clang/test/CIR/Transforms/scope.cir

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ module {
1212
cir.return
1313
}
1414
// CHECK: cir.func @foo() {
15+
// CHECK: %0 = cir.alloca !u32i, !cir.ptr<!u32i>, ["a", init] {alignment = 4 : i64}
1516
// CHECK: cir.br ^bb1
1617
// CHECK: ^bb1: // pred: ^bb0
17-
// CHECK: %0 = cir.alloca !u32i, !cir.ptr<!u32i>, ["a", init] {alignment = 4 : i64}
1818
// CHECK: %1 = cir.const #cir.int<4> : !u32i
1919
// CHECK: cir.store %1, %0 : !u32i, !cir.ptr<!u32i>
2020
// CHECK: cir.br ^bb2

0 commit comments

Comments
 (0)