Skip to content

Commit 4b43573

Browse files
committed
[Constraint system] Implement switch support for function builders.
Implement support for switch statements within function builders. Cases can perform arbitrary pattern matches, e.g., tuplify(true) { c in "testSwitchCombined" switch e { case .a: "a" case .b(let i, _?), .b(let i, nil): i + 17 } } subject to the normal rules of switch statements. Cases within function builders cannot, however, include “fallthrough” statements, because those (like “break” and “continue”) are control flow. The translation of performed for `switch` statements is similar to that of `if` statements, using `buildEither(first:)` and `buildEither(second:)` on the function builder type. This is the bulk of switch support, tracked by rdar://problem/50426203.
1 parent 13b200d commit 4b43573

9 files changed

+500
-7
lines changed

lib/Sema/BuilderTransform.cpp

+150-3
Original file line numberDiff line numberDiff line change
@@ -622,12 +622,104 @@ class BuilderClosureVisitor
622622
DeclNameLoc(endLoc), /*implicit=*/true);
623623
}
624624

625+
VarDecl *visitSwitchStmt(SwitchStmt *switchStmt) {
626+
// Generate constraints for the subject expression, and capture its
627+
// type for use in matching the various patterns.
628+
Expr *subjectExpr = switchStmt->getSubjectExpr();
629+
if (cs) {
630+
// FIXME: Add contextual type purpose for switch subjects?
631+
SolutionApplicationTarget target(subjectExpr, dc, CTP_Unused, Type(),
632+
/*isDiscarded=*/false);
633+
if (cs->generateConstraints(target, FreeTypeVariableBinding::Disallow)) {
634+
hadError = true;
635+
return nullptr;
636+
}
637+
638+
cs->setSolutionApplicationTarget(switchStmt, target);
639+
subjectExpr = target.getAsExpr();
640+
assert(subjectExpr && "Must have a subject expression here");
641+
}
642+
643+
// Generate constraints and capture variables for all of the cases.
644+
SmallVector<std::pair<CaseStmt *, VarDecl *>, 4> capturedCaseVars;
645+
for (auto *caseStmt : switchStmt->getCases()) {
646+
if (auto capturedCaseVar = visitCaseStmt(caseStmt, subjectExpr)) {
647+
capturedCaseVars.push_back({caseStmt, capturedCaseVar});
648+
}
649+
}
650+
651+
if (!cs)
652+
return nullptr;
653+
654+
// Form the expressions that inject the result of each case into the
655+
// appropriate
656+
llvm::TinyPtrVector<Expr *> injectedCaseExprs;
657+
SmallVector<std::pair<Type, ConstraintLocator *>, 4> injectedCaseTerms;
658+
for (unsigned idx : indices(capturedCaseVars)) {
659+
auto caseStmt = capturedCaseVars[idx].first;
660+
auto caseVar = capturedCaseVars[idx].second;
661+
662+
// Build the expression that injects the case variable into appropriate
663+
// buildEither(first:)/buildEither(second:) chain.
664+
Expr *caseVarRef = buildVarRef(caseVar, caseStmt->getEndLoc());
665+
Expr *injectedCaseExpr = buildWrappedChainPayload(
666+
caseVarRef, idx, capturedCaseVars.size(), /*isOptional=*/false);
667+
668+
// Generate constraints for this injected case result.
669+
injectedCaseExpr = cs->generateConstraints(injectedCaseExpr, dc);
670+
if (!injectedCaseExpr) {
671+
hadError = true;
672+
return nullptr;
673+
}
674+
675+
// Record this injected case expression.
676+
injectedCaseExprs.push_back(injectedCaseExpr);
677+
678+
// Record the type and locator for this injected case expression, to be
679+
// used in the "join" constraint later.
680+
injectedCaseTerms.push_back(
681+
{ cs->getType(injectedCaseExpr)->getRValueType(),
682+
cs->getConstraintLocator(injectedCaseExpr) });
683+
}
684+
685+
// Form the type of the switch itself.
686+
// FIXME: Need a locator for the "switch" statement.
687+
Type resultType = cs->addJoinConstraint(nullptr, injectedCaseTerms);
688+
if (!resultType) {
689+
hadError = true;
690+
return nullptr;
691+
}
692+
693+
// Create a variable to capture the result of evaluating the switch.
694+
auto switchVar = buildVar(switchStmt->getStartLoc());
695+
cs->setType(switchVar, resultType);
696+
applied.capturedStmts.insert(
697+
{switchStmt, { switchVar, std::move(injectedCaseExprs) } });
698+
return switchVar;
699+
}
700+
701+
VarDecl *visitCaseStmt(CaseStmt *caseStmt, Expr *subjectExpr) {
702+
// If needed, generate constraints for everything in the case statement.
703+
if (cs) {
704+
auto locator = cs->getConstraintLocator(
705+
subjectExpr, LocatorPathElt::ContextualType());
706+
Type subjectType = cs->getType(subjectExpr);
707+
708+
if (cs->generateConstraints(caseStmt, dc, subjectType, locator)) {
709+
hadError = true;
710+
return nullptr;
711+
}
712+
}
713+
714+
// Translate the body.
715+
return visit(caseStmt->getBody());
716+
}
717+
625718
CONTROL_FLOW_STMT(Guard)
626719
CONTROL_FLOW_STMT(While)
627720
CONTROL_FLOW_STMT(DoCatch)
628721
CONTROL_FLOW_STMT(RepeatWhile)
629722
CONTROL_FLOW_STMT(ForEach)
630-
CONTROL_FLOW_STMT(Switch)
631723
CONTROL_FLOW_STMT(Case)
632724
CONTROL_FLOW_STMT(Catch)
633725
CONTROL_FLOW_STMT(Break)
@@ -996,6 +1088,63 @@ class BuilderClosureRewriter
9961088
return doStmt;
9971089
}
9981090

