Skip to content

Commit ec64a39

Browse files
committed
[ConstraintSystem] Allow injecting callAsFunction after defaulted arguments
- Adjust argument matching to be less aggressive while matching misplaced trailing closures against defaulted parameters that do not accept trailing closures. For example: ```swift func test(a: Int? = 42, b: String? = nil) {} test { ... } ``` Trailing closure in this case should be marked as "extraneous" instead of matched against parameter `a:`. - If trailing closure is extraneous in a call to `.init` on callable type, let's not attempt to to fix it and allow solver to inject `.callAsFunction` first. Resolves: rdar://problem/94959816 (cherry picked from commit ea4ac2c) (cherry picked from commit 2750e69)
1 parent ad06e7d commit ec64a39

File tree

2 files changed

+131
-8
lines changed

2 files changed

+131
-8
lines changed

lib/Sema/CSSimplify.cpp

+59-7
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,35 @@ static bool matchCallArgumentsImpl(
467467
return;
468468
}
469469

470+
// Let's consider current closure to be extraneous if:
471+
//
472+
// - current parameter has a default value and doesn't accept a trailing
473+
// closure; and
474+
// - no other free parameter after this one accepts a trailing closure via
475+
// forward or backward scan. This check makes sure that it's safe to
476+
// reject and push it to the next parameter without affecting backward
477+
// scan logic.
478+
//
479+
// In other words - let's push the closure argument through defaulted
480+
// parameters until it can be considered extraneous if no parameters
481+
// could possibly match it.
482+
if (!paramInfo.acceptsUnlabeledTrailingClosureArgument(paramIdx) &&
483+
!parameterRequiresArgument(params, paramInfo, paramIdx)) {
484+
if (llvm::none_of(
485+
range(paramIdx + 1, params.size()), [&](unsigned idx) {
486+
return parameterBindings[idx].empty() &&
487+
(paramInfo.acceptsUnlabeledTrailingClosureArgument(
488+
idx) ||
489+
backwardScanAcceptsTrailingClosure(params[idx]));
490+
})) {
491+
haveUnfulfilledParams = true;
492+
return;
493+
}
494+
495+
// If one or more parameters can match the closure, let's check
496+
// whether backward scan is applicable here.
497+
}
498+
470499
// If this parameter does not require an argument, consider applying a
471500
// backward-match rule that skips this parameter if doing so is the only
472501
// way to successfully match arguments to parameters.
@@ -1076,8 +1105,10 @@ constraints::getCompletionArgInfo(ASTNode anchor, ConstraintSystem &CS) {
10761105
class ArgumentFailureTracker : public MatchCallArgumentListener {
10771106
protected:
10781107
ConstraintSystem &CS;
1108+
NullablePtr<ValueDecl> Callee;
10791109
SmallVectorImpl<AnyFunctionType::Param> &Arguments;
10801110
ArrayRef<AnyFunctionType::Param> Parameters;
1111+
Optional<unsigned> UnlabeledTrailingClosureArgIndex;
10811112
ConstraintLocatorBuilder Locator;
10821113

10831114
private:
@@ -1109,11 +1140,14 @@ class ArgumentFailureTracker : public MatchCallArgumentListener {
11091140
}
11101141

11111142
public:
1112-
ArgumentFailureTracker(ConstraintSystem &cs,
1143+
ArgumentFailureTracker(ConstraintSystem &cs, ValueDecl *callee,
11131144
SmallVectorImpl<AnyFunctionType::Param> &args,
11141145
ArrayRef<AnyFunctionType::Param> params,
1146+
Optional<unsigned> unlabeledTrailingClosureArgIndex,
11151147
ConstraintLocatorBuilder locator)
1116-
: CS(cs), Arguments(args), Parameters(params), Locator(locator) {}
1148+
: CS(cs), Callee(callee), Arguments(args), Parameters(params),
1149+
UnlabeledTrailingClosureArgIndex(unlabeledTrailingClosureArgIndex),
1150+
Locator(locator) {}
11171151

11181152
~ArgumentFailureTracker() override {
11191153
if (!MissingArguments.empty()) {
@@ -1143,6 +1177,19 @@ class ArgumentFailureTracker : public MatchCallArgumentListener {
11431177
if (!CS.shouldAttemptFixes())
11441178
return true;
11451179

1180+
// If this is a trailing closure, let's check if the call is
1181+
// to an init of a callable type. If so, let's not record it
1182+
// as extraneous since it would be matched against implicitly
1183+
// injected `.callAsFunction` call.
1184+
if (UnlabeledTrailingClosureArgIndex &&
1185+
argIdx == *UnlabeledTrailingClosureArgIndex && Callee) {
1186+
if (auto *ctor = dyn_cast<ConstructorDecl>(Callee.get())) {
1187+
auto resultTy = ctor->getResultInterfaceType();
1188+
if (resultTy->isCallableNominalType(CS.DC))
1189+
return true;
1190+
}
1191+
}
1192+
11461193
ExtraArguments.push_back(std::make_pair(argIdx, Arguments[argIdx]));
11471194
return false;
11481195
}
@@ -1251,12 +1298,15 @@ class CompletionArgumentTracker : public ArgumentFailureTracker {
12511298
struct CompletionArgInfo ArgInfo;
12521299

12531300
public:
1254-
CompletionArgumentTracker(ConstraintSystem &cs,
1301+
CompletionArgumentTracker(ConstraintSystem &cs, ValueDecl *callee,
12551302
SmallVectorImpl<AnyFunctionType::Param> &args,
12561303
ArrayRef<AnyFunctionType::Param> params,
1304+
Optional<unsigned> unlabeledTrailingClosureArgIndex,
12571305
ConstraintLocatorBuilder locator,
12581306
struct CompletionArgInfo ArgInfo)
1259-
: ArgumentFailureTracker(cs, args, params, locator), ArgInfo(ArgInfo) {}
1307+
: ArgumentFailureTracker(cs, callee, args, params,
1308+
unlabeledTrailingClosureArgIndex, locator),
1309+
ArgInfo(ArgInfo) {}
12601310

12611311
Optional<unsigned> missingArgument(unsigned paramIdx,
12621312
unsigned argInsertIdx) override {
@@ -1666,14 +1716,16 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
16661716
if (cs.isForCodeCompletion()) {
16671717
if (auto completionInfo = getCompletionArgInfo(locator.getAnchor(), cs)) {
16681718
listener = std::make_unique<CompletionArgumentTracker>(
1669-
cs, argsWithLabels, params, locator, *completionInfo);
1719+
cs, callee, argsWithLabels, params,
1720+
argList->getFirstTrailingClosureIndex(), locator, *completionInfo);
16701721
}
16711722
}
16721723
if (!listener) {
16731724
// We didn't create an argument tracker for code completion. Create a
16741725
// normal one.
1675-
listener = std::make_unique<ArgumentFailureTracker>(cs, argsWithLabels,
1676-
params, locator);
1726+
listener = std::make_unique<ArgumentFailureTracker>(
1727+
cs, callee, argsWithLabels, params,
1728+
argList->getFirstTrailingClosureIndex(), locator);
16771729
}
16781730
auto callArgumentMatch = constraints::matchCallArguments(
16791731
argsWithLabels, params, paramInfo,

test/Constraints/callAsFunction.swift

+72-1
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,23 @@
88
protocol View {}
99
struct EmptyView: View {}
1010

11+
enum Align {
12+
case top, center, bottom
13+
}
14+
1115
struct MyLayout {
16+
init(alignment: Align? = .center, spacing: Double? = 0.0) {}
17+
1218
func callAsFunction<V: View>(content: () -> V) -> MyLayout { .init() }
19+
// expected-note@-1 {{where 'V' = 'Int'}}
1320
func callAsFunction<V: View>(answer: () -> Int,
1421
content: () -> V) -> MyLayout { .init() }
22+
// expected-note@-2 {{where 'V' = 'Int'}}
1523
}
1624

1725
struct Test {
1826
var body1: MyLayout {
19-
MyLayout() {
27+
MyLayout(spacing: 1.0) {
2028
EmptyView() // Ok
2129
}
2230
}
@@ -28,6 +36,58 @@ struct Test {
2836
EmptyView() // Ok
2937
}
3038
}
39+
40+
var body3: MyLayout {
41+
MyLayout(alignment: .top) {
42+
let x = 42
43+
return x
44+
} content: {
45+
EmptyView() // Ok
46+
}
47+
}
48+
49+
var body4: MyLayout {
50+
MyLayout(spacing: 1.0) {
51+
let x = 42
52+
return x
53+
} content: {
54+
_ = 42
55+
return EmptyView() // Ok
56+
}
57+
}
58+
59+
var body5: MyLayout {
60+
MyLayout(alignment: .bottom, spacing: 1.0) {
61+
42
62+
} content: {
63+
EmptyView() // Ok
64+
}
65+
}
66+
67+
var body6: MyLayout {
68+
MyLayout(spacing: 1.0) {
69+
_ = EmptyView()
70+
return 42
71+
} // expected-error {{instance method 'callAsFunction(content:)' requires that 'Int' conform to 'View'}}
72+
}
73+
74+
var body7: MyLayout {
75+
MyLayout(alignment: .center) {
76+
42
77+
} content: {
78+
_ = EmptyView()
79+
return 42
80+
} // expected-error {{instance method 'callAsFunction(answer:content:)' requires that 'Int' conform to 'View'}}
81+
}
82+
83+
var body8: MyLayout {
84+
MyLayout {
85+
let x = ""
86+
return x // expected-error {{cannot convert return expression of type 'String' to return type 'Int'}}
87+
} content: {
88+
EmptyView()
89+
}
90+
}
3191
}
3292

3393
// rdar://92912878 - filtering prevents disambiguation of `.callAsFunction`
@@ -51,3 +111,14 @@ func test_no_filtering_of_overloads() {
51111
}
52112
}
53113
}
114+
115+
func test_default_arguments_do_not_interfere() {
116+
struct S {
117+
init(a: Int? = 42, b: String = "") {}
118+
func callAsFunction(_: () -> Void) -> S { S() }
119+
}
120+
121+
_ = S { _ = 42 }
122+
_ = S(a: 42) { _ = 42 }
123+
_ = S(b: "") { _ = 42 }
124+
}

0 commit comments

Comments
 (0)