Skip to content

Commit 07a6f80

Browse files
committed
[CIR][LowerToLLVM] Exceptions: lower cir.catch_param
1 parent a3c126c commit 07a6f80

File tree

3 files changed

+101
-34
lines changed

3 files changed

+101
-34
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3333,6 +3333,11 @@ def CatchParamOp : CIR_Op<"catch_param"> {
33333333
attr-dict
33343334
}];
33353335

3336+
let extraClassDeclaration = [{
3337+
bool isBegin() { return getKind() == mlir::cir::CatchParamKind::begin; }
3338+
bool isEnd() { return getKind() == mlir::cir::CatchParamKind::end; }
3339+
}];
3340+
33363341
let hasVerifier = 1;
33373342
}
33383343

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,6 +3579,61 @@ class CIREhTypeIdOpLowering
35793579
}
35803580
};
35813581

3582+
class CIRCatchParamOpLowering
3583+
: public mlir::OpConversionPattern<mlir::cir::CatchParamOp> {
3584+
public:
3585+
using OpConversionPattern<mlir::cir::CatchParamOp>::OpConversionPattern;
3586+
3587+
mlir::LogicalResult
3588+
matchAndRewrite(mlir::cir::CatchParamOp op, OpAdaptor adaptor,
3589+
mlir::ConversionPatternRewriter &rewriter) const override {
3590+
if (op.isBegin()) {
3591+
// Get or create `declare ptr @__cxa_begin_catch(ptr)`
3592+
llvm::StringRef cxaBeginCatch = "__cxa_begin_catch";
3593+
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(
3594+
op->getParentOfType<mlir::ModuleOp>(), cxaBeginCatch);
3595+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
3596+
if (!sourceSymbol) {
3597+
auto catchFnTy =
3598+
mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {llvmPtrTy},
3599+
/*isVarArg=*/false);
3600+
mlir::OpBuilder::InsertionGuard guard(rewriter);
3601+
rewriter.setInsertionPoint(
3602+
op->getParentOfType<mlir::LLVM::LLVMFuncOp>());
3603+
auto catchFn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
3604+
op.getLoc(), cxaBeginCatch, catchFnTy);
3605+
sourceSymbol = catchFn;
3606+
}
3607+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
3608+
op, mlir::TypeRange{llvmPtrTy}, cxaBeginCatch,
3609+
mlir::ValueRange{adaptor.getExceptionPtr()});
3610+
return mlir::success();
3611+
} else if (op.isEnd()) {
3612+
// Get or create `declare void @__cxa_end_catch()`
3613+
llvm::StringRef cxaEndCatch = "__cxa_end_catch";
3614+
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(
3615+
op->getParentOfType<mlir::ModuleOp>(), cxaEndCatch);
3616+
auto llvmVoidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext());
3617+
if (!sourceSymbol) {
3618+
auto catchFnTy = mlir::LLVM::LLVMFunctionType::get(llvmVoidTy, {},
3619+
/*isVarArg=*/false);
3620+
mlir::OpBuilder::InsertionGuard guard(rewriter);
3621+
rewriter.setInsertionPoint(
3622+
op->getParentOfType<mlir::LLVM::LLVMFuncOp>());
3623+
auto catchFn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
3624+
op.getLoc(), cxaEndCatch, catchFnTy);
3625+
sourceSymbol = catchFn;
3626+
}
3627+
rewriter.create<mlir::LLVM::CallOp>(op.getLoc(), mlir::TypeRange{},
3628+
cxaEndCatch, mlir::ValueRange{});
3629+
rewriter.eraseOp(op);
3630+
return mlir::success();
3631+
}
3632+
llvm_unreachable("only begin/end supposed to make to lowering stage");
3633+
return mlir::failure();
3634+
}
3635+
};
3636+
35823637
void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
35833638
mlir::TypeConverter &converter) {
35843639
patterns.add<CIRReturnLowering>(patterns.getContext());
@@ -3614,8 +3669,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
36143669
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
36153670
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
36163671
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
3617-
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering>(
3618-
converter, patterns.getContext());
3672+
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
3673+
CIRCatchParamOpLowering>(converter, patterns.getContext());
36193674
}
36203675

