Skip to content

Commit 49a2b45

Browse files
committed
[CIR] Add 'core-flat' as an option value for '-emit-mlir='
- Both 'func' and 'scf' have structure control flow constraints, e.g., 'func.return' needs to have 'func.func' parent op. The alternative approach to lower CIR into MLIR standard dialects is to lower CIR into 'cf' dialect and perform structurization on all or selected regions. - Add 'core-flat' option to enable CFG flattening when lowering CIR into MLIR standard dialects. - Add 'cir-unify-func-return' pass to unify returns into branches to a trialing block dedicated for function return. - Fix 'cir.br' and 'cir.return' lowering to MLIR and allow function declarations.
1 parent fa5b07c commit 49a2b45

File tree

14 files changed

+265
-40
lines changed

14 files changed

+265
-40
lines changed

Diff for: clang/include/clang/CIR/CIRToCIRPasses.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ mlir::LogicalResult runCIRToCIRPasses(
3535
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
3636
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
3737
bool enableCIRSimplify, bool flattenCIR, bool emitMLIR,
38-
bool enableCallConvLowering, bool enableMem2reg);
38+
bool enableCallConvLowering, bool enableMem2reg, bool flattenCore);
3939

4040
} // namespace cir
4141

Diff for: clang/include/clang/CIR/Dialect/Passes.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ std::unique_ptr<Pass> createLibOptPass(clang::ASTContext *astCtx);
3939
std::unique_ptr<Pass> createFlattenCFGPass();
4040
std::unique_ptr<Pass> createHoistAllocasPass();
4141
std::unique_ptr<Pass> createGotoSolverPass();
42+
std::unique_ptr<Pass> createUnifyFuncReturnPass();
4243

4344
/// Create a pass to lower ABI-independent function definitions/calls.
4445
std::unique_ptr<Pass> createCallConvLoweringPass();
4546

46-
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm, bool useCCLowering);
47+
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm, bool useCCLowering,
48+
bool flattenCore);
4749

4850
//===----------------------------------------------------------------------===//
4951
// Registration

Diff for: clang/include/clang/CIR/Dialect/Passes.td

+10
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,14 @@ def CallConvLowering : Pass<"cir-call-conv-lowering"> {
191191
let dependentDialects = ["cir::CIRDialect"];
192192
}
193193

194+
def UnifyFuncReturn : Pass<"cir-unify-func-return"> {
195+
let summary = "Unify function return";
196+
let description = [{
197+
This pass creates a dedicated block as the only function return point and
198+
unifies all returns as unconditional branches to that return block.
199+
}];
200+
let constructor = "mlir::createUnifyFuncReturnPass()";
201+
let dependentDialects = ["cir::CIRDialect"];
202+
}
203+
194204
#endif // MLIR_DIALECT_CIR_PASSES

Diff for: clang/include/clang/Driver/Options.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -3120,9 +3120,9 @@ def emit_mlir : Flag<["-"], "emit-mlir">, Visibility<[ClangOption]>, Group<Actio
31203120
def emit_mlir_EQ : Joined<["-"], "emit-mlir=">, Visibility<[ClangOption, CC1Option]>, Group<Action_Group>,
31213121
HelpText<"Build ASTs and then lower through ClangIR to the selected MLIR dialect, emit the .mlir file. "
31223122
"Allowed values are `core` for MLIR standard dialects and `llvm` for the LLVM dialect.">,
3123-
Values<"core,llvm,cir,cir-flat">,
3123+
Values<"core,llvm,cir,cir-flat,core-flat">,
31243124
NormalizedValuesScope<"frontend">,
3125-
NormalizedValues<["MLIR_CORE", "MLIR_LLVM", "MLIR_CIR", "MLIR_CIR_FLAT"]>,
3125+
NormalizedValues<["MLIR_CORE", "MLIR_LLVM", "MLIR_CIR", "MLIR_CIR_FLAT", "MLIR_CORE_FLAT"]>,
31263126
MarshallingInfoEnum<FrontendOpts<"MLIRTargetDialect">, "MLIR_CORE">;
31273127
def emit_cir : Flag<["-"], "emit-cir">, Visibility<[ClangOption, CC1Option]>,
31283128
Group<Action_Group>, Alias<emit_mlir_EQ>, AliasArgs<["cir"]>,