1091+
Stmt *visitSwitchStmt(SwitchStmt *switchStmt, FunctionBuilderTarget target) {
1092+
// Translate the subject expression.
1093+
ConstraintSystem &cs = solution.getConstraintSystem();
1094+
auto subjectTarget =
1095+
rewriteTarget(*cs.getSolutionApplicationTarget(switchStmt));
1096+
if (!subjectTarget)
1097+
return nullptr;
1098+
1099+
switchStmt->setSubjectExpr(subjectTarget->getAsExpr());
1100+
1101+
// Handle any declaration nodes within the case list first; we'll
1102+
// handle the cases in a second pass.
1103+
for (auto child : switchStmt->getRawCases()) {
1104+
if (auto decl = child.dyn_cast<Decl *>()) {
1105+
TypeChecker::typeCheckDecl(decl);
1106+
}
1107+
}
1108+
1109+
// Translate all of the cases.
1110+
assert(target.kind == FunctionBuilderTarget::TemporaryVar);
1111+
auto temporaryVar = target.captured.first;
1112+
unsigned caseIndex = 0;
1113+
for (auto caseStmt : switchStmt->getCases()) {
1114+
if (!visitCaseStmt(
1115+
caseStmt,
1116+
FunctionBuilderTarget::forAssign(
1117+
temporaryVar, {target.captured.second[caseIndex]})))
1118+
return nullptr;
1119+
1120+
++caseIndex;
1121+
}
1122+
1123+
return switchStmt;
1124+
}
1125+
1126+
Stmt *visitCaseStmt(CaseStmt *caseStmt, FunctionBuilderTarget target) {
1127+
// Translate the patterns and guard expressions for each case label item.
1128+
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
1129+
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
1130+
if (!rewriteTarget(caseLabelTarget))
1131+
return nullptr;
1132+
}
1133+
1134+
// Transform the body of the case.
1135+
auto body = cast<BraceStmt>(caseStmt->getBody());
1136+
auto captured = takeCapturedStmt(body);
1137+
auto newInnerBody = cast<BraceStmt>(
1138+
visitBraceStmt(
1139+
body,
1140+
target,
1141+
FunctionBuilderTarget::forAssign(
1142+
captured.first, {captured.second.front()})));
1143+
caseStmt->setBody(newInnerBody);
1144+
1145+
return caseStmt;
1146+
}
1147+
9991148
#define UNHANDLED_FUNCTION_BUILDER_STMT(STMT) \
10001149
Stmt *visit##STMT##Stmt(STMT##Stmt *stmt, FunctionBuilderTarget target) { \
10011150
llvm_unreachable("Function builders do not allow statement of kind " \
@@ -1010,8 +1159,6 @@ class BuilderClosureRewriter
10101159
UNHANDLED_FUNCTION_BUILDER_STMT(DoCatch)
10111160
UNHANDLED_FUNCTION_BUILDER_STMT(RepeatWhile)
10121161
UNHANDLED_FUNCTION_BUILDER_STMT(ForEach)
1013-
UNHANDLED_FUNCTION_BUILDER_STMT(Switch)
1014-
UNHANDLED_FUNCTION_BUILDER_STMT(Case)
10151162
UNHANDLED_FUNCTION_BUILDER_STMT(Catch)
10161163
UNHANDLED_FUNCTION_BUILDER_STMT(Break)
10171164
UNHANDLED_FUNCTION_BUILDER_STMT(Continue)

lib/Sema/CSApply.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -7474,6 +7474,45 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
74747474
}
74757475
}
74767476

