Skip to content

Requirement lowering cleanup and accept-invalid fix for @objc protocols #69866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3718,35 +3718,31 @@ class PrintConformance : public PrintBase {
return;
}

printFieldQuotedRaw([&](raw_ostream &out) { genericSig->print(out); },
"generic_signature");

auto printSubstitution = [&](GenericTypeParamType * genericParam,
Type replacementType) {
printFieldQuotedRaw([&](raw_ostream &out) {
genericParam->print(out);
out << " -> ";
if (replacementType) {
PrintOptions opts;
opts.PrintForSIL = true;
opts.PrintTypesForDebugging = true;
replacementType->print(out, opts);
}
else
out << "<unresolved concrete type>";
}, "");
};
printFieldRaw([&](raw_ostream &out) { genericSig->print(out); },
"generic_signature");

auto genericParams = genericSig.getGenericParams();
auto replacementTypes =
static_cast<const SubstitutionMap &>(map).getReplacementTypesBuffer();
for (unsigned i : indices(genericParams)) {
if (style == SubstitutionMap::DumpStyle::Minimal) {
printSubstitution(genericParams[i], replacementTypes[i]);
printFieldRaw([&](raw_ostream &out) {
genericParams[i]->print(out);
out << " -> ";
if (replacementTypes[i])
out << replacementTypes[i];
else
out << "<unresolved concrete type>";
}, "");
} else {
printRecArbitrary([&](StringRef label) {
printHead("substitution", ASTNodeColor, label);
printSubstitution(genericParams[i], replacementTypes[i]);
printFieldRaw([&](raw_ostream &out) {
genericParams[i]->print(out);
out << " -> ";
}, "");
if (replacementTypes[i])
printRec(replacementTypes[i]);
printFoot();
});
}
Expand Down
110 changes: 58 additions & 52 deletions lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,25 @@ swift::rewriting::desugarRequirement(Requirement req, SourceLoc loc,
}
}

void swift::rewriting::desugarRequirements(SmallVector<StructuralRequirement, 2> &reqs,
SmallVectorImpl<RequirementError> &errors) {
SmallVector<StructuralRequirement, 2> result;
for (auto req : reqs) {
SmallVector<Requirement, 2> desugaredReqs;
SmallVector<RequirementError, 2> ignoredErrors;

if (req.inferred)
desugarRequirement(req.req, SourceLoc(), desugaredReqs, ignoredErrors);
else
desugarRequirement(req.req, req.loc, desugaredReqs, errors);

for (auto desugaredReq : desugaredReqs)
result.push_back({desugaredReq, req.loc, req.inferred});
}

std::swap(reqs, result);
}

//
// Requirement realization and inference.
//
Expand All @@ -467,8 +486,6 @@ static void realizeTypeRequirement(DeclContext *dc,
SourceLoc loc,
SmallVectorImpl<StructuralRequirement> &result,
SmallVectorImpl<RequirementError> &errors) {
SmallVector<Requirement, 2> reqs;

// The GenericSignatureBuilder allowed the right hand side of a
// conformance or superclass requirement to reference a protocol
// typealias whose underlying type was a protocol or class.
Expand Down Expand Up @@ -497,22 +514,19 @@ static void realizeTypeRequirement(DeclContext *dc,
}

if (constraintType->isConstraintType()) {
Requirement req(RequirementKind::Conformance, subjectType, constraintType);
desugarRequirement(req, loc, reqs, errors);
result.push_back({Requirement(RequirementKind::Conformance,
subjectType, constraintType),
loc, /*wasInferred=*/false});
} else if (constraintType->getClassOrBoundGenericClass()) {
Requirement req(RequirementKind::Superclass, subjectType, constraintType);
desugarRequirement(req, loc, reqs, errors);
result.push_back({Requirement(RequirementKind::Superclass,
subjectType, constraintType),
loc, /*wasInferred=*/false});
} else {
errors.push_back(
RequirementError::forInvalidTypeRequirement(subjectType,
constraintType,
loc));
return;
}

// Add source location information.
for (auto req : reqs)
result.push_back({req, loc, /*wasInferred=*/false});
}