Diff for: clang/include/clang/Frontend/FrontendOptions.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ enum ActionKind {
154154
PrintDependencyDirectivesSourceMinimizerOutput
155155
};
156156

157-
enum MLIRDialectKind { MLIR_CORE, MLIR_LLVM, MLIR_CIR, MLIR_CIR_FLAT };
157+
enum MLIRDialectKind {
158+
MLIR_CORE,
159+
MLIR_LLVM,
160+
MLIR_CIR,
161+
MLIR_CIR_FLAT,
162+
MLIR_CORE_FLAT
163+
};
158164

159165
} // namespace frontend
160166

Diff for: clang/lib/CIR/CodeGen/CIRPasses.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ mlir::LogicalResult runCIRToCIRPasses(
2929
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
3030
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
3131
bool enableCIRSimplify, bool flattenCIR, bool emitCore,
32-
bool enableCallConvLowering, bool enableMem2Reg) {
32+
bool enableCallConvLowering, bool enableMem2Reg, bool flattenCore) {
3333

3434
llvm::TimeTraceScope scope("CIR To CIR Passes");
3535

@@ -75,8 +75,8 @@ mlir::LogicalResult runCIRToCIRPasses(
7575

7676
pm.addPass(mlir::createLoweringPreparePass(&astContext));
7777

78-
if (flattenCIR || enableMem2Reg)
79-
mlir::populateCIRPreLoweringPasses(pm, enableCallConvLowering);
78+
if (flattenCIR || flattenCore || enableMem2Reg)
79+
mlir::populateCIRPreLoweringPasses(pm, enableCallConvLowering, flattenCore);
8080

8181
if (enableMem2Reg)
8282
pm.addPass(mlir::createMem2Reg());
@@ -96,11 +96,14 @@ mlir::LogicalResult runCIRToCIRPasses(
9696

9797
namespace mlir {
9898

99-
void populateCIRPreLoweringPasses(OpPassManager &pm, bool useCCLowering) {
99+
void populateCIRPreLoweringPasses(OpPassManager &pm, bool useCCLowering,
100+
bool flattenCore) {
100101
if (useCCLowering)
101102
pm.addPass(createCallConvLoweringPass());
102103
pm.addPass(createHoistAllocasPass());
103104
pm.addPass(createFlattenCFGPass());
105+
if (flattenCore)
106+
pm.addPass(createUnifyFuncReturnPass());
104107
pm.addPass(createGotoSolverPass());
105108
}
106109

Diff for: clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_clang_library(MLIRCIRTransforms
1414
SCFPrepare.cpp
1515
CallConvLowering.cpp
1616
HoistAllocas.cpp
17+
UnifyFuncReturn.cpp
1718

1819
DEPENDS
1920
MLIRCIRPassIncGen

Diff for: clang/lib/CIR/Dialect/Transforms/UnifyFuncReturn.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//====- UnifyFuncReturn.cpp -------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Support/LogicalResult.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
14+
#include "clang/CIR/Dialect/Passes.h"
15+
#include "llvm/Support/TimeProfiler.h"
16+
17+
using namespace mlir;
18+
using namespace cir;
19+
20+
namespace {
21+
22+
struct UnifyFuncReturnPass : public UnifyFuncReturnBase<UnifyFuncReturnPass> {
23+
UnifyFuncReturnPass() = default;
24+
void runOnOperation() override;
25+
26+
private:
27+
void unifyReturn(FuncOp func);
28+
};
29+
30+
struct UnifyReturn : public OpRewritePattern<ReturnOp> {
31+
using OpRewritePattern<ReturnOp>::OpRewritePattern;
32+
33+
UnifyReturn(MLIRContext *context, cir::FuncOp func, Block *retBlock)
34+
: OpRewritePattern<ReturnOp>(context), func(func), retBlock(retBlock) {}
35+
36+
mlir::LogicalResult
37+
matchAndRewrite(cir::ReturnOp ret,
38+
mlir::PatternRewriter &rewriter) const override {
39+
mlir::OpBuilder::InsertionGuard guard(rewriter);
40+
auto fn = ret->getParentOfType<cir::FuncOp>();
41+
if (!fn || fn != func)
42+
return mlir::failure();
43+
// Replace 'return' with 'br <retBlock>'
44+
rewriter.replaceOpWithNewOp<cir::BrOp>(ret, ret.getInput(), retBlock);
45+
return mlir::success();
46+
}
47+
48+
private:
49+
cir::FuncOp func;
50+
Block *retBlock;
51+
};
52+
53+
} // namespace
54+
55+
void UnifyFuncReturnPass::unifyReturn(cir::FuncOp func) {
56+
if (func.getRegion().empty())
57+
return;
58+
59+
bool hasRetVals = func.getNumResults() > 0;
60+
auto *endBody = &func.getBody().back();
61+
auto *retBlock = endBody->splitBlock(endBody->end());
62+
if (hasRetVals)
63+
retBlock->addArguments(func.getResultTypes(), func.getLoc());
64+
65+
RewritePatternSet patterns(&getContext());
66+
patterns.add<UnifyReturn>(patterns.getContext(), func, retBlock);
67+
68+
// Collect operations to apply patterns.
69+
llvm::SmallVector<Operation *, 16> ops;
70+
func->walk([&](cir::ReturnOp op) { ops.push_back(op.getOperation()); });
71+
72+
// Apply patterns.
73+
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
74+
signalPassFailure();
75+
76+
auto builder = OpBuilder::atBlockBegin(retBlock);
77+
if (hasRetVals)
78+
builder.create<cir::ReturnOp>(func.getLoc(), retBlock->getArguments());
79+
else
80+
builder.create<cir::ReturnOp>(func.getLoc());
81+
}
82+
83+
void UnifyFuncReturnPass::runOnOperation() {
84+
llvm::TimeTraceScope scope("Unify function returns");
85+
86+
// Collect operations to apply patterns.
87+
llvm::SmallVector<Operation *, 16> ops;
88+
getOperation()->walk([&](cir::FuncOp op) { unifyReturn(op); });
89+
}
90+
91+
namespace mlir {
92+
std::unique_ptr<Pass> createUnifyFuncReturnPass() {
93+
return std::make_unique<UnifyFuncReturnPass>();
94+
}
95+
} // namespace mlir

Diff for: clang/lib/CIR/FrontendAction/CIRGenAction.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,13 @@ class CIRGenConsumer : public clang::ASTConsumer {
210210
action == CIRGenAction::OutputType::EmitMLIR &&
211211
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CIR_FLAT;
212212

213-
bool emitCore = action == CIRGenAction::OutputType::EmitMLIR &&
214-
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE;
213+
bool emitCore =
214+
action == CIRGenAction::OutputType::EmitMLIR &&
215+
(feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE ||
216+
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE_FLAT);
217+
bool flattenCore =
218+
action == CIRGenAction::OutputType::EmitMLIR &&
219+
feOptions.MLIRTargetDialect == clang::frontend::MLIR_CORE_FLAT;
215220

216221
// Setup and run CIR pipeline.
217222
std::string passOptParsingFailure;
@@ -221,7 +226,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
221226
feOptions.ClangIRIdiomRecognizer, idiomRecognizerOpts,
222227
feOptions.ClangIRLibOpt, libOptOpts, passOptParsingFailure,
223228
codeGenOptions.OptimizationLevel > 0, flattenCIR, emitCore,
224-
enableCCLowering, feOptions.ClangIREnableMem2Reg)
229+
enableCCLowering, feOptions.ClangIREnableMem2Reg, flattenCore)
225230
.failed()) {
226231
if (!passOptParsingFailure.empty())
227232
diagnosticsEngine.Report(diag::err_drv_cir_pass_opt_parsing)
@@ -283,6 +288,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
283288
case CIRGenAction::OutputType::EmitMLIR: {
284289
switch (feOptions.MLIRTargetDialect) {
285290
case clang::frontend::MLIR_CORE:
291+
case clang::frontend::MLIR_CORE_FLAT:
286292
// case for direct lowering is already checked in compiler invocation
287293
// no need to check here
288294
emitMLIR(lowerFromCIRToMLIR(mlirMod, mlirCtx.get()), false);

Diff for: clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4788,7 +4788,7 @@ std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
47884788
}
47894789

47904790
void populateCIRToLLVMPasses(mlir::OpPassManager &pm, bool useCCLowering) {
4791-
populateCIRPreLoweringPasses(pm, useCCLowering);
4791+
populateCIRPreLoweringPasses(pm, useCCLowering, /*flattenCore=*/false);
47924792
pm.addPass(createConvertCIRToLLVMPass());
47934793
}
47944794

Diff for: clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ add_clang_library(clangCIRLoweringThroughMLIR
4242
MLIRMemRefDialect
4343
MLIROpenMPDialect
4444
MLIROpenMPToLLVMIRTranslation
45+
MLIRControlFlowToSCF
4546
)

Diff for: clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+32-27
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> {
6969
mlir::LogicalResult
7070
matchAndRewrite(cir::ReturnOp op, OpAdaptor adaptor,
7171
mlir::ConversionPatternRewriter &rewriter) const override {
72+
assert(isa<mlir::FunctionOpInterface>(op->getParentOp()) &&
73+
"'func.return' op expects parent op 'func.func'");
7274
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op,
7375
adaptor.getOperands());
7476
return mlir::LogicalResult::success();
@@ -662,10 +664,13 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
662664
resultType ? mlir::TypeRange(resultType)
663665
: mlir::TypeRange()));
664666

665-
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
666-
&signatureConversion)))
667-
return mlir::failure();
668-
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
667+
if (!op.getBody().empty()) {
668+
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
669+
&signatureConversion)))
670+
return mlir::failure();
671+
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
672+
} else
673+
fn.setPrivate();
669674

