Skip to content

Commit 0a1eb05

Browse files
bcardosolopeslanza
authored andcommitted
[CIR][LowerToLLVM] Exceptions: lower cir.catch_param
1 parent 66b3f31 commit 0a1eb05

File tree

3 files changed

+101
-34
lines changed

3 files changed

+101
-34
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,6 +3342,11 @@ def CatchParamOp : CIR_Op<"catch_param"> {
33423342
attr-dict
33433343
}];
33443344

3345+
let extraClassDeclaration = [{
3346+
bool isBegin() { return getKind() == mlir::cir::CatchParamKind::begin; }
3347+
bool isEnd() { return getKind() == mlir::cir::CatchParamKind::end; }
3348+
}];
3349+
33453350
let hasVerifier = 1;
33463351
}
33473352

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3580,6 +3580,61 @@ class CIREhTypeIdOpLowering
35803580
}
35813581
};
35823582

3583+
class CIRCatchParamOpLowering
3584+
: public mlir::OpConversionPattern<mlir::cir::CatchParamOp> {
3585+
public:
3586+
using OpConversionPattern<mlir::cir::CatchParamOp>::OpConversionPattern;
3587+
3588+
mlir::LogicalResult
3589+
matchAndRewrite(mlir::cir::CatchParamOp op, OpAdaptor adaptor,
3590+
mlir::ConversionPatternRewriter &rewriter) const override {
3591+
if (op.isBegin()) {
3592+
// Get or create `declare ptr @__cxa_begin_catch(ptr)`
3593+
llvm::StringRef cxaBeginCatch = "__cxa_begin_catch";
3594+
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(
3595+
op->getParentOfType<mlir::ModuleOp>(), cxaBeginCatch);
3596+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
3597+
if (!sourceSymbol) {
3598+
auto catchFnTy =
3599+
mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {llvmPtrTy},
3600+
/*isVarArg=*/false);
3601+
mlir::OpBuilder::InsertionGuard guard(rewriter);
3602+
rewriter.setInsertionPoint(
3603+
op->getParentOfType<mlir::LLVM::LLVMFuncOp>());
3604+
auto catchFn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
3605+
op.getLoc(), cxaBeginCatch, catchFnTy);
3606+
sourceSymbol = catchFn;
3607+
}
3608+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
3609+
op, mlir::TypeRange{llvmPtrTy}, cxaBeginCatch,
3610+
mlir::ValueRange{adaptor.getExceptionPtr()});
3611+
return mlir::success();
3612+
} else if (op.isEnd()) {
3613+
// Get or create `declare void @__cxa_end_catch()`
3614+
llvm::StringRef cxaEndCatch = "__cxa_end_catch";
3615+
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(
3616+
op->getParentOfType<mlir::ModuleOp>(), cxaEndCatch);
3617+
auto llvmVoidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext());
3618+
if (!sourceSymbol) {
3619+
auto catchFnTy = mlir::LLVM::LLVMFunctionType::get(llvmVoidTy, {},
3620+
/*isVarArg=*/false);
3621+
mlir::OpBuilder::InsertionGuard guard(rewriter);
3622+
rewriter.setInsertionPoint(
3623+
op->getParentOfType<mlir::LLVM::LLVMFuncOp>());
3624+
auto catchFn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
3625+
op.getLoc(), cxaEndCatch, catchFnTy);
3626+
sourceSymbol = catchFn;
3627+
}
3628+
rewriter.create<mlir::LLVM::CallOp>(op.getLoc(), mlir::TypeRange{},
3629+
cxaEndCatch, mlir::ValueRange{});
3630+
rewriter.eraseOp(op);
3631+
return mlir::success();
3632+
}
3633+
llvm_unreachable("only begin/end supposed to make to lowering stage");
3634+
return mlir::failure();
3635+
}
3636+
};
3637+
35833638
void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
35843639
mlir::TypeConverter &converter) {
35853640
patterns.add<CIRReturnLowering>(patterns.getContext());
@@ -3615,8 +3670,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
36153670
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
36163671
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
36173672
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
3618-
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering>(
3619-
converter, patterns.getContext());
3673+
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
3674+
CIRCatchParamOpLowering>(converter, patterns.getContext());
36203675
}
36213676

