Skip to content

Commit f5f94fd

Browse files
xedinhamishknight
authored andcommitted
[CSClosure] Introduce SyntacticElementContext
Replace `AnyFunctionRef` as "context" for syntactic element with a custom `SyntacticElementContext` to support type-checking of other constructs in the future.
1 parent f2404fe commit f5f94fd

File tree

1 file changed

+121
-49
lines changed

1 file changed

+121
-49
lines changed

lib/Sema/CSSyntacticElement.cpp

Lines changed: 121 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,95 @@ ElementInfo makeElement(ASTNode node, ConstraintLocator *locator,
352352
return std::make_tuple(node, context, isDiscarded, locator);
353353
}
354354

355+
struct SyntacticElementContext
356+
: public llvm::PointerUnion<AbstractFunctionDecl *, AbstractClosureExpr *> {
357+
// Inherit the constructors from PointerUnion.
358+
using PointerUnion::PointerUnion;
359+
360+
static SyntacticElementContext forFunctionRef(AnyFunctionRef ref) {
361+
if (auto *decl = ref.getAbstractFunctionDecl()) {
362+
return {decl};
363+
}
364+
365+
return {ref.getAbstractClosureExpr()};
366+
}
367+
368+
static SyntacticElementContext forClosure(ClosureExpr *closure) {
369+
return {closure};
370+
}
371+
372+
static SyntacticElementContext forFunction(AbstractFunctionDecl *func) {
373+
return {func};
374+
}
375+
376+
DeclContext *getAsDeclContext() const {
377+
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
378+
return fn;
379+
} else if (auto *closure = this->dyn_cast<AbstractClosureExpr *>()) {
380+
return closure;
381+
} else {
382+
llvm_unreachable("unsupported kind");
383+
}
384+
}
385+
386+
NullablePtr<AbstractClosureExpr> getAsAbstractClosureExpr() const {
387+
return this->dyn_cast<AbstractClosureExpr *>();
388+
}
389+
390+
NullablePtr<AbstractFunctionDecl> getAsAbstractFunctionDecl() const {
391+
return this->dyn_cast<AbstractFunctionDecl *>();
392+
}
393+
394+
Optional<AnyFunctionRef> getAsAnyFunctionRef() const {
395+
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
396+
return {fn};
397+
} else if (auto *closure = this->dyn_cast<AbstractClosureExpr *>()) {
398+
return {closure};
399+
} else {
400+
return None;
401+
}
402+
}
403+
404+
BraceStmt *getBody() const {
405+
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
406+
return fn->getBody();
407+
} else if (auto *closure = this->dyn_cast<AbstractClosureExpr *>()) {
408+
return closure->getBody();
409+
} else {
410+
llvm_unreachable("unsupported kind");
411+
}
412+
}
413+
414+
bool isSingleExpressionClosure(ConstraintSystem &cs) {
415+
if (auto ref = getAsAnyFunctionRef()) {
416+
if (cs.getAppliedResultBuilderTransform(*ref))
417+
return false;
418+
419+
if (auto *closure = ref->getAbstractClosureExpr())
420+
return closure->hasSingleExpressionBody();
421+
}
422+
423+
return false;
424+
}
425+
};
426+
355427
/// Statement visitor that generates constraints for a given closure body.
356428
class SyntacticElementConstraintGenerator
357429
: public StmtVisitor<SyntacticElementConstraintGenerator, void> {
358430
friend StmtVisitor<SyntacticElementConstraintGenerator, void>;
359431

360432
ConstraintSystem &cs;
361-
AnyFunctionRef context;
433+
SyntacticElementContext context;
362434
ConstraintLocator *locator;
363435

364436
public:
365437
/// Whether an error was encountered while generating constraints.
366438
bool hadError = false;
367439

368-
SyntacticElementConstraintGenerator(ConstraintSystem &cs, AnyFunctionRef fn,
440+
SyntacticElementConstraintGenerator(ConstraintSystem &cs,
441+
SyntacticElementContext context,
369442
ConstraintLocator *locator)
370-
: cs(cs), context(fn), locator(locator) {}
443+
: cs(cs), context(context), locator(locator) {}
371444

372445
void visitPattern(Pattern *pattern, ContextualTypeInfo context) {
373446
auto parentElement =
@@ -608,7 +681,7 @@ class SyntacticElementConstraintGenerator
608681
}
609682

610683
void visitDecl(Decl *decl) {
611-
if (!isInSingleExpressionClosure()) {
684+
if (!context.isSingleExpressionClosure(cs)) {
612685
if (auto patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
613686
if (locator->isLastElement<LocatorPathElt::PatternBindingElement>())
614687
visitPatternBindingElement(patternBinding);
@@ -832,7 +905,7 @@ class SyntacticElementConstraintGenerator
832905
}
833906

834907
void visitBraceStmt(BraceStmt *braceStmt) {
835-
if (isInSingleExpressionClosure()) {
908+
if (context.isSingleExpressionClosure(cs)) {
836909
for (auto node : braceStmt->getElements()) {
837910
if (auto expr = node.dyn_cast<Expr *>()) {
838911
auto generatedExpr = cs.generateConstraints(
@@ -889,7 +962,7 @@ class SyntacticElementConstraintGenerator
889962
// so let's give them a special locator as to indicate that.
890963
// Return statements might not have a result if we have a closure whose
891964
// implicit returned value is coerced to Void.
892-
if (isInSingleExpressionClosure() && returnStmt->hasResult()) {
965+
if (context.isSingleExpressionClosure(cs) && returnStmt->hasResult()) {
893966
auto *expr = returnStmt->getResult();
894967
assert(expr && "single expression closure without expression?");
895968

@@ -904,7 +977,7 @@ class SyntacticElementConstraintGenerator
904977
cs.addConstraint(ConstraintKind::Conversion, cs.getType(expr),
905978
contextualResultInfo.getType(),
906979
cs.getConstraintLocator(
907-
context.getAbstractClosureExpr(),
980+
context.getAsAbstractClosureExpr().get(),
908981
LocatorPathElt::ClosureBody(
909982
/*hasReturn=*/!returnStmt->isImplicit())));
910983
return;
@@ -940,26 +1013,19 @@ class SyntacticElementConstraintGenerator
9401013
cs.setSolutionApplicationTarget(returnStmt, target);
9411014
}
9421015

943-
bool isInSingleExpressionClosure() {
944-
if (!isExpr<ClosureExpr>(context.getAbstractClosureExpr()))
945-
return false;
946-
947-
// Result builder transformed bodies are never single-expression.
948-
if (cs.getAppliedResultBuilderTransform(context))
949-
return false;
950-
951-
return context.hasSingleExpressionBody();
952-
}
953-
9541016
ContextualTypeInfo getContextualResultInfo() const {
955-
if (auto transform = cs.getAppliedResultBuilderTransform(context))
1017+
auto funcRef = context.getAsAnyFunctionRef();
1018+
if (!funcRef)
1019+
return {Type(), CTP_Unused};
1020+
1021+
if (auto transform = cs.getAppliedResultBuilderTransform(*funcRef))
9561022
return {transform->bodyResultType, CTP_ReturnStmt};
9571023

9581024
if (auto *closure =
959-
getAsExpr<ClosureExpr>(context.getAbstractClosureExpr()))
1025+
getAsExpr<ClosureExpr>(funcRef->getAbstractClosureExpr()))
9601026
return {cs.getClosureType(closure)->getResult(), CTP_ClosureResult};
9611027

962-
return {context.getBodyResultType(), CTP_ReturnStmt};
1028+
return {funcRef->getBodyResultType(), CTP_ReturnStmt};
9631029
}
9641030

9651031
#define UNSUPPORTED_STMT(STMT) void visit##STMT##Stmt(STMT##Stmt *) { \
@@ -1061,7 +1127,8 @@ bool ConstraintSystem::generateConstraints(ClosureExpr *closure) {
10611127

10621128
if (participatesInInference(closure)) {
10631129
SyntacticElementConstraintGenerator generator(
1064-
*this, closure, getConstraintLocator(closure));
1130+
*this, SyntacticElementContext::forClosure(closure),
1131+
getConstraintLocator(closure));
10651132

10661133
generator.visit(closure->getBody());
10671134

@@ -1097,7 +1164,8 @@ bool ConstraintSystem::generateConstraints(AnyFunctionRef fn, BraceStmt *body) {
10971164
locator = getConstraintLocator(fn.getAbstractClosureExpr());
10981165
}
10991166

1100-
SyntacticElementConstraintGenerator generator(*this, fn, locator.get());
1167+
SyntacticElementConstraintGenerator generator(
1168+
*this, SyntacticElementContext::forFunctionRef(fn), locator.get());
11011169

11021170
generator.visit(body);
11031171

@@ -1148,23 +1216,22 @@ ConstraintSystem::simplifySyntacticElementConstraint(
11481216
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
11491217
auto anchor = locator.getAnchor();
11501218

1151-
DeclContext *context;
1219+
Optional<SyntacticElementContext> context;
11521220
if (auto *closure = getAsExpr<ClosureExpr>(anchor)) {
1153-
context = closure;
1221+
context = SyntacticElementContext::forClosure(closure);
11541222
} else if (auto *fn = getAsDecl<AbstractFunctionDecl>(anchor)) {
1155-
context = fn;
1223+
context = SyntacticElementContext::forFunction(fn);
11561224
} else {
11571225
return SolutionKind::Error;
11581226
}
11591227

1160-
AnyFunctionRef fn = AnyFunctionRef::fromFunctionDeclContext(context);
1161-
1162-
SyntacticElementConstraintGenerator generator(*this, fn,
1228+
SyntacticElementConstraintGenerator generator(*this, *context,
11631229
getConstraintLocator(locator));
11641230

11651231
if (auto *expr = element.dyn_cast<Expr *>()) {
1166-
SolutionApplicationTarget target(expr, context, contextInfo.purpose,
1167-
contextInfo.getType(), isDiscarded);
1232+
SolutionApplicationTarget target(expr, context->getAsDeclContext(),
1233+
contextInfo.purpose, contextInfo.getType(),
1234+
isDiscarded);
11681235

11691236
if (generateConstraints(target, FreeTypeVariableBinding::Disallow))
11701237
return SolutionKind::Error;
@@ -1174,7 +1241,7 @@ ConstraintSystem::simplifySyntacticElementConstraint(
11741241
} else if (auto *stmt = element.dyn_cast<Stmt *>()) {
11751242
generator.visit(stmt);
11761243
} else if (auto *cond = element.dyn_cast<StmtConditionElement *>()) {
1177-
if (generateConstraints({*cond}, context))
1244+
if (generateConstraints({*cond}, context->getAsDeclContext()))
11781245
return SolutionKind::Error;
11791246
} else if (auto *pattern = element.dyn_cast<Pattern *>()) {
11801247
generator.visitPattern(pattern, contextInfo);
@@ -1199,10 +1266,9 @@ class SyntacticElementSolutionApplication
11991266

12001267
protected:
12011268
Solution &solution;
1202-
AnyFunctionRef context;
1269+
SyntacticElementContext context;
12031270
Type resultType;
12041271
RewriteTargetFn rewriteTarget;
1205-
bool isSingleExpression;
12061272

12071273
/// All `func`s declared in the body of the closure.
12081274
SmallVector<FuncDecl *, 4> LocalFuncs;
@@ -1212,11 +1278,11 @@ class SyntacticElementSolutionApplication
12121278
bool hadError = false;
12131279

12141280
SyntacticElementSolutionApplication(Solution &solution,
1215-
AnyFunctionRef context, Type resultType,
1281+
SyntacticElementContext context,
1282+
Type resultType,
12161283
RewriteTargetFn rewriteTarget)
12171284
: solution(solution), context(context), resultType(resultType),
1218-
rewriteTarget(rewriteTarget),
1219-
isSingleExpression(context.hasSingleExpressionBody()) {}
1285+
rewriteTarget(rewriteTarget) {}
12201286

12211287
virtual ~SyntacticElementSolutionApplication() {}
12221288

@@ -1566,9 +1632,9 @@ class SyntacticElementSolutionApplication
15661632
// of the body if there is none. This wasn't needed before SE-0326
15671633
// because result type was (incorrectly) inferred as `Void` due to
15681634
// the body being skipped.
1569-
auto *closure = context.getAbstractClosureExpr();
1570-
if (closure && !closure->hasSingleExpressionBody() &&
1571-
closure->getBody() == braceStmt) {
1635+
auto closure = context.getAsAbstractClosureExpr();
1636+
if (closure && !closure.get()->hasSingleExpressionBody() &&
1637+
closure.get()->getBody() == braceStmt) {
15721638
if (resultType->getOptionalObjectType() &&
15731639
resultType->lookThroughAllOptionalTypes()->isVoid() &&
15741640
!braceStmt->getLastElement().isStmt(StmtKind::Return)) {
@@ -1650,7 +1716,8 @@ class SyntacticElementSolutionApplication
16501716

16511717
// A single-expression closure with a Never expression type
16521718
// coerces to any other function type.
1653-
} else if (isSingleExpression && resultExprType->isUninhabited()) {
1719+
} else if (context.isSingleExpressionClosure(cs) &&
1720+
resultExprType->isUninhabited()) {
16541721
mode = coerceFromNever;
16551722

16561723
// Normal rule is to coerce to the return expression to the closure type.
@@ -1665,7 +1732,7 @@ class SyntacticElementSolutionApplication
16651732
// Single-expression closures have to handle returns in a special
16661733
// way so the target has to be created for them during solution
16671734
// application based on the resolved type.
1668-
assert(isSingleExpression);
1735+
assert(context.isSingleExpressionClosure(cs));
16691736
resultTarget = SolutionApplicationTarget(
16701737
resultExpr, context.getAsDeclContext(),
16711738
mode == convertToResult ? CTP_ClosureResult : CTP_Unused,
@@ -1737,7 +1804,8 @@ class ResultBuilderRewriter : public SyntacticElementSolutionApplication {
17371804
const AppliedBuilderTransform &transform,
17381805
RewriteTargetFn rewriteTarget)
17391806
: SyntacticElementSolutionApplication(
1740-
solution, context, transform.bodyResultType, rewriteTarget),
1807+
solution, SyntacticElementContext::forFunctionRef(context),
1808+
transform.bodyResultType, rewriteTarget),
17411809
Transform(transform) {}
17421810

17431811
bool apply() {
@@ -1746,11 +1814,14 @@ class ResultBuilderRewriter : public SyntacticElementSolutionApplication {
17461814
if (!body || hadError)
17471815
return true;
17481816

1749-
context.setTypecheckedBody(castToStmt<BraceStmt>(body),
1750-
/*hasSingleExpression=*/false);
1817+
auto funcRef = context.getAsAnyFunctionRef();
1818+
assert(funcRef);
1819+
1820+
funcRef->setTypecheckedBody(castToStmt<BraceStmt>(body),
1821+
/*hasSingleExpression=*/false);
17511822

17521823
if (auto *closure =
1753-
getAsExpr<ClosureExpr>(context.getAbstractClosureExpr()))
1824+
getAsExpr<ClosureExpr>(funcRef->getAbstractClosureExpr()))
17541825
solution.setExprTypes(closure);
17551826

17561827
return false;
@@ -2015,7 +2086,7 @@ SolutionApplicationToFunctionResult ConstraintSystem::applySolution(
20152086
DeclContext *&currentDC,
20162087
RewriteTargetFn rewriteTarget) {
20172088
auto &cs = solution.getConstraintSystem();
2018-
auto closure = dyn_cast_or_null<ClosureExpr>(fn.getAbstractClosureExpr());
2089+
auto *closure = getAsExpr<ClosureExpr>(fn.getAbstractClosureExpr());
20192090
FunctionType *closureFnType = nullptr;
20202091
if (closure) {
20212092
// Update the closure's type.
@@ -2128,8 +2199,9 @@ bool ConstraintSystem::applySolutionToBody(Solution &solution,
21282199
resultTy = fn.getBodyResultType();
21292200
}
21302201

2131-
SyntacticElementSolutionApplication application(solution, fn, resultTy,
2132-
rewriteTarget);
2202+
SyntacticElementSolutionApplication application(
2203+
solution, SyntacticElementContext::forFunctionRef(fn), resultTy,
2204+
rewriteTarget);
21332205

21342206
auto body = application.apply();
21352207

0 commit comments

Comments
 (0)