namespace {
Expand All @@ -521,11 +535,11 @@ namespace {
struct InferRequirementsWalker : public TypeWalker {
ModuleDecl *module;
DeclContext *dc;
SmallVector<Requirement, 2> reqs;
SmallVector<RequirementError, 2> errors;
SmallVectorImpl<StructuralRequirement> &reqs;

explicit InferRequirementsWalker(ModuleDecl *module, DeclContext *dc)
: module(module), dc(dc) {}
explicit InferRequirementsWalker(ModuleDecl *module, DeclContext *dc,
SmallVectorImpl<StructuralRequirement> &reqs)
: module(module), dc(dc), reqs(reqs) {}

Action walkToTypePre(Type ty) override {
// Unbound generic types are the result of recovered-but-invalid code, and
Expand Down Expand Up @@ -555,8 +569,7 @@ struct InferRequirementsWalker : public TypeWalker {
return false;

return (req.getKind() == RequirementKind::Conformance &&
req.getSecondType()->castTo<ProtocolType>()->getDecl()
->isSpecificProtocol(KnownProtocolKind::Sendable));
req.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Sendable));
};

// Infer from generic typealiases.
Expand All @@ -567,7 +580,7 @@ struct InferRequirementsWalker : public TypeWalker {
if (skipRequirement(rawReq, decl))
continue;

desugarRequirement(rawReq.subst(subMap), SourceLoc(), reqs, errors);
reqs.push_back({rawReq.subst(subMap), SourceLoc(), /*inferred=*/true});
}

return Action::Continue;
Expand All @@ -581,10 +594,9 @@ struct InferRequirementsWalker : public TypeWalker {
packExpansion->getPatternType()->getTypeParameterPacks(packReferences);

auto countType = packExpansion->getCountType();
for (auto pack : packReferences) {
Requirement req(RequirementKind::SameShape, countType, pack);
desugarRequirement(req, SourceLoc(), reqs, errors);
}
for (auto pack : packReferences)
reqs.push_back({Requirement(RequirementKind::SameShape, countType, pack),
SourceLoc(), /*inferred=*/true});
}

// Infer requirements from `@differentiable` function types.
Expand All @@ -596,9 +608,9 @@ struct InferRequirementsWalker : public TypeWalker {
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
// Add a new conformance constraint for a fixed protocol.
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
Requirement req(RequirementKind::Conformance, type,
protocol->getDeclaredInterfaceType());
desugarRequirement(req, SourceLoc(), reqs, errors);
reqs.push_back({Requirement(RequirementKind::Conformance, type,
protocol->getDeclaredInterfaceType()),
SourceLoc(), /*inferred=*/true});
};

auto &ctx = module->getASTContext();
Expand All @@ -610,8 +622,9 @@ struct InferRequirementsWalker : public TypeWalker {
auto secondType = assocType->getDeclaredInterfaceType()
->castTo<DependentMemberType>()
->substBaseType(module, firstType);
Requirement req(RequirementKind::SameType, firstType, secondType);
desugarRequirement(req, SourceLoc(), reqs, errors);
reqs.push_back({Requirement(RequirementKind::SameType,
firstType, secondType),
SourceLoc(), /*inferred=*/true});
};
auto *tangentVectorAssocType =
differentiableProtocol->getAssociatedType(ctx.Id_TangentVector);
Expand Down Expand Up @@ -659,8 +672,7 @@ struct InferRequirementsWalker : public TypeWalker {
if (skipRequirement(rawReq, decl))
continue;

auto req = rawReq.subst(subMap);
desugarRequirement(req, SourceLoc(), reqs, errors);
reqs.push_back({rawReq.subst(subMap), SourceLoc(), /*inferred=*/true});
}

return Action::Continue;
Expand All @@ -683,15 +695,12 @@ void swift::rewriting::inferRequirements(
if (!type)
return;

InferRequirementsWalker walker(module, dc);
InferRequirementsWalker walker(module, dc, result);
type.walk(walker);

for (const auto &req : walker.reqs)
result.push_back({req, loc, /*wasInferred=*/true});
}

/// Desugar a requirement and perform requirement inference if requested
/// to obtain zero or more structural requirements.
/// Perform requirement inference from the type representations in the
/// requirement itself (eg, `T == Set<U>` infers `U: Hashable`).
void swift::rewriting::realizeRequirement(
DeclContext *dc,
Requirement req, RequirementRepr *reqRepr,
Expand Down Expand Up @@ -732,12 +741,7 @@ void swift::rewriting::realizeRequirement(
inferRequirements(firstType, firstLoc, moduleForInference, dc, result);
}

SmallVector<Requirement, 2> reqs;
desugarRequirement(req, loc, reqs, errors);

for (auto req : reqs)
result.push_back({req, loc, /*wasInferred=*/false});

result.push_back({req, loc, /*wasInferred=*/false});
break;
}

Expand All @@ -754,11 +758,7 @@ void swift::rewriting::realizeRequirement(
inferRequirements(secondType, secondLoc, moduleForInference, dc, result);
}

SmallVector<Requirement, 2> reqs;
desugarRequirement(req, loc, reqs, errors);

for (auto req : reqs)
result.push_back({req, loc, /*wasInferred=*/false});
result.push_back({req, loc, /*wasInferred=*/false});
break;
}
}
Expand Down Expand Up @@ -903,13 +903,13 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
ProtocolDecl *proto) const {
assert(!proto->hasLazyRequirementSignature());

SmallVector<StructuralRequirement, 4> result;
SmallVector<RequirementError, 4> errors;
SmallVector<StructuralRequirement, 2> result;
SmallVector<RequirementError, 2> errors;

auto &ctx = proto->getASTContext();
auto selfTy = proto->getSelfInterfaceType();

SmallVector<Type, 4> needsDefaultReqirements({selfTy});
SmallVector<Type, 4> needsDefaultRequirements({selfTy});

unsigned errorCount = errors.size();
realizeInheritedRequirements(proto, selfTy,
Expand Down Expand Up @@ -950,7 +950,12 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
result.push_back({Requirement(RequirementKind::Layout, selfTy, layout),
proto->getLoc(), /*inferred=*/true});

expandDefaultRequirements(ctx, needsDefaultReqirements, result, errors);
desugarRequirements(result, errors);
expandDefaultRequirements(ctx, needsDefaultRequirements, result, errors);

diagnoseRequirementErrors(ctx, errors,
AllowConcreteTypePolicy::NestedAssocTypes);

return ctx.AllocateCopy(result);
}

Expand All @@ -976,7 +981,7 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
return false;
});

