From 064e0c6264e2019f68d5658cfbefe2c6872e5379 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Mon, 3 Mar 2025 17:41:04 -0800 Subject: [PATCH 1/2] Ensure that isolated conformances originate in the same isolation domain This is the missing check for "rule #1" in the isolated conformances proposal, which states that an isolated conformance can only be referenced within the same isolation domain as the conformance. For example, a main-actor-isolated conformance can only be used within main-actor code. --- include/swift/AST/DiagnosticsSema.def | 3 + lib/Sema/TypeCheckConcurrency.cpp | 112 ++++++++++++++++++++ lib/Sema/TypeCheckConcurrency.h | 31 ++++++ lib/Sema/TypeCheckProtocol.cpp | 104 ++++++++++++++++++ lib/Sema/TypeCheckProtocol.h | 32 ++++++ test/Concurrency/isolated_conformance.swift | 6 ++ 6 files changed, 288 insertions(+) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 500a5b51ff544..835d38359f09a 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -8320,6 +8320,9 @@ ERROR(isolated_conformance_with_sendable_simple,none, "isolated conformance of %0 to %1 cannot be used to satisfy conformance " "requirement for a `Sendable` type parameter ", (Type, DeclName)) +ERROR(isolated_conformance_wrong_domain,none, + "%0 isolated conformance of %1 to %2 cannot be used in %3 context", + (ActorIsolation, Type, DeclName, ActorIsolation)) //===----------------------------------------------------------------------===// // MARK: @execution Attribute diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp index fcee92580b536..2a76c83a0ac3b 100644 --- a/lib/Sema/TypeCheckConcurrency.cpp +++ b/lib/Sema/TypeCheckConcurrency.cpp @@ -18,6 +18,7 @@ #include "MiscDiagnostics.h" #include "TypeCheckDistributed.h" #include "TypeCheckInvertible.h" +#include "TypeCheckProtocol.h" #include "TypeCheckType.h" #include "TypeChecker.h" #include "swift/AST/ASTWalker.h" @@ -3175,6 +3176,18 @@ namespace { checkDefaultArgument(defaultArg); } + if (auto erasureExpr = dyn_cast(expr)) { + checkIsolatedConformancesInContext( + erasureExpr->getConformances(), erasureExpr->getLoc(), + getDeclContext()); + } + + if (auto *underlyingToOpaque = dyn_cast(expr)) { + checkIsolatedConformancesInContext( + underlyingToOpaque->substitutions, underlyingToOpaque->getLoc(), + getDeclContext()); + } + return Action::Continue(expr); } @@ -4282,6 +4295,9 @@ namespace { if (!declRef) return false; + // Make sure isolated conformances are formed in the right context. + checkIsolatedConformancesInContext(declRef, loc, getDeclContext()); + auto decl = declRef.getDecl(); // If this declaration is a callee from the enclosing application, @@ -7684,3 +7700,99 @@ ActorIsolation swift::getConformanceIsolation(ProtocolConformance *conformance) return getActorIsolation(nominal); } + +namespace { + /// Identifies isolated conformances whose isolation differs from the + /// context's isolation. + class MismatchedIsolatedConformances { + llvm::TinyPtrVector badIsolatedConformances; + DeclContext *fromDC; + mutable std::optional fromIsolation; + + public: + MismatchedIsolatedConformances(const DeclContext *fromDC) + : fromDC(const_cast(fromDC)) { } + + ActorIsolation getContextIsolation() const { + if (!fromIsolation) + fromIsolation = getActorIsolationOfContext(fromDC); + + return *fromIsolation; + } + + ArrayRef getBadIsolatedConformances() const { + return badIsolatedConformances; + } + + explicit operator bool() const { return !badIsolatedConformances.empty(); } + + bool operator()(ProtocolConformanceRef conformance) { + if (conformance.isAbstract() || conformance.isPack()) + return false; + + auto concrete = conformance.getConcrete(); + auto normal = dyn_cast( + concrete->getRootConformance()); + if (!normal) + return false; + + if (!normal->isIsolated()) + return false; + + auto conformanceIsolation = getConformanceIsolation(concrete); + if (conformanceIsolation == getContextIsolation()) + return true; + + badIsolatedConformances.push_back(concrete); + return false; + } + + /// If there were any bad isolated conformances, diagnose them and return + /// true. Otherwise, returns false. + bool diagnose(SourceLoc loc) const { + if (badIsolatedConformances.empty()) + return false; + + ASTContext &ctx = fromDC->getASTContext(); + auto firstConformance = badIsolatedConformances.front(); + ctx.Diags.diagnose( + loc, diag::isolated_conformance_wrong_domain, + getConformanceIsolation(firstConformance), + firstConformance->getType(), + firstConformance->getProtocol()->getName(), + getContextIsolation()); + return true; + } + }; + +} + +bool swift::checkIsolatedConformancesInContext( + ConcreteDeclRef declRef, SourceLoc loc, const DeclContext *dc) { + MismatchedIsolatedConformances mismatched(dc); + forEachConformance(declRef, mismatched); + return mismatched.diagnose(loc); +} + +bool swift::checkIsolatedConformancesInContext( + ArrayRef conformances, SourceLoc loc, + const DeclContext *dc) { + MismatchedIsolatedConformances mismatched(dc); + for (auto conformance: conformances) + forEachConformance(conformance, mismatched); + return mismatched.diagnose(loc); +} + +bool swift::checkIsolatedConformancesInContext( + SubstitutionMap subs, SourceLoc loc, const DeclContext *dc) { + MismatchedIsolatedConformances mismatched(dc); + forEachConformance(subs, mismatched); + return mismatched.diagnose(loc); +} + +bool swift::checkIsolatedConformancesInContext( + Type type, SourceLoc loc, const DeclContext *dc) { + MismatchedIsolatedConformances mismatched(dc); + forEachConformance(type, mismatched); + return mismatched.diagnose(loc); +} diff --git a/lib/Sema/TypeCheckConcurrency.h b/lib/Sema/TypeCheckConcurrency.h index d00eb95870b24..0eea9b1ebdae4 100644 --- a/lib/Sema/TypeCheckConcurrency.h +++ b/lib/Sema/TypeCheckConcurrency.h @@ -703,6 +703,37 @@ void introduceUnsafeInheritExecutorReplacements( /// the immediate conformance, not any conformances on which it depends. ActorIsolation getConformanceIsolation(ProtocolConformance *conformance); +/// Check for correct use of isolated conformances in the given reference. +/// +/// This checks that any isolated conformances that occur in the given +/// declaration reference match the isolated of the context. +bool checkIsolatedConformancesInContext( + ConcreteDeclRef declRef, SourceLoc loc, const DeclContext *dc); + +/// Check for correct use of isolated conformances in the set given set of +/// protocol conformances. +/// +/// This checks that any isolated conformances that occur in the given +/// declaration reference match the isolated of the context. +bool checkIsolatedConformancesInContext( + ArrayRef conformances, SourceLoc loc, + const DeclContext *dc); + +/// Check for correct use of isolated conformances in the given substitution +/// map. +/// +/// This checks that any isolated conformances that occur in the given +/// substitution map match the isolated of the context. +bool checkIsolatedConformancesInContext( + SubstitutionMap subs, SourceLoc loc, const DeclContext *dc); + +/// Check for correct use of isolated conformances in the given type. +/// +/// This checks that any isolated conformances that occur in the given +/// type match the isolated of the context. +bool checkIsolatedConformancesInContext( + Type type, SourceLoc loc, const DeclContext *dc); + } // end namespace swift namespace llvm { diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 0e17cc2d634a0..cfd06a93d4275 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -43,6 +43,7 @@ #include "swift/AST/GenericSignature.h" #include "swift/AST/NameLookup.h" #include "swift/AST/NameLookupRequests.h" +#include "swift/AST/PackConformance.h" #include "swift/AST/ParameterList.h" #include "swift/AST/PotentialMacroExpansions.h" #include "swift/AST/PrettyStackTrace.h" @@ -7159,3 +7160,106 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) { req.getFirstType()->getCanonicalType(), requirementProto, conformance); } } + +bool swift::forEachConformance( + SubstitutionMap subs, + llvm::function_ref body) { + if (!subs) + return false; + + for (auto type: subs.getReplacementTypes()) { + if (forEachConformance(type, body)) + return true; + } + + for (auto conformance: subs.getConformances()) { + if (forEachConformance(conformance, body)) + return true; + } + + return false; +} + +bool swift::forEachConformance( + ProtocolConformanceRef conformance, + llvm::function_ref body) { + // Visit this conformance. + if (body(conformance)) + return true; + + if (conformance.isInvalid() || conformance.isAbstract()) + return false; + + if (conformance.isPack()) { + auto pack = conformance.getPack()->getPatternConformances(); + for (auto conformance : pack) { + if (forEachConformance(conformance, body)) + return true; + } + + return false; + } + + // Check the substitution make within this conformance. + auto concrete = conformance.getConcrete(); + if (forEachConformance(concrete->getSubstitutionMap(), body)) + return true; + + + return false; +} + +bool swift::forEachConformance( + Type type, llvm::function_ref body) { + return type.findIf([&](Type type) { + if (auto typeAlias = dyn_cast(type.getPointer())) { + if (forEachConformance(typeAlias->getSubstitutionMap(), body)) + return true; + + return false; + } + + if (auto opaqueArchetype = + dyn_cast(type.getPointer())) { + if (forEachConformance(opaqueArchetype->getSubstitutions(), body)) + return true; + + return false; + } + + // Look through type sugar. + if (auto sugarType = dyn_cast(type.getPointer())) { + type = sugarType->getImplementationType(); + } + + if (auto boundGeneric = dyn_cast(type.getPointer())) { + auto subs = boundGeneric->getContextSubstitutionMap(); + if (forEachConformance(subs, body)) + return true; + + return false; + } + + return false; + }); +} + +bool swift::forEachConformance( + ConcreteDeclRef declRef, + llvm::function_ref body) { + if (!declRef) + return false; + + Type type = declRef.getDecl()->getInterfaceType(); + if (auto subs = declRef.getSubstitutions()) { + if (forEachConformance(subs, body)) + return true; + + type = type.subst(subs); + } + + if (forEachConformance(type, body)) + return true; + + return false; +} diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 0e9360b2d3cdf..2291fa25fa076 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -240,6 +240,38 @@ bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness, bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness, ValueDecl *requirement); +/// Visit each conformance within the given type. +/// +/// If `body` returns true for any conformance, this function stops the +/// traversal and returns true. +bool forEachConformance( + Type type, llvm::function_ref body); + +/// Visit each conformance within the given conformance (including the given +/// one). +/// +/// If `body` returns true for any conformance, this function stops the +/// traversal and returns true. +bool forEachConformance( + ProtocolConformanceRef conformance, + llvm::function_ref body); + +/// Visit each conformance within the given substitution map. +/// +/// If `body` returns true for any conformance, this function stops the +/// traversal and returns true. +bool forEachConformance( + SubstitutionMap subs, + llvm::function_ref body); + +/// Visit each conformance within the given declaration reference. +/// +/// If `body` returns true for any conformance, this function stops the +/// traversal and returns true. +bool forEachConformance( + ConcreteDeclRef declRef, + llvm::function_ref body); + } #endif // SWIFT_SEMA_PROTOCOL_H diff --git a/test/Concurrency/isolated_conformance.swift b/test/Concurrency/isolated_conformance.swift index f8aeea1b4ede6..872fb4c02a5e7 100644 --- a/test/Concurrency/isolated_conformance.swift +++ b/test/Concurrency/isolated_conformance.swift @@ -119,3 +119,9 @@ func testIsolationConformancesInCall(c: C) { acceptSendableP(c) // expected-error{{isolated conformance of 'C' to 'P' cannot be used to satisfy conformance requirement for a `Sendable` type parameter}} acceptSendableMetaP(c) // expected-error{{isolated conformance of 'C' to 'P' cannot be used to satisfy conformance requirement for a `Sendable` type parameter}} } + +func testIsolationConformancesFromOutside(c: C) { + acceptP(c) // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}} + let _: any P = c // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}} + let _ = PWrapper() // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}} +} From 5c67cffbc0e171ce6fc911950fcce84b45a92e6d Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Mon, 3 Mar 2025 22:14:54 -0800 Subject: [PATCH 2/2] Prevent infinite recursion with conformance enumeration. --- lib/Sema/TypeCheckProtocol.cpp | 69 +++++++++++++++++++++++++--------- lib/Sema/TypeCheckProtocol.h | 14 +++++-- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index cfd06a93d4275..4137133059fa6 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -7163,17 +7163,22 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) { bool swift::forEachConformance( SubstitutionMap subs, - llvm::function_ref body) { + llvm::function_ref body, + VisitedConformances *visitedConformances) { if (!subs) return false; + VisitedConformances visited; + if (!visitedConformances) + visitedConformances = &visited; + for (auto type: subs.getReplacementTypes()) { - if (forEachConformance(type, body)) + if (forEachConformance(type, body, visitedConformances)) return true; } for (auto conformance: subs.getConformances()) { - if (forEachConformance(conformance, body)) + if (forEachConformance(conformance, body, visitedConformances)) return true; } @@ -7182,10 +7187,12 @@ bool swift::forEachConformance( bool swift::forEachConformance( ProtocolConformanceRef conformance, - llvm::function_ref body) { - // Visit this conformance. - if (body(conformance)) - return true; + llvm::function_ref body, + VisitedConformances *visitedConformances) { + // Make sure we can store visited conformances. + VisitedConformances visited; + if (!visitedConformances) + visitedConformances = &visited; if (conformance.isInvalid() || conformance.isAbstract()) return false; @@ -7193,27 +7200,48 @@ bool swift::forEachConformance( if (conformance.isPack()) { auto pack = conformance.getPack()->getPatternConformances(); for (auto conformance : pack) { - if (forEachConformance(conformance, body)) + if (forEachConformance(conformance, body, visitedConformances)) return true; } return false; } - // Check the substitution make within this conformance. + // Extract the concrete conformance. auto concrete = conformance.getConcrete(); - if (forEachConformance(concrete->getSubstitutionMap(), body)) + + // Prevent recursion. + if (!visitedConformances->insert(concrete).second) + return false; + + // Visit this conformance. + if (body(conformance)) return true; + // Check the substitution map within this conformance. + if (forEachConformance(concrete->getSubstitutionMap(), body, + visitedConformances)) + return true; return false; } bool swift::forEachConformance( - Type type, llvm::function_ref body) { + Type type, llvm::function_ref body, + VisitedConformances *visitedConformances) { + // Make sure we can store visited conformances. + VisitedConformances visited; + if (!visitedConformances) + visitedConformances = &visited; + + // Prevent recursion. + if (!visitedConformances->insert(type.getPointer()).second) + return false; + return type.findIf([&](Type type) { if (auto typeAlias = dyn_cast(type.getPointer())) { - if (forEachConformance(typeAlias->getSubstitutionMap(), body)) + if (forEachConformance(typeAlias->getSubstitutionMap(), body, + visitedConformances)) return true; return false; @@ -7221,7 +7249,8 @@ bool swift::forEachConformance( if (auto opaqueArchetype = dyn_cast(type.getPointer())) { - if (forEachConformance(opaqueArchetype->getSubstitutions(), body)) + if (forEachConformance(opaqueArchetype->getSubstitutions(), body, + visitedConformances)) return true; return false; @@ -7234,7 +7263,7 @@ bool swift::forEachConformance( if (auto boundGeneric = dyn_cast(type.getPointer())) { auto subs = boundGeneric->getContextSubstitutionMap(); - if (forEachConformance(subs, body)) + if (forEachConformance(subs, body, visitedConformances)) return true; return false; @@ -7246,19 +7275,25 @@ bool swift::forEachConformance( bool swift::forEachConformance( ConcreteDeclRef declRef, - llvm::function_ref body) { + llvm::function_ref body, + VisitedConformances *visitedConformances) { if (!declRef) return false; + // Make sure we can store visited conformances. + VisitedConformances visited; + if (!visitedConformances) + visitedConformances = &visited; + Type type = declRef.getDecl()->getInterfaceType(); if (auto subs = declRef.getSubstitutions()) { - if (forEachConformance(subs, body)) + if (forEachConformance(subs, body, visitedConformances)) return true; type = type.subst(subs); } - if (forEachConformance(type, body)) + if (forEachConformance(type, body, visitedConformances)) return true; return false; diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 2291fa25fa076..f71caa8ded30e 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -240,12 +240,15 @@ bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness, bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness, ValueDecl *requirement); +using VisitedConformances = llvm::SmallPtrSet; + /// Visit each conformance within the given type. /// /// If `body` returns true for any conformance, this function stops the /// traversal and returns true. bool forEachConformance( - Type type, llvm::function_ref body); + Type type, llvm::function_ref body, + VisitedConformances *visitedConformances = nullptr); /// Visit each conformance within the given conformance (including the given /// one). @@ -254,7 +257,8 @@ bool forEachConformance( /// traversal and returns true. bool forEachConformance( ProtocolConformanceRef conformance, - llvm::function_ref body); + llvm::function_ref body, + VisitedConformances *visitedConformances = nullptr); /// Visit each conformance within the given substitution map. /// @@ -262,7 +266,8 @@ bool forEachConformance( /// traversal and returns true. bool forEachConformance( SubstitutionMap subs, - llvm::function_ref body); + llvm::function_ref body, + VisitedConformances *visitedConformances = nullptr); /// Visit each conformance within the given declaration reference. /// @@ -270,7 +275,8 @@ bool forEachConformance( /// traversal and returns true. bool forEachConformance( ConcreteDeclRef declRef, - llvm::function_ref body); + llvm::function_ref body, + VisitedConformances *visitedConformances = nullptr); }