Skip to content

Commit ad0ad9f

Browse files
committed
Check for let vs var bindings of the same variable name in multiple case patterns.
1 parent c8b886f commit ad0ad9f

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,9 @@ ERROR(fallthrough_into_case_with_var_binding,none,
29302930
ERROR(unnecessary_cast_over_optionset,none,
29312931
"unnecessary cast over raw value of %0", (Type))
29322932

2933+
ERROR(mutability_mismatch_multiple_pattern_list,none,
2934+
"'%select{var|let}0' pattern binding must match previous "
2935+
"'%select{var|let}1' pattern binding", (bool, bool))
29332936
ERROR(type_mismatch_multiple_pattern_list,none,
29342937
"pattern variable bound to type %0, expected type %1", (Type, Type))
29352938
ERROR(type_mismatch_fallthrough_pattern_list,none,

lib/Sema/TypeCheckStmt.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -871,25 +871,41 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
871871
// For each variable in the pattern, make sure its type is identical to what it
872872
// was in the first label item's pattern.
873873
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
874-
if (pattern != firstPattern) {
875-
SmallVector<VarDecl *, 4> vars;
876-
firstPattern->collectVariables(vars);
877-
pattern->forEachVariable([&](VarDecl *VD) {
878-
if (!VD->hasName())
879-
return;
880-
for (auto *expected : vars) {
881-
if (expected->hasName() && expected->getName() == VD->getName()) {
882-
if (!VD->getType()->isEqual(expected->getType())) {
883-
TC.diagnose(VD->getLoc(), diag::type_mismatch_multiple_pattern_list,
884-
VD->getType(), expected->getType());
885-
VD->markInvalid();
886-
expected->markInvalid();
887-
}
888-
return;
874+
SmallVector<VarDecl *, 4> vars;
875+
firstPattern->collectVariables(vars);
876+
pattern->forEachVariable([&](VarDecl *VD) {
877+
if (!VD->hasName())
878+
return;
879+
for (auto *expected : vars) {
880+
if (expected->hasName() && expected->getName() == VD->getName()) {
881+
if (VD->hasType() && expected->hasType() && !expected->isInvalid() &&
882+
!VD->getType()->isEqual(expected->getType())) {
883+
TC.diagnose(VD->getLoc(), diag::type_mismatch_multiple_pattern_list,
884+
VD->getType(), expected->getType());
885+
VD->markInvalid();
886+
expected->markInvalid();
887+
}
888+
if (expected->isLet() != VD->isLet()) {
889+
auto diag = TC.diagnose(VD->getLoc(),
890+
diag::mutability_mismatch_multiple_pattern_list,
891+
VD->isLet(), expected->isLet());
892+
893+
VarPattern *foundVP = nullptr;
894+
VD->getParentPattern()->forEachNode([&](Pattern *P) {
895+
if (auto *VP = dyn_cast<VarPattern>(P))
896+
if (VP->getSingleVar() == VD)
897+
foundVP = VP;
898+
});
899+
if (foundVP)
900+
diag.fixItReplace(foundVP->getLoc(),
901+
expected->isLet() ? "let" : "var");
902+
VD->markInvalid();
903+
expected->markInvalid();
889904
}
905+
return;
890906
}
891-
});
892-
}
907+
}
908+
});
893909
}
894910
// Check the guard expression, if present.
895911
if (auto *guard = labelItem.getGuardExpr()) {

test/Parse/switch.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,28 @@ func patternVarDiffType(x: Int, y: Double) {
276276
}
277277
}
278278

279+
func patternVarDiffMutability(x: Int, y: Double) {
280+
switch x {
281+
case let a where a < 5, var a where a > 10: // expected-error {{'var' pattern binding must match previous 'let' pattern binding}}{{27-30=let}}
282+
break
283+
default:
284+
break
285+
}
286+
switch (x, y) {
287+
// Would be nice to have a fixit in the following line if we detect that all bindings in the same pattern have the same problem.
288+
case let (a, b) where a < 5, var (a, b) where a > 10: // expected-error 2{{'var' pattern binding must match previous 'let' pattern binding}}{{none}}
289+
break
290+
case (let a, var b) where a < 5, (let a, let b) where a > 10: // expected-error {{'let' pattern binding must match previous 'var' pattern binding}}{{44-47=var}}
291+
break
292+
case (let a, let b) where a < 5, (var a, let b) where a > 10, (let a, var b) where a == 8:
293+
// expected-error@-1 {{'var' pattern binding must match previous 'let' pattern binding}}{{37-40=let}}
294+
// expected-error@-2 {{'var' pattern binding must match previous 'let' pattern binding}}{{73-76=let}}
295+
break
296+
default:
297+
break
298+
}
299+
}
300+
279301
func test_label(x : Int) {
280302
Gronk: // expected-error {{switch must be exhaustive}} expected-note{{do you want to add a default clause?}}
281303
switch x {

0 commit comments

Comments
 (0)