Skip to content

Commit 8ec33f0

Browse files
committed
[CIR] Add 'core-flat' as an option value for '-emit-mlir='
- Both 'func' and 'scf' have structure control flow constraints, e.g., 'func.return' needs to have 'func.func' parent op. The alternative approach to lower CIR into MLIR standard dialects is to lower CIR into 'cf' dialect and then perform structurization on all or selected regions. - Add 'core-flat' option to enable CFG flattening when lowering CIR into MLIR standard dialects. - Enhance 'cir-flatten-cfg' pass to unify returns into branches to a dedicated return block. - Fix 'cir.br' and 'cir.return' lowering to MLIR and allow function declarations.
1 parent 79d0d74 commit 8ec33f0

File tree

11 files changed

+258
-40
lines changed

11 files changed

+258
-40
lines changed

clang/include/clang/CIR/Dialect/Passes.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ std::unique_ptr<Pass> createIdiomRecognizerPass(clang::ASTContext *astCtx);
3737
std::unique_ptr<Pass> createLibOptPass();
3838
std::unique_ptr<Pass> createLibOptPass(clang::ASTContext *astCtx);
3939
std::unique_ptr<Pass> createFlattenCFGPass();
40+
std::unique_ptr<Pass> createFlattenCFGPass(bool throughMLIR);
4041
std::unique_ptr<Pass> createHoistAllocasPass();
4142
std::unique_ptr<Pass> createGotoSolverPass();
4243

4344
/// Create a pass to lower ABI-independent function definitions/calls.
4445
std::unique_ptr<Pass> createCallConvLoweringPass();
4546

46-
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm, bool useCCLowering);
47+
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm, bool useCCLowering,
48+
bool emitCore);
4749

4850
//===----------------------------------------------------------------------===//
4951
// Registration

clang/include/clang/CIR/Dialect/Passes.td

+4
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def FlattenCFG : Pass<"cir-flatten-cfg"> {
127127
In other words, this pass removes such CIR operations like IfOp, LoopOp,
128128
ScopeOp and etc. and produces a flat CIR.
129129
}];
130+
let options = [
131+
Option<"throughMLIR", "through-mlir", "bool", /*default=*/"false",
132+
"Prepare the flat CIR for lowering through MLIR">,
133+
];
130134
let constructor = "mlir::createFlattenCFGPass()";
131135
let dependentDialects = ["cir::CIRDialect"];
132136
}

clang/include/clang/Driver/Options.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -3067,9 +3067,9 @@ def emit_mlir : Flag<["-"], "emit-mlir">, Visibility<[ClangOption]>, Group<Actio
30673067
def emit_mlir_EQ : Joined<["-"], "emit-mlir=">, Visibility<[ClangOption, CC1Option]>, Group<Action_Group>,
30683068
HelpText<"Build ASTs and then lower through ClangIR to the selected MLIR dialect, emit the .mlir file. "
30693069
"Allowed values are `core` for MLIR standard dialects and `llvm` for the LLVM dialect.">,
3070-
Values<"core,llvm,cir,cir-flat">,
3070+
Values<"core,llvm,cir,cir-flat,core-flat">,
30713071
NormalizedValuesScope<"frontend">,
3072-
NormalizedValues<["MLIR_CORE", "MLIR_LLVM", "MLIR_CIR", "MLIR_CIR_FLAT"]>,
3072+
NormalizedValues<["MLIR_CORE", "MLIR_LLVM", "MLIR_CIR", "MLIR_CIR_FLAT", "MLIR_CORE_FLAT"]>,
30733073
MarshallingInfoEnum<FrontendOpts<"MLIRTargetDialect">, "MLIR_CORE">;
30743074
def emit_cir : Flag<["-"], "emit-cir">, Visibility<[ClangOption, CC1Option]>,
30753075
Group<Action_Group>, Alias<emit_mlir_EQ>, AliasArgs<["cir"]>,

clang/include/clang/Frontend/FrontendOptions.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,13 @@ enum ActionKind {
151151
PrintDependencyDirectivesSourceMinimizerOutput
152152
};
153153

154-
enum MLIRDialectKind { MLIR_CORE, MLIR_LLVM, MLIR_CIR, MLIR_CIR_FLAT };
154+
enum MLIRDialectKind {
155+
MLIR_CORE,
156+
MLIR_LLVM,
157+
MLIR_CIR,
158+
MLIR_CIR_FLAT,
159+
MLIR_CORE_FLAT
160+
};
155161

156162
} // namespace frontend
157163

