Skip to content

Commit a4aa65b

Browse files
committed
[CIR][LowerToLLVM] Exceptions: more lowering work for cir.try_call and cir.eh.inflight_exception
- Fix parser problems that were preventing testing and fix additional lowering missing for `cir.try_call`. - Add lowering from scratch for `cir.eh.inflight_exception`. End-to-end requires full exception support (still more lowering TBD to get there).
1 parent 7150a05 commit a4aa65b

File tree

3 files changed

+258
-14
lines changed

3 files changed

+258
-14
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,11 +2349,60 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
23492349
return success();
23502350
}
23512351

2352-
static ::mlir::ParseResult
2353-
parseTryCallBranches(::mlir::OpAsmParser &parser,
2354-
::mlir::OperationState &result) {
2355-
parser.emitError(parser.getCurrentLocation(), "NYI");
2356-
return failure();
2352+
static mlir::ParseResult
2353+
parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
2354+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
2355+
&continueOperands,
2356+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
2357+
&landingPadOperands,
2358+
llvm::SmallVectorImpl<mlir::Type> &continueTypes,
2359+
llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
2360+
llvm::SMLoc &continueOperandsLoc,
2361+
llvm::SMLoc &landingPadOperandsLoc) {
2362+
mlir::Block *continueSuccessor = nullptr;
2363+
mlir::Block *landingPadSuccessor = nullptr;
2364+
2365+
if (parser.parseSuccessor(continueSuccessor))
2366+
return mlir::failure();
2367+
if (mlir::succeeded(parser.parseOptionalLParen())) {
2368+
continueOperandsLoc = parser.getCurrentLocation();
2369+
if (parser.parseOperandList(continueOperands))
2370+
return mlir::failure();
2371+
if (parser.parseColon())
2372+
return mlir::failure();
2373+
2374+
if (parser.parseTypeList(continueTypes))
2375+
return mlir::failure();
2376+
if (parser.parseRParen())
2377+
return mlir::failure();
2378+
}
2379+
if (parser.parseComma())
2380+
return mlir::failure();
2381+
2382+
if (parser.parseSuccessor(landingPadSuccessor))
2383+
return mlir::failure();
2384+
if (mlir::succeeded(parser.parseOptionalLParen())) {
2385+
2386+
landingPadOperandsLoc = parser.getCurrentLocation();
2387+
if (parser.parseOperandList(landingPadOperands))
2388+
return mlir::failure();
2389+
if (parser.parseColon())
2390+
return mlir::failure();
2391+
2392+
if (parser.parseTypeList(landingPadTypes))
2393+
return mlir::failure();
2394+
if (parser.parseRParen())
2395+
return mlir::failure();
2396+
}
2397+
{
2398+
auto loc = parser.getCurrentLocation();
2399+
(void)loc;
2400+
if (parser.parseOptionalAttrDict(result.attributes))
2401+
return mlir::failure();
2402+
}
2403+
result.addSuccessors(continueSuccessor);
2404+
result.addSuccessors(landingPadSuccessor);
2405+
return mlir::success();
23572406
}
23582407

23592408
static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
@@ -2367,6 +2416,14 @@ static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
23672416
llvm::ArrayRef<::mlir::Type> operandsTypes;
23682417
llvm::ArrayRef<::mlir::Type> allResultTypes;
23692418

2419+
// Control flow related
2420+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
2421+
llvm::SMLoc continueOperandsLoc;
2422+
llvm::SmallVector<mlir::Type, 1> continueTypes;
2423+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
2424+
llvm::SMLoc landingPadOperandsLoc;
2425+
llvm::SmallVector<mlir::Type, 1> landingPadTypes;
2426+
23702427
if (::mlir::succeeded(parser.parseOptionalKeyword("exception")))
23712428
result.addAttribute("exception", parser.getBuilder().getUnitAttr());
23722429

@@ -2390,7 +2447,10 @@ static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
23902447
return ::mlir::failure();
23912448

23922449
if (hasDestinationBlocks)
2393-
if (parseTryCallBranches(parser, result).failed())
2450+
if (parseTryCallBranches(parser, result, continueOperands,
2451+
landingPadOperands, continueTypes, landingPadTypes,
2452+
continueOperandsLoc, landingPadOperandsLoc)
2453+
.failed())
23942454
return ::mlir::failure();
23952455

23962456
auto &builder = parser.getBuilder();
@@ -2423,6 +2483,23 @@ static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
24232483

