Skip to content

Commit 9b166c1

Browse files
committed
Check ConcurrentValue on @Concurrent functions and function types
1 parent 866a8d8 commit 9b166c1

File tree

7 files changed

+108
-1
lines changed

7 files changed

+108
-1
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4315,6 +4315,10 @@ WARNING(non_concurrent_property_type,none,
43154315
"cannot use %0 %1 with a non-concurrent-value type %2 "
43164316
"%select{across actors|from concurrently-executed code}3",
43174317
(DescriptiveDeclKind, DeclName, Type, bool))
4318+
WARNING(non_concurrent_function_type,none,
4319+
"`@concurrent` %select{function type|closure}0 has "
4320+
"non-concurrent-value %select{parameter|result}1 type %2",
4321+
(bool, bool, Type))
43184322

43194323
ERROR(actorindependent_let,none,
43204324
"'@actorIndependent' is meaningless on 'let' declarations because "

lib/Sema/TypeCheckAttr.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
128128
IGNORED_ATTR(OriginallyDefinedIn)
129129
IGNORED_ATTR(NoDerivative)
130130
IGNORED_ATTR(SpecializeExtension)
131-
IGNORED_ATTR(Concurrent)
132131
#undef IGNORED_ATTR
133132

134133
void visitAlignmentAttr(AlignmentAttr *attr) {
@@ -420,6 +419,22 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
420419
}
421420
}
422421
}
422+
423+
void visitConcurrentAttr(ConcurrentAttr *attr) {
424+
auto VD = dyn_cast<ValueDecl>(D);
425+
if (!VD)
426+
return;
427+
428+
auto innermostDC = VD->getInnermostDeclContext();
429+
SubstitutionMap subs;
430+
if (auto genericEnv = innermostDC->getGenericEnvironmentOfContext()) {
431+
subs = genericEnv->getForwardingSubstitutionMap();
432+
}
433+
434+
(void)diagnoseNonConcurrentTypesInReference(
435+
ConcreteDeclRef(VD, subs), innermostDC, VD->getLoc(),
436+
ConcurrentReferenceKind::ConcurrentFunction);
437+
}
423438
};
424439
} // end anonymous namespace
425440

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,35 @@ bool swift::diagnoseNonConcurrentTypesInReference(
938938
return false;
939939
}
940940

