Skip to content

Commit fc954d2

Browse files
authored
Merge pull request #62096 from hamishknight/multifunction
2 parents 48987de + f5f94fd commit fc954d2

File tree

6 files changed

+264
-162
lines changed

6 files changed

+264
-162
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,9 @@ class BindingSet {
480480
/// if it has only concrete types or would resolve a closure.
481481
bool favoredOverDisjunction(Constraint *disjunction) const;
482482

483+
/// Check if this binding is favored over a conjunction.
484+
bool favoredOverConjunction(Constraint *conjunction) const;
485+
483486
/// Detect `subtype` relationship between two type variables and
484487
/// attempt to infer supertype bindings transitively e.g.
485488
///

lib/Sema/BuilderTransform.cpp

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,73 +1000,85 @@ class ResultBuilderTransform
10001000
return failTransform(stmt); \
10011001
}
10021002

1003-
std::pair<NullablePtr<VarDecl>, Optional<UnsupportedElt>>
1004-
transform(BraceStmt *braceStmt, SmallVectorImpl<ASTNode> &newBody) {
1005-
SmallVector<Expr *, 4> buildBlockArguments;
1006-
1007-
auto failTransform = [&](UnsupportedElt unsupported) {
1008-
return std::make_pair(nullptr, unsupported);
1009-
};
1010-
1011-
for (auto element : braceStmt->getElements()) {
1012-
if (auto *returnStmt = getAsStmt<ReturnStmt>(element)) {
1013-
assert(returnStmt->isImplicit());
1014-
element = returnStmt->getResult();
1015-
}
1016-
1017-
if (auto *decl = element.dyn_cast<Decl *>()) {
1018-
switch (decl->getKind()) {
1003+
/// Visit the element of a brace statement, returning \c None if the element
1004+
/// was transformed successfully, or an unsupported element if the element
1005+
/// cannot be handled.
1006+
Optional<UnsupportedElt>
1007+
transformBraceElement(ASTNode element, SmallVectorImpl<ASTNode> &newBody,
1008+
SmallVectorImpl<Expr *> &buildBlockArguments) {
1009+
if (auto *returnStmt = getAsStmt<ReturnStmt>(element)) {
1010+
assert(returnStmt->isImplicit());
1011+
element = returnStmt->getResult();
1012+
}
1013+
1014+
if (auto *decl = element.dyn_cast<Decl *>()) {
1015+
switch (decl->getKind()) {
10191016
// Just ignore #if; the chosen children should appear in
10201017
// the surrounding context. This isn't good for source
10211018
// tools but it at least works.
1022-
case DeclKind::IfConfig:
1019+
case DeclKind::IfConfig:
10231020
// Skip #warning/#error; we'll handle them when applying
10241021
// the builder.
1025-
case DeclKind::PoundDiagnostic:
1026-
case DeclKind::PatternBinding:
1027-
case DeclKind::Var:
1028-
case DeclKind::Param:
1029-
newBody.push_back(element);
1030-
break;
1031-
1032-
default:
1033-
return failTransform(decl);
1034-
}
1022+
case DeclKind::PoundDiagnostic:
1023+
case DeclKind::PatternBinding:
1024+
case DeclKind::Var:
1025+
case DeclKind::Param:
1026+
newBody.push_back(element);
1027+
return None;
1028+
1029+
default:
1030+
return UnsupportedElt(decl);
1031+
}
1032+
llvm_unreachable("Unhandled case in switch!");
1033+
}
10351034

1036-
continue;
1035+
if (auto *stmt = element.dyn_cast<Stmt *>()) {
1036+
// Throw is allowed as is.
1037+
if (auto *throwStmt = dyn_cast<ThrowStmt>(stmt)) {
1038+
newBody.push_back(throwStmt);
1039+
return None;
10371040
}
10381041

1039-
if (auto *stmt = element.dyn_cast<Stmt *>()) {
1040-
// Throw is allowed as is.
1041-
if (auto *throwStmt = dyn_cast<ThrowStmt>(stmt)) {
1042-
newBody.push_back(throwStmt);
1043-
continue;
1044-
}
1042+
// Allocate variable with a placeholder type
1043+
auto *resultVar = buildPlaceholderVar(stmt->getStartLoc(), newBody);
10451044

1046-
// Allocate variable with a placeholder type
1047-
auto *resultVar = buildPlaceholderVar(stmt->getStartLoc(), newBody);
1045+
auto result = visit(stmt, resultVar);
1046+
if (!result)
1047+
return UnsupportedElt(stmt);
10481048

1049-
auto result = visit(stmt, resultVar);
1050-
if (!result)
1051-
return failTransform(stmt);
1049+
newBody.push_back(result.get());
1050+
buildBlockArguments.push_back(
1051+
builder.buildVarRef(resultVar, stmt->getStartLoc()));
1052+
return None;
1053+
}
10521054

1053-
newBody.push_back(result.get());
1054-
buildBlockArguments.push_back(
1055-
builder.buildVarRef(resultVar, stmt->getStartLoc()));
1056-
continue;
1057-
}
1055+
auto *expr = element.get<Expr *>();
1056+
if (builder.supports(ctx.Id_buildExpression)) {
1057+
expr = builder.buildCall(expr->getLoc(), ctx.Id_buildExpression, {expr},
1058+
{Identifier()});
1059+
}
10581060

1059-
auto *expr = element.get<Expr *>();
1060-
if (builder.supports(ctx.Id_buildExpression)) {
1061-
expr = builder.buildCall(expr->getLoc(), ctx.Id_buildExpression, {expr},
1062-
{Identifier()});
1063-
}
1061+
auto *capture = captureExpr(expr, newBody);
1062+
// A reference to the synthesized variable is passed as an argument
1063+
// to buildBlock.
1064+
buildBlockArguments.push_back(
1065+
builder.buildVarRef(capture, element.getStartLoc()));
1066+
return None;
1067+
}
10641068

1065-
auto *capture = captureExpr(expr, newBody);
1066-
// A reference to the synthesized variable is passed as an argument
1067-
// to buildBlock.
1068-
buildBlockArguments.push_back(
1069-
builder.buildVarRef(capture, element.getStartLoc()));
1069+
std::pair<NullablePtr<VarDecl>, Optional<UnsupportedElt>>
1070+
transform(BraceStmt *braceStmt, SmallVectorImpl<ASTNode> &newBody) {
1071+
SmallVector<Expr *, 4> buildBlockArguments;
1072+
1073+
auto failTransform = [&](UnsupportedElt unsupported) {
1074+
return std::make_pair(nullptr, unsupported);
1075+
};
1076+
1077+
for (auto element : braceStmt->getElements()) {
1078+
if (auto unsupported =
1079+
transformBraceElement(element, newBody, buildBlockArguments)) {
1080+
return failTransform(*unsupported);
1081+
}
10701082
}
10711083

10721084
// Synthesize `buildBlock` or `buildPartial` based on captured arguments.

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ add_swift_host_library(swiftSema STATIC
33
BuilderTransform.cpp
44
CSApply.cpp
55
CSBindings.cpp
6-
CSClosure.cpp
6+
CSSyntacticElement.cpp
77
CSGen.cpp
88
CSRanking.cpp
99
CSSimplify.cpp

lib/Sema/CSBindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,14 @@ bool BindingSet::favoredOverDisjunction(Constraint *disjunction) const {
10651065
return !involvesTypeVariables();
10661066
}
10671067

1068+
bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
1069+
if (CS.shouldAttemptFixes() && isHole()) {
1070+
if (forClosureResult() || forGenericParameter())
1071+
return false;
1072+
}
1073+
return true;
1074+
}
1075+
10681076
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar,
10691077
bool finalize) {
10701078
assert(typeVar->getImpl().getRepresentative(nullptr) == typeVar &&

lib/Sema/CSStep.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ StepResult ComponentStep::take(bool prevFailed) {
359359
});
360360

361361
auto *disjunction = CS.selectDisjunction();
362+
auto *conjunction = CS.selectConjunction();
362363

363364
if (CS.isDebugMode()) {
364365
if (!potentialBindings.empty()) {
@@ -393,33 +394,45 @@ StepResult ComponentStep::take(bool prevFailed) {
393394
}
394395
}
395396

396-
if (CS.shouldAttemptFixes()) {
397-
if ((bestBindings &&
398-
(bestBindings->forClosureResult() ||
399-
bestBindings->forGenericParameter()) &&
400-
bestBindings->isHole()) &&
401-
!disjunction) {
402-
if (auto *conjunction = CS.selectConjunction()) {
403-
return suspend(
404-
std::make_unique<ConjunctionStep>(CS, conjunction, Solutions));
405-
}
397+
enum class StepKind { Binding, Disjunction, Conjunction };
398+
399+
auto chooseStep = [&]() -> Optional<StepKind> {
400+
// Bindings usually happen first, but sometimes we want to prioritize a
401+
// disjunction or conjunction.
402+
if (bestBindings) {
403+
if (disjunction && !bestBindings->favoredOverDisjunction(disjunction))
404+
return StepKind::Disjunction;
405+
406+
if (conjunction && !bestBindings->favoredOverConjunction(conjunction))
407+
return StepKind::Conjunction;
408+
409+
return StepKind::Binding;
410+
}
411+
if (disjunction)
412+
return StepKind::Disjunction;
413+
414+
if (conjunction)
415+
return StepKind::Conjunction;
416+
417+
return None;
418+
};
419+
420+
if (auto step = chooseStep()) {
421+
switch (*step) {
422+
case StepKind::Binding:
423+
return suspend(
424+
std::make_unique<TypeVariableStep>(*bestBindings, Solutions));
425+
case StepKind::Disjunction:
426+
return suspend(
427+
std::make_unique<DisjunctionStep>(CS, disjunction, Solutions));
428+
case StepKind::Conjunction:
429+
return suspend(
430+
std::make_unique<ConjunctionStep>(CS, conjunction, Solutions));
406431
}
432+
llvm_unreachable("Unhandled case in switch!");
407433
}
408434

409-
if (bestBindings &&
410-
(!disjunction || bestBindings->favoredOverDisjunction(disjunction))) {
411-
// Produce a type variable step.
412-
return suspend(
413-
std::make_unique<TypeVariableStep>(*bestBindings, Solutions));
414-
} else if (disjunction) {
415-
// Produce a disjunction step.
416-
return suspend(
417-
std::make_unique<DisjunctionStep>(CS, disjunction, Solutions));
418-
} else if (auto *conjunction = CS.selectConjunction()) {
419-
return suspend(
420-
std::make_unique<ConjunctionStep>(CS, conjunction, Solutions));
421-
} else if (!CS.solverState->allowsFreeTypeVariables() &&
422-
CS.hasFreeTypeVariables()) {
435+
if (!CS.solverState->allowsFreeTypeVariables() && CS.hasFreeTypeVariables()) {
423436
// If there are no disjunctions or type variables to bind
424437
// we can't solve this system unless we have free type variables
425438
// allowed in the solution.

0 commit comments

Comments
 (0)