Skip to content

Commit 5c67cff

Browse files
committed
Prevent infinite recursion with conformance enumeration.
1 parent 064e0c6 commit 5c67cff

File tree

2 files changed

+62
-21
lines changed

2 files changed

+62
-21
lines changed

lib/Sema/TypeCheckProtocol.cpp

+52-17
Original file line numberDiff line numberDiff line change
@@ -7163,17 +7163,22 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
71637163

71647164
bool swift::forEachConformance(
71657165
SubstitutionMap subs,
7166-
llvm::function_ref<bool(ProtocolConformanceRef)> body) {
7166+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
7167+
VisitedConformances *visitedConformances) {
71677168
if (!subs)
71687169
return false;
71697170

7171+
VisitedConformances visited;
7172+
if (!visitedConformances)
7173+
visitedConformances = &visited;
7174+
71707175
for (auto type: subs.getReplacementTypes()) {
7171-
if (forEachConformance(type, body))
7176+
if (forEachConformance(type, body, visitedConformances))
71727177
return true;
71737178
}
71747179

71757180
for (auto conformance: subs.getConformances()) {
7176-
if (forEachConformance(conformance, body))
7181+
if (forEachConformance(conformance, body, visitedConformances))
71777182
return true;
71787183
}
71797184

@@ -7182,46 +7187,70 @@ bool swift::forEachConformance(
71827187

71837188
bool swift::forEachConformance(
71847189
ProtocolConformanceRef conformance,
7185-
llvm::function_ref<bool(ProtocolConformanceRef)> body) {
7186-
// Visit this conformance.
7187-
if (body(conformance))
7188-
return true;
7190+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
7191+
VisitedConformances *visitedConformances) {
7192+
// Make sure we can store visited conformances.
7193+
VisitedConformances visited;
7194+
if (!visitedConformances)
7195+
visitedConformances = &visited;
71897196

71907197
if (conformance.isInvalid() || conformance.isAbstract())
71917198
return false;
71927199

71937200
if (conformance.isPack()) {
71947201
auto pack = conformance.getPack()->getPatternConformances();
71957202
for (auto conformance : pack) {
7196-
if (forEachConformance(conformance, body))
7203+
if (forEachConformance(conformance, body, visitedConformances))
71977204
return true;
71987205
}
71997206

72007207
return false;
72017208
}
72027209

7203-
// Check the substitution make within this conformance.
7210+
// Extract the concrete conformance.
72047211
auto concrete = conformance.getConcrete();
7205-
if (forEachConformance(concrete->getSubstitutionMap(), body))
7212+
7213+
// Prevent recursion.
7214+
if (!visitedConformances->insert(concrete).second)
7215+
return false;
7216+
7217+
// Visit this conformance.
7218+
if (body(conformance))
72067219
return true;
72077220

7221+
// Check the substitution map within this conformance.
7222+
if (forEachConformance(concrete->getSubstitutionMap(), body,
7223+
visitedConformances))
7224+
return true;
72087225

72097226
return false;
72107227
}
72117228

72127229
bool swift::forEachConformance(
7213-
Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body) {
7230+
Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body,
7231+
VisitedConformances *visitedConformances) {
7232+
// Make sure we can store visited conformances.
7233+
VisitedConformances visited;
7234+
if (!visitedConformances)
7235+
visitedConformances = &visited;
7236+
7237+
// Prevent recursion.
7238+
if (!visitedConformances->insert(type.getPointer()).second)
7239+
return false;
7240+
72147241
return type.findIf([&](Type type) {
72157242
if (auto typeAlias = dyn_cast<TypeAliasType>(type.getPointer())) {
7216-
if (forEachConformance(typeAlias->getSubstitutionMap(), body))
7243+
if (forEachConformance(typeAlias->getSubstitutionMap(), body,
7244+
visitedConformances))
72177245
return true;
72187246

72197247
return false;
72207248
}
72217249

72227250
if (auto opaqueArchetype =
72237251
dyn_cast<OpaqueTypeArchetypeType>(type.getPointer())) {
7224-
if (forEachConformance(opaqueArchetype->getSubstitutions(), body))
7252+
if (forEachConformance(opaqueArchetype->getSubstitutions(), body,
7253+
visitedConformances))
72257254
return true;
72267255

72277256
return false;
@@ -7234,7 +7263,7 @@ bool swift::forEachConformance(
72347263

72357264
if (auto boundGeneric = dyn_cast<BoundGenericType>(type.getPointer())) {
72367265
auto subs = boundGeneric->getContextSubstitutionMap();
7237-
if (forEachConformance(subs, body))
7266+
if (forEachConformance(subs, body, visitedConformances))
72387267
return true;
72397268

72407269
return false;
@@ -7246,19 +7275,25 @@ bool swift::forEachConformance(
72467275

72477276
bool swift::forEachConformance(
72487277
ConcreteDeclRef declRef,
7249-
llvm::function_ref<bool(ProtocolConformanceRef)> body) {
7278+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
7279+
VisitedConformances *visitedConformances) {
72507280
if (!declRef)
72517281
return false;
72527282

7283+
// Make sure we can store visited conformances.
7284+
VisitedConformances visited;
7285+
if (!visitedConformances)
7286+
visitedConformances = &visited;
7287+
72537288
Type type = declRef.getDecl()->getInterfaceType();
72547289
if (auto subs = declRef.getSubstitutions()) {
7255-
if (forEachConformance(subs, body))
7290+
if (forEachConformance(subs, body, visitedConformances))
72567291
return true;
72577292

72587293
type = type.subst(subs);
72597294
}
72607295

7261-
if (forEachConformance(type, body))
7296+
if (forEachConformance(type, body, visitedConformances))
72627297
return true;
72637298

72647299
return false;

lib/Sema/TypeCheckProtocol.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,15 @@ bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
240240
bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
241241
ValueDecl *requirement);
242242

243+
using VisitedConformances = llvm::SmallPtrSet<void *, 16>;
244+
243245
/// Visit each conformance within the given type.
244246
///
245247
/// If `body` returns true for any conformance, this function stops the
246248
/// traversal and returns true.
247249
bool forEachConformance(
248-
Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body);
250+
Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body,
251+
VisitedConformances *visitedConformances = nullptr);
249252

250253
/// Visit each conformance within the given conformance (including the given
251254
/// one).
@@ -254,23 +257,26 @@ bool forEachConformance(
254257
/// traversal and returns true.
255258
bool forEachConformance(
256259
ProtocolConformanceRef conformance,
257-
llvm::function_ref<bool(ProtocolConformanceRef)> body);
260+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
261+
VisitedConformances *visitedConformances = nullptr);
258262

259263
/// Visit each conformance within the given substitution map.
260264
///
261265
/// If `body` returns true for any conformance, this function stops the
262266
/// traversal and returns true.
263267
bool forEachConformance(
264268
SubstitutionMap subs,
265-
llvm::function_ref<bool(ProtocolConformanceRef)> body);
269+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
270+
VisitedConformances *visitedConformances = nullptr);
266271

267272
/// Visit each conformance within the given declaration reference.
268273
///
269274
/// If `body` returns true for any conformance, this function stops the
270275
/// traversal and returns true.
271276
bool forEachConformance(
272277
ConcreteDeclRef declRef,
273-
llvm::function_ref<bool(ProtocolConformanceRef)> body);
278+
llvm::function_ref<bool(ProtocolConformanceRef)> body,
279+
VisitedConformances *visitedConformances = nullptr);
274280

275281
}
276282

0 commit comments

Comments
 (0)