clang/lib/CIR/CodeGen/CIRPasses.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ mlir::LogicalResult runCIRToCIRPasses(
7676
pm.addPass(mlir::createLoweringPreparePass(&astContext));
7777

7878
if (flattenCIR || enableMem2Reg)
79-
mlir::populateCIRPreLoweringPasses(pm, enableCallConvLowering);
79+
mlir::populateCIRPreLoweringPasses(pm, enableCallConvLowering, emitCore);
8080

8181
if (enableMem2Reg)
8282
pm.addPass(mlir::createMem2Reg());
@@ -96,11 +96,12 @@ mlir::LogicalResult runCIRToCIRPasses(
9696

9797
namespace mlir {
9898

99-
void populateCIRPreLoweringPasses(OpPassManager &pm, bool useCCLowering) {
99+
void populateCIRPreLoweringPasses(OpPassManager &pm, bool useCCLowering,
100+
bool emitCore) {
100101
if (useCCLowering)
101102
pm.addPass(createCallConvLoweringPass());
102103
pm.addPass(createHoistAllocasPass());
103-
pm.addPass(createFlattenCFGPass());
104+
pm.addPass(createFlattenCFGPass(/*throughMLIR=*/emitCore));
104105
pm.addPass(createGotoSolverPass());
105106
}
106107

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

+76-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// function region.
1111
//
1212
//===----------------------------------------------------------------------===//
13+
1314
#include "PassDetail.h"
1415
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/IR/PatternMatch.h"
@@ -50,6 +51,8 @@ struct FlattenCFGPass : public FlattenCFGBase<FlattenCFGPass> {
5051

5152
FlattenCFGPass() = default;
5253
void runOnOperation() override;
54+
55+
void setThroughMLIR(bool v) { throughMLIR = v; }
5356
};
5457

5558
struct CIRIfFlattening : public OpRewritePattern<IfOp> {
@@ -855,6 +858,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
855858
return mlir::success();
856859
}
857860
};
861+
858862
class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
859863
public:
860864
using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
@@ -906,23 +910,87 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
906910
}
907911
};
908912

909-
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
913+
/// Rewrite 'cir.return' within the specified 'cir.func' as a branch to a
914+
/// dedicated return block when the CIR needs lowering through MLIR, where
915+
/// 'func.return' must have 'func.func' as its parent. If the structured
916+
/// control flow is preferred, that implies a function could only have a single
917+
/// (unified) 'func.return' in that MLIR.
918+
class CIRReturnUnifying : public mlir::OpRewritePattern<cir::FuncOp> {
919+
using OpRewritePattern<cir::FuncOp>::OpRewritePattern;
920+
921+
mlir::LogicalResult
922+
matchAndRewrite(cir::FuncOp func,
923+
mlir::PatternRewriter &rewriter) const override {
924+
mlir::OpBuilder::InsertionGuard guard(rewriter);
925+
926+
if (func.getRegion().empty())
927+
return mlir::success();
928+
929+
// Collect operations to apply patterns.
930+
llvm::SmallVector<cir::ReturnOp, 4> returnOps;
931+
bool hasReturnOpWithFuncOpParentOnly = true;
932+
func->walk([&](cir::ReturnOp ret) {
933+
returnOps.push_back(ret);
934+
// Check any 'cir.return' without 'cir.func' as the parent. Such
935+
// 'cir.return' needs unifying even when there is just one.
936+
if (!isa<cir::FuncOp>(ret->getParentOp()))
937+
hasReturnOpWithFuncOpParentOnly = false;
938+
});
939+
940+
// Skip unifying if there is only one 'cir.return' and it has 'cir.func' as
941+
// the parent.
942+
if (hasReturnOpWithFuncOpParentOnly && returnOps.size() < 2)
943+
return mlir::success();
944+
945+
bool hasRetVals = (func.getNumResults() > 0);
946+
auto *endBody = &func.getBody().back();
947+
// Create a dedicated return block at the end of the function.
948+
auto *retBlock = endBody->splitBlock(endBody->end());
949+
if (hasRetVals)
950+
retBlock->addArguments(func.getResultTypes(), func.getLoc());
951+
952+
// Rewrite all 'cir.return's as branches to that return block.
953+
for (ReturnOp ret : returnOps) {
954+
rewriter.setInsertionPoint(ret);
955+
rewriter.replaceOpWithNewOp<cir::BrOp>(ret, retBlock, ret.getInput());
956+
}
957+
958+
// Finally, add returnOp in that dedicated return block.
959+
auto builder = OpBuilder::atBlockBegin(retBlock);
960+
if (hasRetVals)
961+
builder.create<cir::ReturnOp>(func.getLoc(), retBlock->getArguments());
962+
else
963+
builder.create<cir::ReturnOp>(func.getLoc());
964+
return mlir::success();
965+
}
966+
};
967+
968+
void populateFlattenCFGPatterns(RewritePatternSet &patterns, bool throughMLIR) {
910969
patterns
911970
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
912971
CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
913972
patterns.getContext());
973+
if (throughMLIR)
974+
patterns.add<CIRReturnUnifying>(patterns.getContext());
914975
}
915976

916977
void FlattenCFGPass::runOnOperation() {
917978
RewritePatternSet patterns(&getContext());
918-
populateFlattenCFGPatterns(patterns);
979+
populateFlattenCFGPatterns(patterns, throughMLIR);
919980

920981
// Collect operations to apply patterns.
921982
llvm::SmallVector<Operation *, 16> ops;
922983
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
923984
if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
924985
ops.push_back(op);
925986
});
987+
// If CIR is to be lowered into MLIR, unify all 'cir.return's as
988+
// 'func.return' needs to has 'func.func' as the parent.
989+
if (throughMLIR)
990+
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
991+
if (isa<FuncOp>(op))
992+
ops.push_back(op);
993+
});
926994

