diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 9a2f269e106a..8156d8fad059 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -478,6 +478,13 @@ class CIRGenFunction : public CIRGenTypeCache { // applies to. nullptr if there is no 'musttail' on the current statement. const clang::CallExpr *MustTailCall = nullptr; + /// The attributes of cases collected during emitting the body of a switch + /// stmt. + llvm::SmallVector, 2> caseAttrsStack; + + /// The type of the condition for the emitting switch statement. + llvm::SmallVector condTypeStack; + clang::ASTContext &getContext() const; CIRGenBuilderTy &getBuilder() { return builder; } @@ -1210,13 +1217,9 @@ class CIRGenFunction : public CIRGenTypeCache { buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType, SmallVector &caseAttrs); - mlir::LogicalResult - buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType, - SmallVector &caseAttrs); + mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S); - mlir::LogicalResult - buildSwitchBody(const clang::Stmt *S, mlir::Type condType, - SmallVector &caseAttrs); + mlir::LogicalResult buildSwitchBody(const clang::Stmt *S); mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn, const CIRGenFunctionInfo &FnInfo); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 426da35b5238..1b0829c8e8bb 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S, case Stmt::CaseStmtClass: case Stmt::DefaultStmtClass: - assert(0 && - "Should not get here, currently handled directly from SwitchStmt"); + return buildSwitchCase(cast(*S)); break; case Stmt::BreakStmtClass: @@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType, return buildCaseDefaultCascade(&S, condType, caseAttrs); } -mlir::LogicalResult -CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType, - SmallVector &caseAttrs) { +mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) { + assert(!caseAttrsStack.empty() && + "build switch case without seeting case attrs"); + assert(!condTypeStack.empty() && + "build switch case without specifying the type of the condition"); + if (S.getStmtClass() == Stmt::CaseStmtClass) - return buildCaseStmt(cast(S), condType, caseAttrs); + return buildCaseStmt(cast(S), condTypeStack.back(), + caseAttrsStack.back()); if (S.getStmtClass() == Stmt::DefaultStmtClass) - return buildDefaultStmt(cast(S), condType, caseAttrs); + return buildDefaultStmt(cast(S), condTypeStack.back(), + caseAttrsStack.back()); llvm_unreachable("expect case or default stmt"); } @@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) { return mlir::success(); } -mlir::LogicalResult CIRGenFunction::buildSwitchBody( - const Stmt *S, mlir::Type condType, - llvm::SmallVector &caseAttrs) { +mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) { if (auto *compoundStmt = dyn_cast(S)) { mlir::Block *lastCaseBlock = nullptr; auto res = mlir::success(); for (auto *c : compoundStmt->body()) { if (auto *switchCase = dyn_cast(c)) { - res = buildSwitchCase(*switchCase, condType, caseAttrs); + res = buildSwitchCase(*switchCase); lastCaseBlock = builder.getBlock(); } else if (lastCaseBlock) { // This means it's a random stmt following up a case, just @@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) { [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) { currLexScope->setAsSwitch(); - llvm::SmallVector caseAttrs; + caseAttrsStack.push_back({}); + condTypeStack.push_back(condV.getType()); - res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs); + res = buildSwitchBody(S.getBody()); os.addRegions(currLexScope->getSwitchRegions()); - os.addAttribute("cases", builder.getArrayAttr(caseAttrs)); + os.addAttribute("cases", builder.getArrayAttr(caseAttrsStack.back())); + + caseAttrsStack.pop_back(); + condTypeStack.pop_back(); }); if (res.failed()) diff --git a/clang/test/CIR/CodeGen/goto.cpp b/clang/test/CIR/CodeGen/goto.cpp index 81eb4ec43e65..2200fc98cfac 100644 --- a/clang/test/CIR/CodeGen/goto.cpp +++ b/clang/test/CIR/CodeGen/goto.cpp @@ -310,3 +310,51 @@ extern "C" void multiple_non_case(int v) { // NOFLAT: cir.label // NOFLAT: cir.call @action2() // NOFLAT: cir.break + +extern "C" void case_follow_label(int v) { + switch (v) { + case 1: + label: + case 2: + action1(); + break; + default: + action2(); + goto label; + } +} + +// NOFLAT: cir.func @case_follow_label +// NOFLAT: cir.switch +// NOFLAT: case (equal, 1) +// NOFLAT: cir.label "label" +// NOFLAT: cir.yield +// NOFLAT: case (equal, 2) +// NOFLAT: cir.call @action1() +// NOFLAT: cir.break +// NOFLAT: case (default) +// NOFLAT: cir.call @action2() +// NOFLAT: cir.goto "label" + +extern "C" void default_follow_label(int v) { + switch (v) { + case 1: + case 2: + action1(); + break; + label: + default: + action2(); + goto label; + } +} + +// NOFLAT: cir.func @default_follow_label +// NOFLAT: cir.switch +// NOFLAT: case (anyof, [1, 2] : !s32i) +// NOFLAT: cir.call @action1() +// NOFLAT: cir.break +// NOFLAT: cir.label "label" +// NOFLAT: case (default) +// NOFLAT: cir.call @action2() +// NOFLAT: cir.goto "label"