Skip to content

Remove some unnecessary generalization in exportability checking #29587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 39 additions & 77 deletions lib/Sema/TypeCheckAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,15 +1445,11 @@ class UsableFromInlineChecker : public AccessControlCheckerBase,
};

class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
using CheckExportabilityTypeCallback =
llvm::function_ref<void(const TypeDecl *, const TypeRepr *)>;
using CheckExportabilityConformanceCallback =
llvm::function_ref<void(const ProtocolConformance *)>;
class Diagnoser;

void checkTypeImpl(
Type type, const TypeRepr *typeRepr, const SourceFile &SF,
CheckExportabilityTypeCallback diagnoseType,
CheckExportabilityConformanceCallback diagnoseConformance) {
const Diagnoser &diagnoser) {
// Don't bother checking errors.
if (type && type->hasError())
return;
Expand All @@ -1469,7 +1465,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
if (!SF.isImportedImplementationOnly(M))
return true;

diagnoseType(component->getBoundDecl(), component);
diagnoser.diagnoseType(component->getBoundDecl(), component);
foundAnyIssues = true;
// We still continue even in the diagnostic case to report multiple
// violations.
Expand All @@ -1488,22 +1484,17 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {

class ProblematicTypeFinder : public TypeDeclFinder {
const SourceFile &SF;
CheckExportabilityTypeCallback diagnoseType;
CheckExportabilityConformanceCallback diagnoseConformance;
const Diagnoser &diagnoser;
public:
ProblematicTypeFinder(
const SourceFile &SF,
CheckExportabilityTypeCallback diagnoseType,
CheckExportabilityConformanceCallback diagnoseConformance)
: SF(SF), diagnoseType(diagnoseType),
diagnoseConformance(diagnoseConformance) {}
ProblematicTypeFinder(const SourceFile &SF, const Diagnoser &diagnoser)
: SF(SF), diagnoser(diagnoser) {}

void visitTypeDecl(const TypeDecl *typeDecl) {
ModuleDecl *M = typeDecl->getModuleContext();
if (!SF.isImportedImplementationOnly(M))
return;

diagnoseType(typeDecl, /*typeRepr*/nullptr);
diagnoser.diagnoseType(typeDecl, /*typeRepr*/nullptr);
}

void visitSubstitutionMap(SubstitutionMap subs) {
Expand All @@ -1521,7 +1512,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
ModuleDecl *M = rootConf->getDeclContext()->getParentModule();
if (!SF.isImportedImplementationOnly(M))
continue;
diagnoseConformance(rootConf);
diagnoser.diagnoseConformance(rootConf);
}
}

Expand All @@ -1545,25 +1536,20 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
}
};

type.walk(ProblematicTypeFinder(SF, diagnoseType, diagnoseConformance));
type.walk(ProblematicTypeFinder(SF, diagnoser));
}

void checkType(
Type type, const TypeRepr *typeRepr, const Decl *context,
CheckExportabilityTypeCallback diagnoseType,
CheckExportabilityConformanceCallback diagnoseConformance) {
const Diagnoser &diagnoser) {
auto *SF = context->getDeclContext()->getParentSourceFile();
assert(SF && "checking a non-source declaration?");
return checkTypeImpl(type, typeRepr, *SF, diagnoseType,
diagnoseConformance);
return checkTypeImpl(type, typeRepr, *SF, diagnoser);
}

void checkType(
const TypeLoc &TL, const Decl *context,
CheckExportabilityTypeCallback diagnoseType,
CheckExportabilityConformanceCallback diagnoseConformance) {
checkType(TL.getType(), TL.getTypeRepr(), context, diagnoseType,
diagnoseConformance);
const TypeLoc &TL, const Decl *context, const Diagnoser &diagnoser) {
checkType(TL.getType(), TL.getTypeRepr(), context, diagnoser);
}

void checkGenericParams(const GenericContext *ownerCtx,
Expand All @@ -1577,15 +1563,13 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
continue;
assert(param->getInherited().size() == 1);
checkType(param->getInherited().front(), ownerDecl,
getDiagnoseCallback(ownerDecl),
getDiagnoseCallback(ownerDecl));
getDiagnoser(ownerDecl));
}