24242484
if (parser.resolveOperands(ops, operandsTypes, opsLoc, result.operands))
24252485
return ::mlir::failure();
2486+
2487+
if (hasDestinationBlocks) {
2488+
// The TryCall ODS layout is: cont, landing_pad, operands.
2489+
llvm::copy(::llvm::ArrayRef<int32_t>(
2490+
{static_cast<int32_t>(continueOperands.size()),
2491+
static_cast<int32_t>(landingPadOperands.size()),
2492+
static_cast<int32_t>(ops.size())}),
2493+
result.getOrAddProperties<TryCallOp::Properties>()
2494+
.operandSegmentSizes.begin());
2495+
if (parser.resolveOperands(continueOperands, continueTypes,
2496+
continueOperandsLoc, result.operands))
2497+
return ::mlir::failure();
2498+
if (parser.resolveOperands(landingPadOperands, landingPadTypes,
2499+
landingPadOperandsLoc, result.operands))
2500+
return ::mlir::failure();
2501+
}
2502+
24262503
return ::mlir::success();
24272504
}
24282505

@@ -2553,7 +2630,8 @@ cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
25532630
::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
25542631
::mlir::OperationState &result) {
25552632

2556-
return parseCallCommon(parser, result, getExtraAttrsAttrName(result.name));
2633+
return parseCallCommon(parser, result, getExtraAttrsAttrName(result.name),
2634+
/*hasDestinationBlocks=*/true);
25572635
}
25582636