670675
rewriter.eraseOp(op);
671676
return mlir::LogicalResult::success();
@@ -835,14 +840,15 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
835840
}
836841
};
837842

838-
class CIRBrOpLowering : public mlir::OpRewritePattern<cir::BrOp> {
843+
class CIRBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
839844
public:
840-
using OpRewritePattern<cir::BrOp>::OpRewritePattern;
845+
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;
841846

842847
mlir::LogicalResult
843-
matchAndRewrite(cir::BrOp op,
844-
mlir::PatternRewriter &rewriter) const override {
845-
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest());
848+
matchAndRewrite(cir::BrOp op, OpAdaptor adaptor,
849+
mlir::ConversionPatternRewriter &rewriter) const override {
850+
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(),
851+
adaptor.getDestOperands());
846852
return mlir::LogicalResult::success();
847853
}
848854
};
@@ -1356,24 +1362,23 @@ class CIRPtrStrideOpLowering
13561362

13571363
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13581364
mlir::TypeConverter &converter) {
1359-
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
1360-
1361-
patterns.add<
1362-
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1363-
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1364-
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1365-
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1366-
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1367-
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1368-
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1369-
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1370-
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1371-
CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1372-
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1373-
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1374-
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1375-
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1376-
patterns.getContext());
1365+
patterns.add<CIRBrOpLowering, CIRReturnLowering, CIRCmpOpLowering,
1366+
CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1367+
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1368+
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1369+
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1370+
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1371+
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1372+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1373+
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1374+
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1375+
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1376+
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1377+
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1378+
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
1379+
CIRVectorCreateLowering, CIRVectorInsertLowering,
1380+
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
1381+
converter, patterns.getContext());
13771382
}
13781383

13791384
static mlir::TypeConverter prepareTypeConverter() {

0 commit comments

Comments
 (0)