7477+
return target;
7478+
} else if (auto caseLabelItem = target.getAsCaseLabelItem()) {
7479+
ConstraintSystem &cs = solution.getConstraintSystem();
7480+
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
7481+
7482+
// Figure out the pattern type.
7483+
Type patternType = solution.simplifyType(solution.getType(info.pattern));
7484+
patternType = patternType->reconstituteSugar(/*recursive=*/false);
7485+
7486+
// Coerce the pattern to its appropriate type.
7487+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
7488+
patternOptions |= TypeResolutionFlags::OverrideType;
7489+
auto contextualPattern =
7490+
ContextualPattern::forRawPattern(info.pattern,
7491+
target.getDeclContext());
7492+
if (auto coercedPattern = TypeChecker::coercePatternToType(
7493+
contextualPattern, patternType, patternOptions)) {
7494+
(*caseLabelItem)->setPattern(coercedPattern);
7495+
} else {
7496+
return None;
7497+
}
7498+
7499+
// If there is a guard expression, coerce that.
7500+
if (auto guardExpr = info.guardExpr) {
7501+
guardExpr = guardExpr->walk(*this);
7502+
if (!guardExpr)
7503+
return None;
7504+
7505+
// FIXME: Feels like we could leverage existing code more.
7506+
Type boolType = cs.getASTContext().getBoolDecl()->getDeclaredType();
7507+
guardExpr = solution.coerceToType(
7508+
guardExpr, boolType, cs.getConstraintLocator(info.guardExpr));
7509+
if (!guardExpr)
7510+
return None;
7511+
7512+
(*caseLabelItem)->setGuardExpr(guardExpr);
7513+
solution.setExprTypes(guardExpr);
7514+
}
7515+
74777516
return target;
74787517
} else {
74797518
auto fn = *target.getAsFunction();
@@ -7747,5 +7786,17 @@ SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
77477786
condElement = *condElement.walk(walker);
77487787
}
77497788
return *this;
7789+
7790+
case Kind::caseLabelItem:
7791+
if (auto newPattern =
7792+
caseLabelItem.caseLabelItem->getPattern()->walk(walker)) {
7793+
caseLabelItem.caseLabelItem->setPattern(newPattern);
7794+
}
7795+
if (auto guardExpr = caseLabelItem.caseLabelItem->getGuardExpr()) {
7796+
if (auto newGuardExpr = guardExpr->walk(walker))
7797+
caseLabelItem.caseLabelItem->setGuardExpr(newGuardExpr);
7798+
}
7799+
7800+
return *this;
77507801
}
77517802
}

lib/Sema/CSGen.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -4223,6 +4223,62 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition,
42234223
return false;
42244224
}
42254225