36213676
namespace {

clang/test/CIR/Lowering/exceptions.cir

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,44 +50,51 @@ module @"try-catch.cpp" attributes {cir.lang = #cir.lang<cxx>, cir.sob = #cir.si
5050
// LLVM: catch ptr @_ZTIi
5151
// LLVM: catch ptr @_ZTIPKc
5252
%exception_ptr, %type_id = cir.eh.inflight_exception [@_ZTIi, @_ZTIPKc]
53-
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 0, !dbg !29
54-
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 1, !dbg !29
53+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 0
54+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 1
5555
cir.br ^bb5(%exception_ptr, %type_id : !cir.ptr<!void>, !u32i)
5656
^bb5(%16: !cir.ptr<!void>, %17: !u32i):
5757
%18 = cir.eh.typeid @_ZTIi
5858
// LLVM: call i32 @llvm.eh.typeid.for.p0(ptr @_ZTIi)
5959
%19 = cir.cmp(eq, %17, %18) : !u32i, !cir.bool
60+
cir.brcond %19 ^bb6(%16 : !cir.ptr<!void>), ^bb7(%16, %17 : !cir.ptr<!void>, !u32i)
61+
^bb6(%20: !cir.ptr<!void>):
62+
%21 = cir.catch_param begin %20 -> !cir.ptr<!s32i>
63+
// LLVM: %[[EH_IDX:.*]] = phi ptr
64+
// LLVM: call ptr @__cxa_begin_catch(ptr %[[EH_IDX]])
65+
%22 = cir.load %21 : !cir.ptr<!s32i>, !s32i
66+
cir.store %22, %7 : !s32i, !cir.ptr<!s32i>
67+
%23 = cir.const #cir.int<98> : !s32i
68+
%24 = cir.cast(integral, %23 : !s32i), !u64i
69+
cir.store %24, %3 : !u64i, !cir.ptr<!u64i>
70+
%25 = cir.load %7 : !cir.ptr<!s32i>, !s32i
71+
%26 = cir.unary(inc, %25) : !s32i, !s32i
72+
cir.store %26, %7 : !s32i, !cir.ptr<!s32i>
73+
cir.catch_param end
74+
// LLVM: call void @__cxa_end_catch()
6075
cir.br ^bb10
61-
// TODO: TBD
62-
// cir.brcond %19 ^bb6(%16 : !cir.ptr<!void>), ^bb7(%16, %17 : !cir.ptr<!void>, !u32i)
63-
// ^bb6(%20: !cir.ptr<!void>):
64-
// %21 = cir.catch_param begin %20 -> !cir.ptr<!s32i>
65-
// %22 = cir.load %21 : !cir.ptr<!s32i>, !s32i
66-
// cir.store %22, %7 : !s32i, !cir.ptr<!s32i>
67-
// %23 = cir.const #cir.int<98> : !s32i
68-
// %24 = cir.cast(integral, %23 : !s32i), !u64i
69-
// cir.store %24, %3 : !u64i, !cir.ptr<!u64i>
70-
// %25 = cir.load %7 : !cir.ptr<!s32i>, !s32i
71-
// %26 = cir.unary(inc, %25) : !s32i, !s32i
72-
// cir.store %26, %7 : !s32i, !cir.ptr<!s32i>
73-
// cir.catch_param end
74-
// cir.br ^bb10
75-
// ^bb7(%27: !cir.ptr<!void>, %28: !u32i):
76-
// %29 = cir.eh.typeid @_ZTIPKc
77-
// %30 = cir.cmp(eq, %28, %29) : !u32i, !cir.bool
78-
// cir.brcond %30 ^bb8(%27 : !cir.ptr<!void>), ^bb9(%27, %28 : !cir.ptr<!void>, !u32i)
79-
// ^bb8(%31: !cir.ptr<!void>):
80-
// %32 = cir.catch_param begin %31 -> !cir.ptr<!s8i>
81-
// cir.store %32, %6 : !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>
82-
// %33 = cir.const #cir.int<99> : !s32i
83-
// %34 = cir.cast(integral, %33 : !s32i), !u64i
84-
// cir.store %34, %3 : !u64i, !cir.ptr<!u64i>
85-
// %35 = cir.load %6 : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i>
86-
// %36 = cir.const #cir.int<0> : !s32i
87-
// %37 = cir.ptr_stride(%35 : !cir.ptr<!s8i>, %36 : !s32i), !cir.ptr<!s8i>
88-
// cir.catch_param end
89-
// cir.br ^bb10
90-
// ^bb9(%38: !cir.ptr<!void>, %39: !u32i):
76+
^bb7(%27: !cir.ptr<!void>, %28: !u32i):
77+
%29 = cir.eh.typeid @_ZTIPKc
78+
// LLVM: call i32 @llvm.eh.typeid.for.p0(ptr @_ZTIPKc)
79+
%30 = cir.cmp(eq, %28, %29) : !u32i, !cir.bool
80+
cir.brcond %30 ^bb8(%27 : !cir.ptr<!void>), ^bb9(%27, %28 : !cir.ptr<!void>, !u32i)
81+
^bb8(%31: !cir.ptr<!void>):
82+
%32 = cir.catch_param begin %31 -> !cir.ptr<!s8i>
83+
// LLVM: %[[EH_MSG:.*]] = phi ptr
84+
// LLVM: call ptr @__cxa_begin_catch(ptr %[[EH_MSG]])
85+
cir.store %32, %6 : !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>
86+
%33 = cir.const #cir.int<99> : !s32i
87+
%34 = cir.cast(integral, %33 : !s32i), !u64i
88+
cir.store %34, %3 : !u64i, !cir.ptr<!u64i>
89+
%35 = cir.load %6 : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i>
90+
%36 = cir.const #cir.int<0> : !s32i
91+
%37 = cir.ptr_stride(%35 : !cir.ptr<!s8i>, %36 : !s32i), !cir.ptr<!s8i>
92+
cir.catch_param end
93+
// LLVM: call void @__cxa_end_catch()
94+
cir.br ^bb10
95+
^bb9(%38: !cir.ptr<!void>, %39: !u32i):
96+
cir.br ^bb10
97+
// TODO: support resume.
9198
// cir.resume
9299
^bb10:
93100
%40 = cir.load %3 : !cir.ptr<!u64i>, !u64i

0 commit comments

Comments
 (0)