927995
// Apply patterns.
928996
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
@@ -937,4 +1005,10 @@ std::unique_ptr<Pass> createFlattenCFGPass() {
9371005
return std::make_unique<FlattenCFGPass>();
9381006
}
9391007

1008+
std::unique_ptr<Pass> createFlattenCFGPass(bool throughMLIR) {
1009+
auto flatten = std::make_unique<FlattenCFGPass>();
1010+
flatten->setThroughMLIR(throughMLIR);
1011+
return std::move(flatten);
1012+
}
1013+
9401014
} // namespace mlir

clang/lib/CIR/FrontendAction/CIRGenAction.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
206206
feOptions.ClangIRCallConvLowering &&
207207
!(action == CIRGenAction::OutputType::EmitMLIR &&
208208
feOptions.MLIRTargetDialect == frontend::MLIR_CIR);
209+
209210
bool flattenCIR =
210211
action == CIRGenAction::OutputType::EmitMLIR &&
211-
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CIR_FLAT;
212+
(feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE_FLAT ||
213+
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CIR_FLAT);
212214

213-
bool emitCore = action == CIRGenAction::OutputType::EmitMLIR &&
214-
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE;
215+
bool emitCore =
216+
action == CIRGenAction::OutputType::EmitMLIR &&
217+
(feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE ||
218+
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE_FLAT);
215219

216220
// Setup and run CIR pipeline.
217221
std::string passOptParsingFailure;
@@ -283,6 +287,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
283287
case CIRGenAction::OutputType::EmitMLIR: {
284288
switch (feOptions.MLIRTargetDialect) {
285289
case clang::frontend::MLIR_CORE:
290+
case clang::frontend::MLIR_CORE_FLAT:
286291
// case for direct lowering is already checked in compiler invocation
287292
// no need to check here
288293
emitMLIR(lowerFromCIRToMLIR(mlirMod, mlirCtx.get()), false);

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4780,7 +4780,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
47804780
}
47814781

47824782
void populateCIRToLLVMPasses(mlir::OpPassManager &pm, bool useCCLowering) {
4783-
populateCIRPreLoweringPasses(pm, useCCLowering);
4783+
populateCIRPreLoweringPasses(pm, useCCLowering, /*emitCore=*/false);
47844784
pm.addPass(createConvertCIRToLLVMPass());
47854785
}
47864786

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+32-27
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> {
6969
mlir::LogicalResult
7070
matchAndRewrite(cir::ReturnOp op, OpAdaptor adaptor,
7171
mlir::ConversionPatternRewriter &rewriter) const override {
72+
assert(isa<mlir::FunctionOpInterface>(op->getParentOp()) &&
73+
"'func.return' op expects parent op 'func.func'");
7274
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op,
7375
adaptor.getOperands());
7476
return mlir::LogicalResult::success();
@@ -660,10 +662,13 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
660662
resultType ? mlir::TypeRange(resultType)
661663
: mlir::TypeRange()));
662664

663-
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
664-
&signatureConversion)))
665-
return mlir::failure();
666-
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
665+
if (!op.getBody().empty()) {
666+
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
667+
&signatureConversion)))
668+
return mlir::failure();
669+
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
670+
} else
671+
fn.setPrivate();
667672

668673
rewriter.eraseOp(op);
669674
return mlir::LogicalResult::success();
@@ -833,14 +838,15 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
833838
}
834839
};
835840

836-
class CIRBrOpLowering : public mlir::OpRewritePattern<cir::BrOp> {
841+
class CIRBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
837842
public:
838-
using OpRewritePattern<cir::BrOp>::OpRewritePattern;
843+
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;
839844

840845
mlir::LogicalResult
841-
matchAndRewrite(cir::BrOp op,
842-
mlir::PatternRewriter &rewriter) const override {
843-
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest());
846+
matchAndRewrite(cir::BrOp op, OpAdaptor adaptor,
847+
mlir::ConversionPatternRewriter &rewriter) const override {
848+
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
849+
adaptor.getDestOperands());
844850
return mlir::LogicalResult::success();
845851
}
846852
};
@@ -1354,24 +1360,23 @@ class CIRPtrStrideOpLowering
13541360

13551361
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13561362
mlir::TypeConverter &converter) {
1357-
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
1358-
1359-
patterns.add<
1360-
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1361-
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1362-
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1363-
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1364-
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1365-
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1366-
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1367-
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1368-
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1369-
CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1370-
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1371-
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1372-
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1373-
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1374-
patterns.getContext());
1363+
patterns.add<CIRBrOpLowering, CIRReturnLowering, CIRCmpOpLowering,
1364+
CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1365+
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1366+
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1367+
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1368+
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1369+
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1370+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1371+
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1372+
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1373+
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1374+
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1375+
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1376+
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
1377+
CIRVectorCreateLowering, CIRVectorInsertLowering,
1378+
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
1379+
converter, patterns.getContext());
13751380
}
13761381

13771382
static mlir::TypeConverter prepareTypeConverter() {

0 commit comments

Comments
 (0)