Skip to content

Commit 2356c02

Browse files
committed
[CIR][MLIR] Add scf.scope lowering to standard dialects
1 parent beb3d98 commit 2356c02

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,43 @@ class CIRBrOpLowering : public mlir::OpRewritePattern<mlir::cir::BrOp> {
488488
}
489489
};
490490

491+
class CIRScopeOpLowering : public mlir::OpRewritePattern<mlir::cir::ScopeOp> {
492+
using mlir::OpRewritePattern<mlir::cir::ScopeOp>::OpRewritePattern;
493+
494+
mlir::LogicalResult
495+
matchAndRewrite(mlir::cir::ScopeOp scopeOp,
496+
mlir::PatternRewriter &rewriter) const override {
497+
// Empty scope: just remove it.
498+
if (scopeOp.getRegion().empty()) {
499+
rewriter.eraseOp(scopeOp);
500+
return mlir::success();
501+
}
502+
503+
for (auto &block : scopeOp.getRegion()) {
504+
rewriter.setInsertionPointToEnd(&block);
505+
auto *terminator = block.getTerminator();
506+
rewriter.replaceOpWithNewOp<mlir::memref::AllocaScopeReturnOp>(
507+
terminator, terminator->getOperands());
508+
}
509+
510+
rewriter.setInsertionPoint(scopeOp);
511+
auto newScopeOp = rewriter.create<mlir::memref::AllocaScopeOp>(
512+
scopeOp.getLoc(), scopeOp.getResultTypes());
513+
rewriter.inlineRegionBefore(scopeOp.getScopeRegion(),
514+
newScopeOp.getBodyRegion(),
515+
newScopeOp.getBodyRegion().end());
516+
rewriter.replaceOp(scopeOp, newScopeOp);
517+
518+
return mlir::LogicalResult::success();
519+
}
520+
};
521+
491522
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
492523
mlir::TypeConverter &converter) {
493524
patterns.add<CIRAllocaLowering, CIRLoadLowering, CIRStoreLowering,
494525
CIRConstantLowering, CIRUnaryOpLowering, CIRBinOpLowering,
495526
CIRCmpOpLowering, CIRBrOpLowering, CIRCallLowering,
496-
CIRReturnLowering>(patterns.getContext());
527+
CIRReturnLowering, CIRScopeOpLowering>(patterns.getContext());
497528
patterns.add<CIRFuncLowering>(converter, patterns.getContext());
498529
}
499530

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
module {
5+
cir.func @foo() {
6+
cir.scope {
7+
%0 = cir.alloca i32, cir.ptr <i32>, ["a", init] {alignment = 4 : i64}
8+
%1 = cir.const(4 : i32) : i32
9+
cir.store %1, %0 : i32, cir.ptr <i32>
10+
}
11+
cir.return
12+
}
13+
14+
// MLIR: func.func @foo()
15+
// MLIR-NEXT: memref.alloca_scope
16+
// MLIR-NEXT: %alloca = memref.alloca() {alignment = 4 : i64} : memref<i32>
17+
// MLIR-NEXT: %c4_i32 = arith.constant 4 : i32
18+
// MLIR-NEXT: memref.store %c4_i32, %alloca[] : memref<i32>
19+
// MLIR-NEXT: }
20+
// MLIR-NEXT: return
21+
22+
23+
// LLVM: define void @foo()
24+
// LLVM-NEXT: %1 = call ptr @llvm.stacksave()
25+
// LLVM-NEXT: br label %2
26+
// LLVM-EMPTY:
27+
// LLVM-NEXT: 2:
28+
// LLVM-NEXT: %3 = alloca i32, i64 1, align 4
29+
// LLVM-NEXT: %4 = insertvalue { ptr, ptr, i64 } undef, ptr %3, 0
30+
// LLVM-NEXT: %5 = insertvalue { ptr, ptr, i64 } %4, ptr %3, 1
31+
// LLVM-NEXT: %6 = insertvalue { ptr, ptr, i64 } %5, i64 0, 2
32+
// LLVM-NEXT: %7 = extractvalue { ptr, ptr, i64 } %6, 1
33+
// LLVM-NEXT: store i32 4, ptr %7, align 4
34+
// LLVM-NEXT: call void @llvm.stackrestore(ptr %1)
35+
// LLVM-NEXT: br label %8
36+
// LLVM-EMPTY:
37+
// LLVM-NEXT: 8:
38+
// LLVM-NEXT: ret void
39+
// LLVM-NEXT: }
40+
41+
42+
// Should drop empty scopes.
43+
cir.func @empty_scope() {
44+
cir.scope {
45+
}
46+
cir.return
47+
}
48+
// MLIR: func.func @empty_scope()
49+
// MLIR-NEXT: return
50+
// MLIR-NEXT: }
51+
52+
}

0 commit comments

Comments
 (0)