25592637
void TryCallOp::print(::mlir::OpAsmPrinter &state) {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ mlir::LogicalResult
834834
rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
835835
mlir::ConversionPatternRewriter &rewriter,
836836
const mlir::TypeConverter *converter,
837-
mlir::FlatSymbolRefAttr calleeAttr, bool invoke = false,
837+
mlir::FlatSymbolRefAttr calleeAttr,
838838
mlir::Block *continueBlock = nullptr,
839839
mlir::Block *landingPadBlock = nullptr) {
840840
llvm::SmallVector<mlir::Type, 8> llvmResults;
@@ -844,7 +844,7 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
844844
return mlir::failure();
845845

846846
if (calleeAttr) { // direct call
847-
if (invoke)
847+
if (landingPadBlock)
848848
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
849849
op, llvmResults, calleeAttr, callOperands, continueBlock,
850850
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
@@ -860,7 +860,7 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
860860
auto ftyp = dyn_cast<mlir::cir::FuncType>(ptyp.getPointee());
861861
assert(ftyp && "expected a pointer to a function as the first operand");
862862

863-
if (invoke) {
863+
if (landingPadBlock) {
864864
auto llvmFnTy =
865865
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
866866
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
@@ -896,9 +896,9 @@ class CIRTryCallLowering
896896
mlir::LogicalResult
897897
matchAndRewrite(mlir::cir::TryCallOp op, OpAdaptor adaptor,
898898
mlir::ConversionPatternRewriter &rewriter) const override {
899-
return rewriteToCallOrInvoke(op.getOperation(), adaptor.getOperands(),
900-
rewriter, getTypeConverter(),
901-
op.getCalleeAttr());
899+
return rewriteToCallOrInvoke(
900+
op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(),
901+
op.getCalleeAttr(), op.getCont(), op.getLandingPad());
902902
}
903903
};
904904

@@ -910,7 +910,76 @@ class CIREhInflightOpLowering
910910
mlir::LogicalResult
911911
matchAndRewrite(mlir::cir::EhInflightOp op, OpAdaptor adaptor,
912912
mlir::ConversionPatternRewriter &rewriter) const override {
913-
return mlir::failure();
913+
mlir::Location loc = op.getLoc();
914+
// Create the landing pad type: struct { ptr, i32 }
915+
mlir::MLIRContext *ctx = rewriter.getContext();
916+
auto llvmPtr = mlir::LLVM::LLVMPointerType::get(ctx);
917+
llvm::SmallVector<mlir::Type> structFields;
918+
structFields.push_back(llvmPtr);
919+
structFields.push_back(rewriter.getI32Type());
920+
921+
auto llvmLandingPadStructTy =
922+
mlir::LLVM::LLVMStructType::getLiteral(ctx, structFields);
923+
mlir::ArrayAttr symListAttr = op.getSymTypeListAttr();
924+
mlir::SmallVector<mlir::Value, 4> symAddrs;
925+
926+
auto llvmFn = op->getParentOfType<mlir::LLVM::LLVMFuncOp>();
927+
assert(llvmFn && "expected LLVM function parent");
928+
mlir::Block *entryBlock = &llvmFn.getRegion().front();
929+
assert(entryBlock->isEntryBlock());
930+
931+
// %x = landingpad { ptr, i32 }
932+
if (symListAttr) {
933+
// catch ptr @_ZTIi
934+
// catch ptr @_ZTIPKc
935+
for (mlir::Attribute attr : op.getSymTypeListAttr()) {
936+
auto symAttr = cast<mlir::FlatSymbolRefAttr>(attr);
937+
// Generate `llvm.mlir.addressof` for each symbol, and place those
938+
// operations in the LLVM function entry basic block.
939+
mlir::OpBuilder::InsertionGuard guard(rewriter);
940+
rewriter.setInsertionPointToStart(entryBlock);
941+
mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
942+
loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()),
943+
symAttr.getValue());
944+
symAddrs.push_back(addrOp);
945+
}
946+
} else {
947+
// catch ptr null
948+
mlir::Value nullOp = rewriter.create<mlir::LLVM::ZeroOp>(
949+
loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()));
950+
symAddrs.push_back(nullOp);
951+
}
952+
953+
// %slot = extractvalue { ptr, i32 } %x, 0
954+
// %selector = extractvalue { ptr, i32 } %x, 1
955+
auto padOp = rewriter.create<mlir::LLVM::LandingpadOp>(
956+
loc, llvmLandingPadStructTy, symAddrs);
957+
SmallVector<int64_t> slotIdx = {0};
958+
SmallVector<int64_t> selectorIdx = {1};
959+
960+
mlir::Value slot =
961+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, padOp, slotIdx);
962+
mlir::Value selector =
963+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, padOp, selectorIdx);
964+
965+
rewriter.replaceOp(op, mlir::ValueRange{slot, selector});
966+
967+
// Landing pads are required to be in LLVM functions with personality
968+
// attribute. FIXME: for now hardcode personality creation in order to start
969+
// adding exception tests, once we annotate CIR with such information,
970+
// change it to be in FuncOp lowering instead.
971+
{
972+
mlir::OpBuilder::InsertionGuard guard(rewriter);
973+
// Insert personality decl before the current function.
974+
rewriter.setInsertionPoint(llvmFn);
975+
auto personalityFnTy =
976+
mlir::LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {},
977+
/*isVarArg=*/true);
978+
auto personalityFn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
979+
loc, "__gxx_personality_v0", personalityFnTy);
980+
llvmFn.setPersonality(personalityFn.getName());
981+
}
982+
return mlir::success();
914983
}
915984
};
916985

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: cir-translate %s -cir-to-llvmir -o %t.ll
2+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
!s8i = !cir.int<s, 8>
6+
!u32i = !cir.int<u, 32>
7+
!u64i = !cir.int<u, 64>
8+
!u8i = !cir.int<u, 8>
9+
!void = !cir.void
10+
11+
module @"try-catch.cpp" attributes {cir.lang = #cir.lang<cxx>, cir.sob = #cir.signed_overflow_behavior<undefined>, cir.triple = "x86_64-unknown-linux-gnu", dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>, #dlti.dl_entry<"dlti.endianness", "little">>, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"} {
12+
cir.global "private" constant external @_ZTIi : !cir.ptr<!u8i>
13+
cir.global "private" constant external @_ZTIPKc : !cir.ptr<!u8i>
14+
cir.func private @_Z8divisionii(!s32i, !s32i) -> !cir.double
15+
// LLVM: @_Z2tcv() personality ptr @__gxx_personality_v0
16+
cir.func @_Z2tcv() -> !u64i {
17+
%0 = cir.alloca !u64i, !cir.ptr<!u64i>, ["__retval"] {alignment = 8 : i64}
18+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
19+
%2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
20+
%3 = cir.alloca !u64i, !cir.ptr<!u64i>, ["z"] {alignment = 8 : i64}
21+
%4 = cir.const #cir.int<50> : !s32i
22+
cir.store %4, %1 : !s32i, !cir.ptr<!s32i>
23+
%5 = cir.const #cir.int<3> : !s32i
24+
cir.store %5, %2 : !s32i, !cir.ptr<!s32i>
25+
cir.br ^bb1
26+
^bb1:
27+
%6 = cir.alloca !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>, ["msg"] {alignment = 8 : i64}
28+
%7 = cir.alloca !s32i, !cir.ptr<!s32i>, ["idx"] {alignment = 4 : i64}
29+
cir.br ^bb2
30+
^bb2:
31+
%8 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
32+
%9 = cir.const #cir.int<4> : !s32i
33+
cir.store %9, %8 : !s32i, !cir.ptr<!s32i>
34+
%10 = cir.load %1 : !cir.ptr<!s32i>, !s32i
35+
%11 = cir.load %2 : !cir.ptr<!s32i>, !s32i
36+
%12 = cir.try_call @_Z8divisionii(%10, %11) ^bb3, ^bb4 : (!s32i, !s32i) -> !cir.double
37+
// LLVM: invoke double @_Z8divisionii
38+
// LLVM: to label %[[CONT:.*]] unwind label %[[UNWIND:.*]],
39+
^bb3:
40+
// LLVM: [[CONT]]:
41+
%13 = cir.cast(float_to_int, %12 : !cir.double), !u64i
42+
cir.store %13, %3 : !u64i, !cir.ptr<!u64i>
43+
%14 = cir.load %8 : !cir.ptr<!s32i>, !s32i
44+
%15 = cir.unary(inc, %14) : !s32i, !s32i
45+
cir.store %15, %8 : !s32i, !cir.ptr<!s32i>
46+
cir.br ^bb10
47+
^bb4:
48+
// LLVM: [[UNWIND]]:
49+
// LLVM: %[[EHINFO:.*]] = landingpad { ptr, i32 }
50+
// LLVM: catch ptr @_ZTIi
51+
// LLVM: catch ptr @_ZTIPKc
52+
%exception_ptr, %type_id = cir.eh.inflight_exception [@_ZTIi, @_ZTIPKc]
53+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 0, !dbg !29
54+
// LLVM: extractvalue { ptr, i32 } %[[EHINFO]], 1, !dbg !29
55+
cir.br ^bb10
56+
// TODO: TBD
57+
// cir.br ^bb5(%exception_ptr, %type_id : !cir.ptr<!void>, !u32i)
58+
// ^bb5(%16: !cir.ptr<!void>, %17: !u32):
59+
// %18 = cir.eh.typeid @_ZTIi
60+
// %19 = cir.cmp(eq, %17, %18) : !u32i, !cir.bool
61+
// cir.brcond %19 ^bb6(%16 : !cir.ptr<!void>), ^bb7(%16, %17 : !cir.ptr<!void>, !u32i)
62+
// ^bb6(%20: !cir.ptr<!void>):
63+
// %21 = cir.catch_param begin %20 -> !cir.ptr<!s32i>
64+
// %22 = cir.load %21 : !cir.ptr<!s32i>, !s32i
65+
// cir.store %22, %7 : !s32i, !cir.ptr<!s32i>
66+
// %23 = cir.const #cir.int<98> : !s32i
67+
// %24 = cir.cast(integral, %23 : !s32i), !u64i
68+
// cir.store %24, %3 : !u64i, !cir.ptr<!u64i>
69+
// %25 = cir.load %7 : !cir.ptr<!s32i>, !s32i
70+
// %26 = cir.unary(inc, %25) : !s32i, !s32i
71+
// cir.store %26, %7 : !s32i, !cir.ptr<!s32i>
72+
// cir.catch_param end
73+
// cir.br ^bb10
74+
// ^bb7(%27: !cir.ptr<!void>, %28: !u32i):
75+
// %29 = cir.eh.typeid @_ZTIPKc
76+
// %30 = cir.cmp(eq, %28, %29) : !u32i, !cir.bool
77+
// cir.brcond %30 ^bb8(%27 : !cir.ptr<!void>), ^bb9(%27, %28 : !cir.ptr<!void>, !u32i)
78+
// ^bb8(%31: !cir.ptr<!void>):
79+
// %32 = cir.catch_param begin %31 -> !cir.ptr<!s8i>
80+
// cir.store %32, %6 : !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>
81+
// %33 = cir.const #cir.int<99> : !s32i
82+
// %34 = cir.cast(integral, %33 : !s32i), !u64i
83+
// cir.store %34, %3 : !u64i, !cir.ptr<!u64i>
84+
// %35 = cir.load %6 : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i>
85+
// %36 = cir.const #cir.int<0> : !s32i
86+
// %37 = cir.ptr_stride(%35 : !cir.ptr<!s8i>, %36 : !s32i), !cir.ptr<!s8i>
87+
// cir.catch_param end
88+
// cir.br ^bb10
89+
// ^bb9(%38: !cir.ptr<!void>, %39: !u32i):
90+
// cir.resume
91+
^bb10:
92+
%40 = cir.load %3 : !cir.ptr<!u64i>, !u64i
93+
cir.store %40, %0 : !u64i, !cir.ptr<!u64i>
94+
%41 = cir.load %0 : !cir.ptr<!u64i>, !u64i
95+
cir.return %41 : !u64i
96+
}
97+
}

0 commit comments

Comments
 (0)