needsDefaultReqirements.push_back(assocType);
needsDefaultRequirements.push_back(assocType);
}

// Add requirements for each typealias.
Expand Down Expand Up @@ -1014,7 +1019,8 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
}
}

expandDefaultRequirements(ctx, needsDefaultReqirements, result, errors);
desugarRequirements(result, errors);
expandDefaultRequirements(ctx, needsDefaultRequirements, result, errors);

diagnoseRequirementErrors(ctx, errors,
AllowConcreteTypePolicy::NestedAssocTypes);
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/RequirementMachine/RequirementLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ namespace rewriting {
// documentation
// comments.

void desugarRequirements(SmallVector<StructuralRequirement, 2> &result,
SmallVectorImpl<RequirementError> &errors);

void desugarRequirement(Requirement req, SourceLoc loc,
SmallVectorImpl<Requirement> &result,
SmallVectorImpl<RequirementError> &errors);
Expand Down
32 changes: 13 additions & 19 deletions lib/AST/RequirementMachine/RequirementMachineRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,13 @@ AbstractGenericSignatureRequest::evaluate(

// Convert the input Requirements into StructuralRequirements by adding
// empty source locations.
SmallVector<StructuralRequirement, 4> requirements;
SmallVector<StructuralRequirement, 2> requirements;
for (auto req : baseSignature.getRequirements())
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});

// We need to create this errors vector to pass to
// desugarRequirement, but this request should never
// diagnose errors.
SmallVector<RequirementError, 4> errors;
// Add the new requirements.
for (auto req : addedRequirements)
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});

// The requirements passed to this request may have been substituted,
// meaning the subject type might be a concrete type and not a type
Expand All @@ -651,12 +650,8 @@ AbstractGenericSignatureRequest::evaluate(
// Desugaring converts these kinds of requirements into "proper"
// requirements where the subject type is always a type parameter,
// which is what the RuleBuilder expects.
for (auto req : addedRequirements) {
SmallVector<Requirement, 2> reqs;
desugarRequirement(req, SourceLoc(), reqs, errors);
for (auto req : reqs)
requirements.push_back({req, SourceLoc(), /*wasInferred=*/false});
}
SmallVector<RequirementError, 2> errors;
desugarRequirements(requirements, errors);

auto &rewriteCtx = ctx.getRewriteContext();

Expand Down Expand Up @@ -747,8 +742,8 @@ InferredGenericSignatureRequest::evaluate(
parentSig.getGenericParams().begin(),
parentSig.getGenericParams().end());

SmallVector<StructuralRequirement, 4> requirements;
SmallVector<RequirementError, 4> errors;
SmallVector<StructuralRequirement, 2> requirements;
SmallVector<RequirementError, 2> errors;

SourceLoc loc = [&]() {
if (genericParamList) {
Expand Down Expand Up @@ -844,9 +839,6 @@ InferredGenericSignatureRequest::evaluate(
for (auto *gtpd : genericParamList->getParams())
localGPs.push_back(gtpd->getDeclaredInterfaceType());

// Expand defaults and eliminate all inverse-conformance requirements.
expandDefaultRequirements(ctx, localGPs, requirements, errors);

// Perform requirement inference from function parameter and result
// types and such.
for (auto sourcePair : inferenceSources) {
Expand All @@ -860,12 +852,14 @@ InferredGenericSignatureRequest::evaluate(
// Finish by adding any remaining requirements. This is used to introduce
// inferred same-type requirements when building the generic signature of
// an extension whose extended type is a generic typealias.
SmallVector<Requirement, 4> rawAddedRequirements;
for (const auto &req : addedRequirements)
desugarRequirement(req, SourceLoc(), rawAddedRequirements, errors);
for (const auto &req : rawAddedRequirements)
requirements.push_back({req, SourceLoc(), /*inferred=*/true});

desugarRequirements(requirements, errors);

// Expand defaults and eliminate all inverse-conformance requirements.
expandDefaultRequirements(ctx, localGPs, requirements, errors);

// Re-order requirements so that inferred requirements appear last. This
// ensures that if an inferred requirement is redundant with some other
// requirement, it is the inferred requirement that becomes redundant,
Expand Down
3 changes: 1 addition & 2 deletions lib/AST/RequirementMachine/RewriteSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,6 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const {
}
if (!WrittenRequirements.empty()) {
out << "Written requirements: {\n";

for (unsigned reqID : indices(WrittenRequirements)) {
out << " - ID: " << reqID << " - ";
const auto &requirement = WrittenRequirements[reqID];
Expand All @@ -725,6 +724,6 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const {
requirement.loc.print(out, Context.getASTContext().SourceMgr);
out << "\n";
}
out << "}\n";
}
out << "}\n";
}
Loading