36223677
namespace {

clang/test/CIR/Lowering/exceptions.cir

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,44 +50,51 @@ module @"try-catch.cpp" attributes {cir.lang = #cir.lang<cxx>, cir.sob = #cir.si
5050
// LLVM: catch ptr @_ZTIi
5151
// LLVM: catch ptr @_ZTIPKc
5252
%exception_ptr, %type_id = cir.eh.inflight_exception [@_ZTIi, @_ZTIPKc]
53-
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 0, !dbg !29
54-
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 1, !dbg !29
53+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 0
54+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 1
5555
cir.br ^bb5(%exception_ptr, %type_id : !cir.ptr<!void>, !u32i)
5656
^bb5(%16: !cir.ptr<!void>, %17: !u32i):
5757
%18 = cir.eh.typeid @_ZTIi
5858
// LLVM: call i32 @llvm.eh.typeid.for.p0(ptr @_ZTIi)
5959
%19 = cir.cmp(eq, %17, %18) : !u32i, !cir.bool
60+
cir.brcond %19 ^bb6(%16 : !cir.ptr<!void>), ^bb7(%16, %17 : !cir.ptr<!void>, !u32i)
61+
^bb6(%20: !cir.ptr<!void>):
62+
%21 = cir.catch_param begin %20 -> !cir.ptr<!s32i>
63+
// LLVM: %[[EH_IDX:.*]] = phi ptr
64+
// LLVM: call ptr @__cxa_begin_catch(ptr %[[EH_IDX]])
65+
%22 = cir.load %21 : !cir.ptr<!s32i>, !s32i
66+
cir.store %22, %7 : !s32i, !cir.ptr<!s32i>
67+
%23 = cir.const #cir.int<98> : !s32i
68+
%24 = cir.cast(integral, %23 : !s32i), !u64i
69+
cir.store %24, %3 : !u64i, !cir.ptr<!u64i>
70+
%25 = cir.load %7 : !cir.ptr<!s32i>, !s32i
71+
%26 = cir.unary(inc, %25) : !s32i, !s32i
72+
cir.store %26, %7 : !s32i, !cir.ptr<!s32i>
73+
cir.catch_param end
74+
// LLVM: call void @__cxa_end_catch()
6075
cir.br ^bb10
61-
// TODO: TBD
62-
// cir.brcond %19 ^bb6(%16 : !cir.ptr<!void>), ^bb7(%16, %17 : !cir.ptr<!void>, !u32i)
63-
// ^bb6(%20: !cir.ptr<!void>):
64-
// %21 = cir.catch_param begin %20 -> !cir.ptr<!s32i>
65-
// %22 = cir.load %21 : !cir.ptr<!s32i>, !s32i
66-
// cir.store %22, %7 : !s32i, !cir.ptr<!s32i>
67-
// %23 = cir.const #cir.int<98> : !s32i
68-
// %24 = cir.cast(integral, %23 : !s32i), !u64i
69-
// cir.store %24, %3 : !u64i, !cir.ptr<!u64i>
70-
// %25 = cir.load %7 : !cir.ptr<!s32i>, !s32i
71-
// %26 = cir.unary(inc, %25) : !s32i, !s32i
72-
// cir.store %26, %7 : !s32i, !cir.ptr<!s32i>
73-
// cir.catch_param end
74-
// cir.br ^bb10
75-
// ^bb7(%27: !cir.ptr<!void>, %28: !u32i):
76-
// %29 = cir.eh.typeid @_ZTIPKc
77-
// %30 = cir.cmp(eq, %28, %29) : !u32i, !cir.bool
78-
// cir.brcond %30 ^bb8(%27 : !cir.ptr<!void>), ^bb9(%27, %28 : !cir.ptr<!void>, !u32i)
79-
// ^bb8(%31: !cir.ptr<!void>):
80-
// %32 = cir.catch_param begin %31 -> !cir.ptr<!s8i>
81-
// cir.store %32, %6 : !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>
82-
// %33 = cir.const #cir.int<99> : !s32i
83-
// %34 = cir.cast(integral, %33 : !s32i), !u64i
84-
// cir.store %34, %3 : !u64i, !cir.ptr<!u64i>
85-
// %35 = cir.load %6 : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i>
86-
// %36 = cir.const #cir.int<0> : !s32i
87-
// %37 = cir.ptr_stride(%35 : !cir.ptr<!s8i>, %36 : !s32i), !cir.ptr<!s8i>
88-
// cir.catch_param end
89-
// cir.br ^bb10
90-
// ^bb9(%38: !cir.ptr<!void>, %39: !u32i):
76+
^bb7(%27: !cir.ptr<!void>, %28: !u32i):
77+
%29 = cir.eh.typeid @_ZTIPKc
78+
// LLVM: call i32 @llvm.eh.typeid.for.p0(ptr @_ZTIPKc)
79+
%30 = cir.cmp(eq, %28, %29) : !u32i, !cir.bool
80+
cir.brcond %30 ^bb8(%27 : !cir.ptr<!void>), ^bb9(%27, %28 : !cir.ptr<!void>, !u32i)
81+
^bb8(%31: !cir.ptr<!void>):
82+
%32 = cir.catch_param begin %31 -> !cir.ptr<!s8i>
83+
// LLVM: %[[EH_MSG:.*]] = phi ptr
84+
// LLVM: call ptr @__cxa_begin_catch(ptr %[[EH_MSG]])
85+
cir.store %32, %6 : !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>
86+
%33 = cir.const #cir.int<99> : !s32i
87+
%34 = cir.cast(integral, %33 : !s32i), !u64i
88+
cir.store %34, %3 : !u64i, !cir.ptr<!u64i>
89+
%35 = cir.load %6 : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i>
90+
%36 = cir.const #cir.int<0> : !s32i
91+
%37 = cir.ptr_stride(%35 : !cir.ptr<!s8i>, %36 : !s32i), !cir.ptr<!s8i>
92+
cir.catch_param end
93+
// LLVM: call void @__cxa_end_catch()
94+
cir.br ^bb10
95+
^bb9(%38: !cir.ptr<!void>, %39: !u32i):
96+
cir.br ^bb10
97+
// TODO: support resume.
9198
// cir.resume
9299
^bb10:
93100
%40 = cir.load %3 : !cir.ptr<!u64i>, !u64i

0 commit comments

Comments
 (0)