Skip to content

Commit 115abcb

Browse files
authored
Merge pull request #79573 from DougGregor/unsafe-for-in-loop
[SE-0458] Implement "unsafe" effect for the for-in loop
2 parents e0cf5a5 + 50801f9 commit 115abcb

File tree

13 files changed

+308
-109
lines changed

13 files changed

+308
-109
lines changed

include/swift/AST/ASTBridging.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,11 +2290,12 @@ BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
22902290
BridgedDeclContext cDC);
22912291

22922292
SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
2293-
"pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
2293+
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
22942294
BridgedForEachStmt BridgedForEachStmt_createParsed(
22952295
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
22962296
BridgedSourceLoc cForLoc, BridgedSourceLoc cTryLoc,
2297-
BridgedSourceLoc cAwaitLoc, BridgedPattern cPat, BridgedSourceLoc cInLoc,
2297+
BridgedSourceLoc cAwaitLoc, BridgedSourceLoc cUnsafeLoc,
2298+
BridgedPattern cPat, BridgedSourceLoc cInLoc,
22982299
BridgedExpr cSequence, BridgedSourceLoc cWhereLoc,
22992300
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody);
23002301

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8218,8 +8218,12 @@ GROUPED_WARNING(preconcurrency_import_unsafe,Unsafe,none,
82188218
"introduce data races", ())
82198219
GROUPED_WARNING(unsafe_without_unsafe,Unsafe,none,
82208220
"expression uses unsafe constructs but is not marked with 'unsafe'", ())
8221+
GROUPED_WARNING(for_unsafe_without_unsafe,Unsafe,none,
8222+
"for-in loop uses unsafe constructs but is not marked with 'unsafe'", ())
82218223
WARNING(no_unsafe_in_unsafe,none,
82228224
"no unsafe operations occur within 'unsafe' expression", ())
8225+
WARNING(no_unsafe_in_unsafe_for,none,
8226+
"no unsafe operations occur within 'unsafe' for-in loop", ())
82238227
NOTE(make_subclass_unsafe,none,
82248228
"make class %0 @unsafe to allow unsafe overrides of safe superclass methods",
82258229
(DeclName))