941+
bool swift::diagnoseNonConcurrentTypesInFunctionType(
942+
const AnyFunctionType *fnType, const DeclContext *dc, SourceLoc loc,
943+
bool isClosure) {
944+
ASTContext &ctx = dc->getASTContext();
945+
// Bail out immediately if we aren't supposed to do this checking.
946+
if (!dc->getASTContext().LangOpts.EnableExperimentalConcurrentValueChecking)
947+
return false;
948+
949+
// Check parameter types.
950+
for (const auto &param : fnType->getParams()) {
951+
Type paramType = param.getParameterType();
952+
if (!isConcurrentValueType(dc, paramType)) {
953+
ctx.Diags.diagnose(
954+
loc, diag::non_concurrent_function_type, isClosure, false, paramType);
955+
return true;
956+
}
957+
}
958+
959+
// Check result type.
960+
if (!isConcurrentValueType(dc, fnType->getResult())) {
961+
ctx.Diags.diagnose(
962+
loc, diag::non_concurrent_function_type, isClosure, true,
963+
fnType->getResult());
964+
return true;
965+
}
966+
967+
return false;
968+
}
969+
941970
namespace {
942971
/// Check whether a particular context may execute concurrently within
943972
/// another context.
@@ -1079,6 +1108,18 @@ namespace {
10791108
if (auto *closure = dyn_cast<AbstractClosureExpr>(expr)) {
10801109
closure->setActorIsolation(determineClosureIsolation(closure));
10811110
contextStack.push_back(closure);
1111+
1112+
1113+
// Concurrent closures must be composed of concurrent-safe parameter
1114+
// and result types.
1115+
if (isConcurrentClosure(closure)) {
1116+
if (auto fnType = closure->getType()->getAs<FunctionType>()) {
1117+
(void)diagnoseNonConcurrentTypesInFunctionType(
1118+
fnType, getDeclContext(), closure->getLoc(),
1119+
/*isClosure=*/true);
1120+
}
1121+
}
1122+
10821123
return { true, expr };
10831124
}
10841125

lib/Sema/TypeCheckConcurrency.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace swift {
2525

2626
class AbstractFunctionDecl;
2727
class ActorIsolation;
28+
class AnyFunctionType;
2829
class ASTContext;
2930
class ClassDecl;
3031
class ConcreteDeclRef;
@@ -60,6 +61,8 @@ enum class ConcurrentReferenceKind {
6061
CrossActor,
6162
/// A local capture referenced from concurrent code.
6263
LocalCapture,
64+
/// Concurrent function
65+
ConcurrentFunction,
6366
};
6467

6568
/// Describes why or where a particular entity has a non-concurrent-value type.
@@ -233,6 +236,12 @@ bool diagnoseNonConcurrentTypesInReference(
233236
ConcreteDeclRef declRef, const DeclContext *dc, SourceLoc loc,
234237
ConcurrentReferenceKind refKind);
235238

239+
/// Diagnose the presence of any non-concurrent types within the given
240+
/// function type.
241+
bool diagnoseNonConcurrentTypesInFunctionType(
242+
const AnyFunctionType *fnType, const DeclContext *dc, SourceLoc loc,
243+
bool isClosure);
244+
236245
} // end namespace swift
237246

238247
#endif /* SWIFT_SEMA_TYPECHECKCONCURRENCY_H */

lib/Sema/TypeCheckDeclPrimary.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,8 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
19061906
(void) TAD->getGenericSignature();
19071907
(void) TAD->getUnderlyingType();
19081908

1909+
// Make sure to check the underlying type.
1910+
19091911
TypeChecker::checkDeclAttributes(TAD);
19101912
checkAccessControl(TAD);
19111913
checkGenericParams(TAD);

lib/Sema/TypeCheckType.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "TypeChecker.h"
1919
#include "TypeCheckAvailability.h"
20+
#include "TypeCheckConcurrency.h"
2021
#include "TypeCheckProtocol.h"
2122
#include "TypeCheckType.h"
2223
#include "TypoCorrection.h"
@@ -2824,6 +2825,14 @@ NeverNullType TypeResolver::resolveASTFunctionType(
28242825
if (fnTy->hasError())
28252826
return fnTy;
28262827

2828+
// Concurrent function types must be composed of concurrent-safe parameter
2829+
// and result types.
2830+
if (concurrent && resolution.getStage() > TypeResolutionStage::Structural) {
2831+
(void)diagnoseNonConcurrentTypesInFunctionType(
2832+
fnTy, resolution.getDeclContext(), repr->getLoc(),
2833+
/*isClosure=*/false);
2834+
}
2835+
28272836
// If the type is a block or C function pointer, it must be representable in
28282837
// ObjC.
28292838
switch (representation) {

test/Concurrency/concurrent_value_checking.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,30 @@ class SomeClass: MainActorProto {
118118
@SomeGlobalActor
119119
func asyncMainMethod(_: NotConcurrent) async { } // expected-warning{{cannot pass argument of non-concurrent-value type 'NotConcurrent' across actors}}
120120
}
121+
122+
// ----------------------------------------------------------------------
123+
// ConcurrentValue restriction on concurrent functions.
124+
// ----------------------------------------------------------------------
125+
126+
// FIXME: poor diagnostic
127+
@concurrent func concurrentFunc() -> NotConcurrent? { nil } // expected-warning{{cannot call function returning non-concurrent-value type 'NotConcurrent?' across actors}}
128+
129+
// ----------------------------------------------------------------------
130+
// ConcurrentValue restriction on @concurrent types.
131+
// ----------------------------------------------------------------------
132+
typealias CF = @concurrent () -> NotConcurrent? // expected-warning{{`@concurrent` function type has non-concurrent-value result type 'NotConcurrent?'}}
133+
typealias BadGenericCF<T> = @concurrent () -> T? // expected-warning{{`@concurrent` function type has non-concurrent-value result type 'T?'}}
134+
typealias GoodGenericCF<T: ConcurrentValue> = @concurrent () -> T? // okay
135+
136+
var concurrentFuncVar: (@concurrent (NotConcurrent) -> Void)? = nil // expected-warning{{`@concurrent` function type has non-concurrent-value parameter type 'NotConcurrent'}}
137+
138+
// ----------------------------------------------------------------------
139+
// ConcurrentValue restriction on @concurrent closures.
140+
// ----------------------------------------------------------------------
141+
func acceptConcurrentUnary<T>(_: @concurrent (T) -> T) { }
142+
143+
func concurrentClosures<T>(_: T) {
144+
acceptConcurrentUnary { (x: T) in // expected-warning{{`@concurrent` closure has non-concurrent-value parameter type 'T'}}
145+
x
146+
}
147+
}

0 commit comments

Comments
 (0)