4226+
bool ConstraintSystem::generateConstraints(
4227+
CaseStmt *caseStmt, DeclContext *dc, Type subjectType,
4228+
ConstraintLocator *locator) {
4229+
// Pre-bind all of the pattern variables within the case.
4230+
bindSwitchCasePatternVars(caseStmt);
4231+
4232+
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
4233+
// Resolve the pattern.
4234+
auto *pattern = TypeChecker::resolvePattern(
4235+
caseLabelItem.getPattern(), dc, /*isStmtCondition=*/false);
4236+
if (!pattern)
4237+
return true;
4238+
4239+
// Generate constraints for the pattern, including one-way bindings for
4240+
// any variables that show up in this pattern, because those variables
4241+
// can be referenced in the guard expressions and the body.
4242+
Type patternType = generateConstraints(
4243+
pattern, locator, /* bindPatternVarsOneWay=*/true);
4244+
4245+
// Convert the subject type to the pattern, which establishes the
4246+
// bindings.
4247+
addConstraint(
4248+
ConstraintKind::Conversion, subjectType, patternType, locator);
4249+
4250+
// Generate constraints for the guard expression, if there is one.
4251+
Expr *guardExpr = caseLabelItem.getGuardExpr();
4252+
if (guardExpr) {
4253+
guardExpr = generateConstraints(guardExpr, dc);
4254+
if (!guardExpr)
4255+
return true;
4256+
}
4257+
4258+
// Save this info.
4259+
setCaseLabelItemInfo(&caseLabelItem, {pattern, guardExpr});
4260+
4261+
// For any pattern variable that has a parent variable (i.e., another
4262+
// pattern variable with the same name in the same case), require that
4263+
// the types be equivalent.
4264+
pattern->forEachVariable([&](VarDecl *var) {
4265+
if (auto parentVar = var->getParentVarDecl()) {
4266+
addConstraint(
4267+
ConstraintKind::Equal, getType(parentVar), getType(var), locator);
4268+
}
4269+
});
4270+
}
4271+
4272+
// Bind the types of the case body variables.
4273+
for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
4274+
auto parentVar = caseBodyVar->getParentVarDecl();
4275+
assert(parentVar && "Case body variables always have parents");
4276+
setType(caseBodyVar, getType(parentVar));
4277+
}
4278+
4279+
return false;
4280+
}
4281+
42264282
void ConstraintSystem::optimizeConstraints(Expr *e) {
42274283
if (getASTContext().TypeCheckerOpts.DisableConstraintSolverPerformanceHacks)
42284284
return;

lib/Sema/CSSolver.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Solution ConstraintSystem::finalize() {
173173
contextualTypes.begin(), contextualTypes.end());
174174

175175
solution.solutionApplicationTargets = solutionApplicationTargets;
176+
solution.caseLabelItems = caseLabelItems;
176177

177178
for (auto &e : CheckedConformances)
178179
solution.Conformances.push_back({e.first, e.second});
@@ -250,6 +251,11 @@ void ConstraintSystem::applySolution(const Solution &solution) {
250251
if (!getSolutionApplicationTarget(target.first))
251252
setSolutionApplicationTarget(target.first, target.second);
252253
}
254+
255+
// Register the statement condition targets.
256+
for (const auto &info : solution.caseLabelItems) {
257+
if (!getCaseLabelItemInfo(info.first))
258+
setCaseLabelItemInfo(info.first, info.second);
253259
}
254260

255261
// Register the conformances checked along the way to arrive to solution.
@@ -357,6 +363,13 @@ void truncate(llvm::MapVector<K, V> &map, unsigned newSize) {
357363
map.pop_back();
358364
}
359365

366+
template <typename K, typename V, unsigned N>
367+
void truncate(llvm::SmallMapVector<K, V, N> &map, unsigned newSize) {
368+
assert(newSize <= map.size() && "Not a truncation!");
369+
for (unsigned i = 0, n = map.size() - newSize; i != n; ++i)
370+
map.pop_back();
371+
}
372+
360373
} // end anonymous namespace
361374

362375
ConstraintSystem::SolverState::SolverState(
@@ -465,6 +478,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
465478
numInferredClosureTypes = cs.ClosureTypes.size();
466479
numContextualTypes = cs.contextualTypes.size();
467480
numSolutionApplicationTargets = cs.solutionApplicationTargets.size();
481+
numCaseLabelItems = cs.caseLabelItems.size();
468482

469483
PreviousScore = cs.CurrentScore;
470484

@@ -545,6 +559,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
545559
// Remove any solution application targets.
546560
truncate(cs.solutionApplicationTargets, numSolutionApplicationTargets);
547561

562+
// Remove any case label item infos.
563+
truncate(cs.caseLabelItems, numCaseLabelItems);
564+
548565
// Reset the previous score.
549566
cs.CurrentScore = PreviousScore;
550567

0 commit comments

Comments
 (0)