forAllRequirementTypes(WhereClauseOwner(
const_cast<GenericContext *>(ownerCtx)),
[&](Type type, TypeRepr *typeRepr) {
checkType(type, typeRepr, ownerDecl, getDiagnoseCallback(ownerDecl),
getDiagnoseCallback(ownerDecl));
checkType(type, typeRepr, ownerDecl, getDiagnoser(ownerDecl));
});
}

Expand All @@ -1598,14 +1582,14 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
ExtensionWithConditionalConformances
};

class DiagnoseGenerically {
class Diagnoser {
const Decl *D;
Reason reason;
public:
DiagnoseGenerically(const Decl *D, Reason reason) : D(D), reason(reason) {}
Diagnoser(const Decl *D, Reason reason) : D(D), reason(reason) {}

void operator()(const TypeDecl *offendingType,
const TypeRepr *complainRepr) {
void diagnoseType(const TypeDecl *offendingType,
const TypeRepr *complainRepr) const {
ModuleDecl *M = offendingType->getModuleContext();
auto diag = D->diagnose(diag::decl_from_implementation_only_module,
offendingType->getDescriptiveKind(),
Expand All @@ -1614,7 +1598,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
highlightOffendingType(diag, complainRepr);
}

void operator()(const ProtocolConformance *offendingConformance) {
void diagnoseConformance(const ProtocolConformance *offendingConformance) const {
ModuleDecl *M = offendingConformance->getDeclContext()->getParentModule();
D->diagnose(diag::conformance_from_implementation_only_module,
offendingConformance->getType(),
Expand All @@ -1623,18 +1607,8 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
}
};

static_assert(
std::is_convertible<DiagnoseGenerically,
CheckExportabilityTypeCallback>::value,
"DiagnoseGenerically has wrong call signature");
static_assert(
std::is_convertible<DiagnoseGenerically,
CheckExportabilityConformanceCallback>::value,
"DiagnoseGenerically has wrong call signature for conformance diags");

DiagnoseGenerically getDiagnoseCallback(const Decl *D,
Reason reason = Reason::General) {
return DiagnoseGenerically(D, reason);
Diagnoser getDiagnoser(const Decl *D, Reason reason = Reason::General) {
return Diagnoser(D, reason);
}

public:
Expand Down Expand Up @@ -1768,7 +1742,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
return;

checkType(theVar->getInterfaceType(), /*typeRepr*/nullptr, theVar,
getDiagnoseCallback(theVar), getDiagnoseCallback(theVar));
getDiagnoser(theVar));
}

/// \see visitPatternBindingDecl
Expand All @@ -1787,8 +1761,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
if (shouldSkipChecking(anyVar))
return;

checkType(TP->getTypeLoc(), anyVar, getDiagnoseCallback(anyVar),
getDiagnoseCallback(anyVar));
checkType(TP->getTypeLoc(), anyVar, getDiagnoser(anyVar));
}

void visitPatternBindingDecl(PatternBindingDecl *PBD) {
Expand All @@ -1812,25 +1785,22 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
void visitTypeAliasDecl(TypeAliasDecl *TAD) {
checkGenericParams(TAD, TAD);
checkType(TAD->getUnderlyingType(),
TAD->getUnderlyingTypeRepr(), TAD, getDiagnoseCallback(TAD),
getDiagnoseCallback(TAD));
TAD->getUnderlyingTypeRepr(), TAD, getDiagnoser(TAD));
}

void visitAssociatedTypeDecl(AssociatedTypeDecl *assocType) {
llvm::for_each(assocType->getInherited(),
[&](TypeLoc requirement) {
checkType(requirement, assocType, getDiagnoseCallback(assocType),
getDiagnoseCallback(assocType));
checkType(requirement, assocType, getDiagnoser(assocType));
});
checkType(assocType->getDefaultDefinitionType(),
assocType->getDefaultDefinitionTypeRepr(), assocType,
getDiagnoseCallback(assocType), getDiagnoseCallback(assocType));
getDiagnoser(assocType));

if (assocType->getTrailingWhereClause()) {
forAllRequirementTypes(assocType,
[&](Type type, TypeRepr *typeRepr) {
checkType(type, typeRepr, assocType, getDiagnoseCallback(assocType),
getDiagnoseCallback(assocType));
checkType(type, typeRepr, assocType, getDiagnoser(assocType));
});
}
}
Expand All @@ -1840,22 +1810,19 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {

llvm::for_each(nominal->getInherited(),
[&](TypeLoc nextInherited) {
checkType(nextInherited, nominal, getDiagnoseCallback(nominal),
getDiagnoseCallback(nominal));
checkType(nextInherited, nominal, getDiagnoser(nominal));
});
}

void visitProtocolDecl(ProtocolDecl *proto) {
llvm::for_each(proto->getInherited(),
[&](TypeLoc requirement) {
checkType(requirement, proto, getDiagnoseCallback(proto),
getDiagnoseCallback(proto));
checkType(requirement, proto, getDiagnoser(proto));
});

if (proto->getTrailingWhereClause()) {
forAllRequirementTypes(proto, [&](Type type, TypeRepr *typeRepr) {
checkType(type, typeRepr, proto, getDiagnoseCallback(proto),
getDiagnoseCallback(proto));
checkType(type, typeRepr, proto, getDiagnoser(proto));
});
}
}
Expand All @@ -1865,41 +1832,38 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {

for (auto &P : *SD->getIndices()) {
checkType(P->getInterfaceType(), P->getTypeRepr(), SD,
getDiagnoseCallback(SD), getDiagnoseCallback(SD));
getDiagnoser(SD));
}
checkType(SD->getElementTypeLoc(), SD, getDiagnoseCallback(SD),
getDiagnoseCallback(SD));
checkType(SD->getElementTypeLoc(), SD, getDiagnoser(SD));
}

void visitAbstractFunctionDecl(AbstractFunctionDecl *fn) {
checkGenericParams(fn, fn);

for (auto *P : *fn->getParameters())
checkType(P->getInterfaceType(), P->getTypeRepr(), fn,
getDiagnoseCallback(fn), getDiagnoseCallback(fn));
getDiagnoser(fn));
}

