@@ -352,22 +352,95 @@ ElementInfo makeElement(ASTNode node, ConstraintLocator *locator,
352
352
return std::make_tuple (node, context, isDiscarded, locator);
353
353
}
354
354
355
+ struct SyntacticElementContext
356
+ : public llvm::PointerUnion<AbstractFunctionDecl *, AbstractClosureExpr *> {
357
+ // Inherit the constructors from PointerUnion.
358
+ using PointerUnion::PointerUnion;
359
+
360
+ static SyntacticElementContext forFunctionRef (AnyFunctionRef ref) {
361
+ if (auto *decl = ref.getAbstractFunctionDecl ()) {
362
+ return {decl};
363
+ }
364
+
365
+ return {ref.getAbstractClosureExpr ()};
366
+ }
367
+
368
+ static SyntacticElementContext forClosure (ClosureExpr *closure) {
369
+ return {closure};
370
+ }
371
+
372
+ static SyntacticElementContext forFunction (AbstractFunctionDecl *func) {
373
+ return {func};
374
+ }
375
+
376
+ DeclContext *getAsDeclContext () const {
377
+ if (auto *fn = this ->dyn_cast <AbstractFunctionDecl *>()) {
378
+ return fn;
379
+ } else if (auto *closure = this ->dyn_cast <AbstractClosureExpr *>()) {
380
+ return closure;
381
+ } else {
382
+ llvm_unreachable (" unsupported kind" );
383
+ }
384
+ }
385
+
386
+ NullablePtr<AbstractClosureExpr> getAsAbstractClosureExpr () const {
387
+ return this ->dyn_cast <AbstractClosureExpr *>();
388
+ }
389
+
390
+ NullablePtr<AbstractFunctionDecl> getAsAbstractFunctionDecl () const {
391
+ return this ->dyn_cast <AbstractFunctionDecl *>();
392
+ }
393
+
394
+ Optional<AnyFunctionRef> getAsAnyFunctionRef () const {
395
+ if (auto *fn = this ->dyn_cast <AbstractFunctionDecl *>()) {
396
+ return {fn};
397
+ } else if (auto *closure = this ->dyn_cast <AbstractClosureExpr *>()) {
398
+ return {closure};
399
+ } else {
400
+ return None;
401
+ }
402
+ }
403
+
404
+ BraceStmt *getBody () const {
405
+ if (auto *fn = this ->dyn_cast <AbstractFunctionDecl *>()) {
406
+ return fn->getBody ();
407
+ } else if (auto *closure = this ->dyn_cast <AbstractClosureExpr *>()) {
408
+ return closure->getBody ();
409
+ } else {
410
+ llvm_unreachable (" unsupported kind" );
411
+ }
412
+ }
413
+
414
+ bool isSingleExpressionClosure (ConstraintSystem &cs) {
415
+ if (auto ref = getAsAnyFunctionRef ()) {
416
+ if (cs.getAppliedResultBuilderTransform (*ref))
417
+ return false ;
418
+
419
+ if (auto *closure = ref->getAbstractClosureExpr ())
420
+ return closure->hasSingleExpressionBody ();
421
+ }
422
+
423
+ return false ;
424
+ }
425
+ };
426
+
355
427
// / Statement visitor that generates constraints for a given closure body.
356
428
class SyntacticElementConstraintGenerator
357
429
: public StmtVisitor<SyntacticElementConstraintGenerator, void > {
358
430
friend StmtVisitor<SyntacticElementConstraintGenerator, void >;
359
431
360
432
ConstraintSystem &cs;
361
- AnyFunctionRef context;
433
+ SyntacticElementContext context;
362
434
ConstraintLocator *locator;
363
435
364
436
public:
365
437
// / Whether an error was encountered while generating constraints.
366
438
bool hadError = false ;
367
439
368
- SyntacticElementConstraintGenerator (ConstraintSystem &cs, AnyFunctionRef fn,
440
+ SyntacticElementConstraintGenerator (ConstraintSystem &cs,
441
+ SyntacticElementContext context,
369
442
ConstraintLocator *locator)
370
- : cs(cs), context(fn ), locator(locator) {}
443
+ : cs(cs), context(context ), locator(locator) {}
371
444
372
445
void visitPattern (Pattern *pattern, ContextualTypeInfo context) {
373
446
auto parentElement =
@@ -608,7 +681,7 @@ class SyntacticElementConstraintGenerator
608
681
}
609
682
610
683
void visitDecl (Decl *decl) {
611
- if (!isInSingleExpressionClosure ( )) {
684
+ if (!context. isSingleExpressionClosure (cs )) {
612
685
if (auto patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
613
686
if (locator->isLastElement <LocatorPathElt::PatternBindingElement>())
614
687
visitPatternBindingElement (patternBinding);
@@ -832,7 +905,7 @@ class SyntacticElementConstraintGenerator
832
905
}
833
906
834
907
void visitBraceStmt (BraceStmt *braceStmt) {
835
- if (isInSingleExpressionClosure ( )) {
908
+ if (context. isSingleExpressionClosure (cs )) {
836
909
for (auto node : braceStmt->getElements ()) {
837
910
if (auto expr = node.dyn_cast <Expr *>()) {
838
911
auto generatedExpr = cs.generateConstraints (
@@ -889,7 +962,7 @@ class SyntacticElementConstraintGenerator
889
962
// so let's give them a special locator as to indicate that.
890
963
// Return statements might not have a result if we have a closure whose
891
964
// implicit returned value is coerced to Void.
892
- if (isInSingleExpressionClosure ( ) && returnStmt->hasResult ()) {
965
+ if (context. isSingleExpressionClosure (cs ) && returnStmt->hasResult ()) {
893
966
auto *expr = returnStmt->getResult ();
894
967
assert (expr && " single expression closure without expression?" );
895
968
@@ -904,7 +977,7 @@ class SyntacticElementConstraintGenerator
904
977
cs.addConstraint (ConstraintKind::Conversion, cs.getType (expr),
905
978
contextualResultInfo.getType (),
906
979
cs.getConstraintLocator (
907
- context.getAbstractClosureExpr (),
980
+ context.getAsAbstractClosureExpr (). get (),
908
981
LocatorPathElt::ClosureBody (
909
982
/* hasReturn=*/ !returnStmt->isImplicit ())));
910
983
return ;
@@ -940,26 +1013,19 @@ class SyntacticElementConstraintGenerator
940
1013
cs.setSolutionApplicationTarget (returnStmt, target);
941
1014
}
942
1015
943
- bool isInSingleExpressionClosure () {
944
- if (!isExpr<ClosureExpr>(context.getAbstractClosureExpr ()))
945
- return false ;
946
-
947
- // Result builder transformed bodies are never single-expression.
948
- if (cs.getAppliedResultBuilderTransform (context))
949
- return false ;
950
-
951
- return context.hasSingleExpressionBody ();
952
- }
953
-
954
1016
ContextualTypeInfo getContextualResultInfo () const {
955
- if (auto transform = cs.getAppliedResultBuilderTransform (context))
1017
+ auto funcRef = context.getAsAnyFunctionRef ();
1018
+ if (!funcRef)
1019
+ return {Type (), CTP_Unused};
1020
+
1021
+ if (auto transform = cs.getAppliedResultBuilderTransform (*funcRef))
956
1022
return {transform->bodyResultType , CTP_ReturnStmt};
957
1023
958
1024
if (auto *closure =
959
- getAsExpr<ClosureExpr>(context. getAbstractClosureExpr ()))
1025
+ getAsExpr<ClosureExpr>(funcRef-> getAbstractClosureExpr ()))
960
1026
return {cs.getClosureType (closure)->getResult (), CTP_ClosureResult};
961
1027
962
- return {context. getBodyResultType (), CTP_ReturnStmt};
1028
+ return {funcRef-> getBodyResultType (), CTP_ReturnStmt};
963
1029
}
964
1030
965
1031
#define UNSUPPORTED_STMT (STMT ) void visit##STMT##Stmt(STMT##Stmt *) { \
@@ -1061,7 +1127,8 @@ bool ConstraintSystem::generateConstraints(ClosureExpr *closure) {
1061
1127
1062
1128
if (participatesInInference (closure)) {
1063
1129
SyntacticElementConstraintGenerator generator (
1064
- *this , closure, getConstraintLocator (closure));
1130
+ *this , SyntacticElementContext::forClosure (closure),
1131
+ getConstraintLocator (closure));
1065
1132
1066
1133
generator.visit (closure->getBody ());
1067
1134
@@ -1097,7 +1164,8 @@ bool ConstraintSystem::generateConstraints(AnyFunctionRef fn, BraceStmt *body) {
1097
1164
locator = getConstraintLocator (fn.getAbstractClosureExpr ());
1098
1165
}
1099
1166
1100
- SyntacticElementConstraintGenerator generator (*this , fn, locator.get ());
1167
+ SyntacticElementConstraintGenerator generator (
1168
+ *this , SyntacticElementContext::forFunctionRef (fn), locator.get ());
1101
1169
1102
1170
generator.visit (body);
1103
1171
@@ -1148,23 +1216,22 @@ ConstraintSystem::simplifySyntacticElementConstraint(
1148
1216
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
1149
1217
auto anchor = locator.getAnchor ();
1150
1218
1151
- DeclContext * context;
1219
+ Optional<SyntacticElementContext> context;
1152
1220
if (auto *closure = getAsExpr<ClosureExpr>(anchor)) {
1153
- context = closure;
1221
+ context = SyntacticElementContext::forClosure ( closure) ;
1154
1222
} else if (auto *fn = getAsDecl<AbstractFunctionDecl>(anchor)) {
1155
- context = fn ;
1223
+ context = SyntacticElementContext::forFunction (fn) ;
1156
1224
} else {
1157
1225
return SolutionKind::Error;
1158
1226
}
1159
1227
1160
- AnyFunctionRef fn = AnyFunctionRef::fromFunctionDeclContext (context);
1161
-
1162
- SyntacticElementConstraintGenerator generator (*this , fn,
1228
+ SyntacticElementConstraintGenerator generator (*this , *context,
1163
1229
getConstraintLocator (locator));
1164
1230
1165
1231
if (auto *expr = element.dyn_cast <Expr *>()) {
1166
- SolutionApplicationTarget target (expr, context, contextInfo.purpose ,
1167
- contextInfo.getType (), isDiscarded);
1232
+ SolutionApplicationTarget target (expr, context->getAsDeclContext (),
1233
+ contextInfo.purpose , contextInfo.getType (),
1234
+ isDiscarded);
1168
1235
1169
1236
if (generateConstraints (target, FreeTypeVariableBinding::Disallow))
1170
1237
return SolutionKind::Error;
@@ -1174,7 +1241,7 @@ ConstraintSystem::simplifySyntacticElementConstraint(
1174
1241
} else if (auto *stmt = element.dyn_cast <Stmt *>()) {
1175
1242
generator.visit (stmt);
1176
1243
} else if (auto *cond = element.dyn_cast <StmtConditionElement *>()) {
1177
- if (generateConstraints ({*cond}, context))
1244
+ if (generateConstraints ({*cond}, context-> getAsDeclContext () ))
1178
1245
return SolutionKind::Error;
1179
1246
} else if (auto *pattern = element.dyn_cast <Pattern *>()) {
1180
1247
generator.visitPattern (pattern, contextInfo);
@@ -1199,10 +1266,9 @@ class SyntacticElementSolutionApplication
1199
1266
1200
1267
protected:
1201
1268
Solution &solution;
1202
- AnyFunctionRef context;
1269
+ SyntacticElementContext context;
1203
1270
Type resultType;
1204
1271
RewriteTargetFn rewriteTarget;
1205
- bool isSingleExpression;
1206
1272
1207
1273
// / All `func`s declared in the body of the closure.
1208
1274
SmallVector<FuncDecl *, 4 > LocalFuncs;
@@ -1212,11 +1278,11 @@ class SyntacticElementSolutionApplication
1212
1278
bool hadError = false ;
1213
1279
1214
1280
SyntacticElementSolutionApplication (Solution &solution,
1215
- AnyFunctionRef context, Type resultType,
1281
+ SyntacticElementContext context,
1282
+ Type resultType,
1216
1283
RewriteTargetFn rewriteTarget)
1217
1284
: solution(solution), context(context), resultType(resultType),
1218
- rewriteTarget (rewriteTarget),
1219
- isSingleExpression(context.hasSingleExpressionBody()) {}
1285
+ rewriteTarget (rewriteTarget) {}
1220
1286
1221
1287
virtual ~SyntacticElementSolutionApplication () {}
1222
1288
@@ -1566,9 +1632,9 @@ class SyntacticElementSolutionApplication
1566
1632
// of the body if there is none. This wasn't needed before SE-0326
1567
1633
// because result type was (incorrectly) inferred as `Void` due to
1568
1634
// the body being skipped.
1569
- auto * closure = context.getAbstractClosureExpr ();
1570
- if (closure && !closure->hasSingleExpressionBody () &&
1571
- closure->getBody () == braceStmt) {
1635
+ auto closure = context.getAsAbstractClosureExpr ();
1636
+ if (closure && !closure. get () ->hasSingleExpressionBody () &&
1637
+ closure. get () ->getBody () == braceStmt) {
1572
1638
if (resultType->getOptionalObjectType () &&
1573
1639
resultType->lookThroughAllOptionalTypes ()->isVoid () &&
1574
1640
!braceStmt->getLastElement ().isStmt (StmtKind::Return)) {
@@ -1650,7 +1716,8 @@ class SyntacticElementSolutionApplication
1650
1716
1651
1717
// A single-expression closure with a Never expression type
1652
1718
// coerces to any other function type.
1653
- } else if (isSingleExpression && resultExprType->isUninhabited ()) {
1719
+ } else if (context.isSingleExpressionClosure (cs) &&
1720
+ resultExprType->isUninhabited ()) {
1654
1721
mode = coerceFromNever;
1655
1722
1656
1723
// Normal rule is to coerce to the return expression to the closure type.
@@ -1665,7 +1732,7 @@ class SyntacticElementSolutionApplication
1665
1732
// Single-expression closures have to handle returns in a special
1666
1733
// way so the target has to be created for them during solution
1667
1734
// application based on the resolved type.
1668
- assert (isSingleExpression );
1735
+ assert (context. isSingleExpressionClosure (cs) );
1669
1736
resultTarget = SolutionApplicationTarget (
1670
1737
resultExpr, context.getAsDeclContext (),
1671
1738
mode == convertToResult ? CTP_ClosureResult : CTP_Unused,
@@ -1737,7 +1804,8 @@ class ResultBuilderRewriter : public SyntacticElementSolutionApplication {
1737
1804
const AppliedBuilderTransform &transform,
1738
1805
RewriteTargetFn rewriteTarget)
1739
1806
: SyntacticElementSolutionApplication(
1740
- solution, context, transform.bodyResultType, rewriteTarget),
1807
+ solution, SyntacticElementContext::forFunctionRef(context),
1808
+ transform.bodyResultType, rewriteTarget),
1741
1809
Transform (transform) {}
1742
1810
1743
1811
bool apply () {
@@ -1746,11 +1814,14 @@ class ResultBuilderRewriter : public SyntacticElementSolutionApplication {
1746
1814
if (!body || hadError)
1747
1815
return true ;
1748
1816
1749
- context.setTypecheckedBody (castToStmt<BraceStmt>(body),
1750
- /* hasSingleExpression=*/ false );
1817
+ auto funcRef = context.getAsAnyFunctionRef ();
1818
+ assert (funcRef);
1819
+
1820
+ funcRef->setTypecheckedBody (castToStmt<BraceStmt>(body),
1821
+ /* hasSingleExpression=*/ false );
1751
1822
1752
1823
if (auto *closure =
1753
- getAsExpr<ClosureExpr>(context. getAbstractClosureExpr ()))
1824
+ getAsExpr<ClosureExpr>(funcRef-> getAbstractClosureExpr ()))
1754
1825
solution.setExprTypes (closure);
1755
1826
1756
1827
return false ;
@@ -2015,7 +2086,7 @@ SolutionApplicationToFunctionResult ConstraintSystem::applySolution(
2015
2086
DeclContext *¤tDC,
2016
2087
RewriteTargetFn rewriteTarget) {
2017
2088
auto &cs = solution.getConstraintSystem ();
2018
- auto closure = dyn_cast_or_null <ClosureExpr>(fn.getAbstractClosureExpr ());
2089
+ auto * closure = getAsExpr <ClosureExpr>(fn.getAbstractClosureExpr ());
2019
2090
FunctionType *closureFnType = nullptr ;
2020
2091
if (closure) {
2021
2092
// Update the closure's type.
@@ -2128,8 +2199,9 @@ bool ConstraintSystem::applySolutionToBody(Solution &solution,
2128
2199
resultTy = fn.getBodyResultType ();
2129
2200
}
2130
2201
2131
- SyntacticElementSolutionApplication application (solution, fn, resultTy,
2132
- rewriteTarget);
2202
+ SyntacticElementSolutionApplication application (
2203
+ solution, SyntacticElementContext::forFunctionRef (fn), resultTy,
2204
+ rewriteTarget);
2133
2205
2134
2206
auto body = application.apply ();
2135
2207
0 commit comments