include/swift/AST/Stmt.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,7 @@ class ForEachStmt : public LabeledStmt {
10031003
SourceLoc ForLoc;
10041004
SourceLoc TryLoc;
10051005
SourceLoc AwaitLoc;
1006+
SourceLoc UnsafeLoc;
10061007
Pattern *Pat;
10071008
SourceLoc InLoc;
10081009
Expr *Sequence;
@@ -1020,13 +1021,14 @@ class ForEachStmt : public LabeledStmt {
10201021

10211022
public:
10221023
ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc,
1023-
SourceLoc AwaitLoc, Pattern *Pat, SourceLoc InLoc, Expr *Sequence,
1024+
SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat,
1025+
SourceLoc InLoc, Expr *Sequence,
10241026
SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body,
10251027
std::optional<bool> implicit = std::nullopt)
10261028
: LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc),
10271029
LabelInfo),
1028-
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), Pat(nullptr),
1029-
InLoc(InLoc), Sequence(Sequence), WhereLoc(WhereLoc),
1030+
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc),
1031+
Pat(nullptr), InLoc(InLoc), Sequence(Sequence), WhereLoc(WhereLoc),
10301032
WhereExpr(WhereExpr), Body(Body) {
10311033
setPattern(Pat);
10321034
}
@@ -1064,6 +1066,7 @@ class ForEachStmt : public LabeledStmt {
10641066

10651067
SourceLoc getAwaitLoc() const { return AwaitLoc; }
10661068
SourceLoc getTryLoc() const { return TryLoc; }
1069+
SourceLoc getUnsafeLoc() const { return UnsafeLoc; }
10671070

10681071
/// getPattern - Retrieve the pattern describing the iteration variables.
10691072
/// These variables will only be visible within the body of the loop.

include/swift/Parse/IDEInspectionCallbacks.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ class CodeCompletionCallbacks {
287287
virtual void completeStmtLabel(StmtKind ParentKind) {};
288288

289289
virtual
290-
void completeForEachPatternBeginning(bool hasTry, bool hasAwait) {};
290+
void completeForEachPatternBeginning(
291+
bool hasTry, bool hasAwait, bool hasUnsafe) {};
291292

292293
virtual void completeTypeAttrBeginning() {};
293294

lib/AST/Bridging/StmtBridging.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,15 @@ BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
191191
BridgedForEachStmt BridgedForEachStmt_createParsed(
192192
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
193193
BridgedSourceLoc cForLoc, BridgedSourceLoc cTryLoc,
194-
BridgedSourceLoc cAwaitLoc, BridgedPattern cPat, BridgedSourceLoc cInLoc,
194+
BridgedSourceLoc cAwaitLoc, BridgedSourceLoc cUnsafeLoc,
195+
BridgedPattern cPat, BridgedSourceLoc cInLoc,
195196
BridgedExpr cSequence, BridgedSourceLoc cWhereLoc,
196197
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody) {
197198
return new (cContext.unbridged()) ForEachStmt(
198199
cLabelInfo.unbridged(), cForLoc.unbridged(), cTryLoc.unbridged(),
199-
cAwaitLoc.unbridged(), cPat.unbridged(), cInLoc.unbridged(),
200-
cSequence.unbridged(), cWhereLoc.unbridged(), cWhereExpr.unbridged(),
201-
cBody.unbridged());
200+
cAwaitLoc.unbridged(), cUnsafeLoc.unbridged(), cPat.unbridged(),
201+
cInLoc.unbridged(), cSequence.unbridged(), cWhereLoc.unbridged(),
202+
cWhereExpr.unbridged(), cBody.unbridged());
202203
}
203204

204205
BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext,

lib/ASTGen/Sources/ASTGen/Stmts.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ extension ASTGenVisitor {
330330
forLoc: self.generateSourceLoc(node.forKeyword),
331331
tryLoc: self.generateSourceLoc(node.tryKeyword),
332332
awaitLoc: self.generateSourceLoc(node.awaitKeyword),
333+
unsafeLoc: self.generateSourceLoc(node.unsafeKeyword),
333334
// NOTE: The pattern can be either a refutable pattern after `case` or a
334335
// normal binding pattern. ASTGen doesn't care because it should be handled
335336
// by the parser.

lib/IDE/CodeCompletion.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ class CodeCompletionCallbacksImpl : public CodeCompletionCallbacks,
302302
void completeGenericRequirement() override;
303303
void completeAfterIfStmtElse() override;
304304
void completeStmtLabel(StmtKind ParentKind) override;
305-
void completeForEachPatternBeginning(bool hasTry, bool hasAwait) override;
305+
void completeForEachPatternBeginning(
306+
bool hasTry, bool hasAwait, bool hasUnsafe) override;
306307
void completeTypeAttrBeginning() override;
307308
void completeTypeAttrInheritanceBeginning() override;
308309
void completeOptionalBinding() override;
@@ -636,14 +637,16 @@ void CodeCompletionCallbacksImpl::completeStmtLabel(StmtKind ParentKind) {
636637
}
637638

638639
void CodeCompletionCallbacksImpl::completeForEachPatternBeginning(
639-
bool hasTry, bool hasAwait) {
640+
bool hasTry, bool hasAwait, bool hasUnsafe) {
640641
CurDeclContext = P.CurDeclContext;
641642
Kind = CompletionKind::ForEachPatternBeginning;
642643
ParsedKeywords.clear();
643644
if (hasTry)
644645
ParsedKeywords.emplace_back("try");
645646
if (hasAwait)
646647
ParsedKeywords.emplace_back("await");
648+
if (hasUnsafe)
649+
ParsedKeywords.emplace_back("unsafe");
647650
}
648651

649652
void CodeCompletionCallbacksImpl::completeOptionalBinding() {

lib/Parse/ParseStmt.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,6 +2368,7 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
23682368
auto StartOfControl = Tok.getLoc();
23692369
SourceLoc AwaitLoc;
23702370
SourceLoc TryLoc;
2371+
SourceLoc UnsafeLoc;
23712372

23722373
if (Tok.isContextualKeyword("await")) {
23732374
AwaitLoc = consumeToken();
@@ -2378,10 +2379,15 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
23782379
}
23792380
}
23802381

2382+
if (Context.LangOpts.hasFeature(Feature::WarnUnsafe) &&
2383+
Tok.isContextualKeyword("unsafe")) {
2384+
UnsafeLoc = consumeToken();
2385+
}
2386+
23812387
if (Tok.is(tok::code_complete)) {
23822388
if (CodeCompletionCallbacks) {
23832389
CodeCompletionCallbacks->completeForEachPatternBeginning(
2384-
TryLoc.isValid(), AwaitLoc.isValid());
2390+
TryLoc.isValid(), AwaitLoc.isValid(), UnsafeLoc.isValid());
23852391
}
23862392
consumeToken(tok::code_complete);
23872393
// Since 'completeForeachPatternBeginning' is a keyword only completion,
@@ -2495,7 +2501,8 @@ ParserResult<Stmt> Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) {
24952501

24962502
return makeParserResult(
24972503
Status,
2498-
new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, pattern.get(), InLoc,
2504+
new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, UnsafeLoc,
2505+
pattern.get(), InLoc,
24992506
Container.get(), WhereLoc, Where.getPtrOrNull(),
25002507
Body.get()));
25012508
}

lib/Sema/BuilderTransform.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ class ResultBuilderTransform
726726
auto *newForEach = new (ctx)
727727
ForEachStmt(forEachStmt->getLabelInfo(), forEachStmt->getForLoc(),
728728
forEachStmt->getTryLoc(), forEachStmt->getAwaitLoc(),
729+
forEachStmt->getUnsafeLoc(),
729730
forEachStmt->getPattern(), forEachStmt->getInLoc(),
730731
forEachStmt->getParsedSequence(),
731732
forEachStmt->getWhereLoc(), forEachStmt->getWhere(),

lib/Sema/CSGen.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4632,10 +4632,11 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
46324632
AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall);
46334633
}
46344634

4635-
// Wrap the 'next' call in 'unsafe', if there is one.
4636-
if (unsafeExpr) {
4637-
nextCall = new (ctx) UnsafeExpr(unsafeExpr->getLoc(), nextCall, Type(),
4638-
/*implicit=*/true);
4635+
// Wrap the 'next' call in 'unsafe', if the for..in loop has that
4636+
// effect.
4637+
if (stmt->getUnsafeLoc().isValid()) {
4638+
nextCall = new (ctx) UnsafeExpr(
4639+
stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true);
46394640
}
46404641

46414642
// The iterator type must conform to IteratorProtocol.

0 commit comments

Comments
 (0)