void visitFuncDecl(FuncDecl *FD) {
visitAbstractFunctionDecl(FD);
checkType(FD->getBodyResultTypeLoc(), FD, getDiagnoseCallback(FD),
getDiagnoseCallback(FD));
checkType(FD->getBodyResultTypeLoc(), FD, getDiagnoser(FD));
}

void visitEnumElementDecl(EnumElementDecl *EED) {
if (!EED->hasAssociatedValues())
return;
for (auto &P : *EED->getParameterList())
checkType(P->getInterfaceType(), P->getTypeRepr(), EED,
getDiagnoseCallback(EED), getDiagnoseCallback(EED));
getDiagnoser(EED));
}

void checkConstrainedExtensionRequirements(ExtensionDecl *ED,
Reason reason) {
if (!ED->getTrailingWhereClause())
return;
forAllRequirementTypes(ED, [&](Type type, TypeRepr *typeRepr) {
checkType(type, typeRepr, ED, getDiagnoseCallback(ED, reason),
getDiagnoseCallback(ED, reason));
checkType(type, typeRepr, ED, getDiagnoser(ED, reason));
});
}

Expand All @@ -1913,8 +1877,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
// but just hide that from interfaces.
llvm::for_each(ED->getInherited(),
[&](TypeLoc nextInherited) {
checkType(nextInherited, ED, getDiagnoseCallback(ED),
getDiagnoseCallback(ED));
checkType(nextInherited, ED, getDiagnoser(ED));
});

bool hasPublicMembers = llvm::any_of(ED->getMembers(),
Expand All @@ -1927,8 +1890,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {

if (hasPublicMembers) {
checkType(ED->getExtendedType(), ED->getExtendedTypeRepr(), ED,
getDiagnoseCallback(ED, Reason::ExtensionWithPublicMembers),
getDiagnoseCallback(ED, Reason::ExtensionWithPublicMembers));
getDiagnoser(ED, Reason::ExtensionWithPublicMembers));
}

if (hasPublicMembers || !ED->getInherited().empty()) {
Expand Down