Skip to content

Commit eb2c4e1

Browse files
Fix memoryleak in filters.
1 parent 427f4bc commit eb2c4e1

File tree

3 files changed

+82
-61
lines changed

3 files changed

+82
-61
lines changed

Diff for: src/prompt_toolkit/filters/app.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@
4747
]
4848

4949

50-
@memoized()
50+
# NOTE: `has_focus` below should *not* be `memoized`. It can reference any user
51+
# control. For instance, if we would contiously create new
52+
# `PromptSession` instances, then previous instances won't be released,
53+
# because this memoize (which caches results in the global scope) will
54+
# still refer to each instance.
5155
def has_focus(value: "FocusableElement") -> Condition:
5256
"""
5357
Enable when this buffer has the focus.

Diff for: src/prompt_toolkit/filters/base.py

+51-60
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import weakref
12
from abc import ABCMeta, abstractmethod
2-
from typing import Callable, Dict, Iterable, List, Tuple, Union
3+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
34

45
__all__ = ["Filter", "Never", "Always", "Condition", "FilterOrBool"]
56

@@ -12,6 +13,15 @@ class Filter(metaclass=ABCMeta):
1213
The return value of ``__call__`` will tell if the feature should be active.
1314
"""
1415

16+
def __init__(self) -> None:
17+
self._and_cache: "weakref.WeakValueDictionary[Filter, _AndList]" = (
18+
weakref.WeakValueDictionary()
19+
)
20+
self._or_cache: "weakref.WeakValueDictionary[Filter, _OrList]" = (
21+
weakref.WeakValueDictionary()
22+
)
23+
self._invert_result: Optional[Filter] = None
24+
1525
@abstractmethod
1626
def __call__(self) -> bool:
1727
"""
@@ -23,19 +33,46 @@ def __and__(self, other: "Filter") -> "Filter":
2333
"""
2434
Chaining of filters using the & operator.
2535
"""
26-
return _and_cache[self, other]
36+
assert isinstance(other, Filter), "Expecting filter, got %r" % other
37+
38+
if isinstance(other, Always):
39+
return self
40+
if isinstance(other, Never):
41+
return other
42+
43+
if other in self._and_cache:
44+
return self._and_cache[other]
45+
46+
result = _AndList([self, other])
47+
self._and_cache[other] = result
48+
return result
2749

2850
def __or__(self, other: "Filter") -> "Filter":
2951
"""
3052
Chaining of filters using the | operator.
3153
"""
32-
return _or_cache[self, other]
54+
assert isinstance(other, Filter), "Expecting filter, got %r" % other
55+
56+
if isinstance(other, Always):
57+
return other
58+
if isinstance(other, Never):
59+
return self
60+
61+
if other in self._or_cache:
62+
return self._or_cache[other]
63+
64+
result = _OrList([self, other])
65+
self._or_cache[other] = result
66+
return result
3367

3468
def __invert__(self) -> "Filter":
3569
"""
3670
Inverting of filters using the ~ operator.
3771
"""
38-
return _invert_cache[self]
72+
if self._invert_result is None:
73+
self._invert_result = _Invert(self)
74+
75+
return self._invert_result
3976

4077
def __bool__(self) -> None:
4178
"""
@@ -52,68 +89,13 @@ def __bool__(self) -> None:
5289
)
5390

5491

55-
class _AndCache(Dict[Tuple[Filter, Filter], "_AndList"]):
56-
"""
57-
Cache for And operation between filters.
58-
(Filter classes are stateless, so we can reuse them.)
59-
60-
Note: This could be a memory leak if we keep creating filters at runtime.
61-
If that is True, the filters should be weakreffed (not the tuple of
62-
filters), and tuples should be removed when one of these filters is
63-
removed. In practise however, there is a finite amount of filters.
64-
"""
65-
66-
def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
67-
a, b = filters
68-
assert isinstance(b, Filter), "Expecting filter, got %r" % b
69-
70-
if isinstance(b, Always) or isinstance(a, Never):
71-
return a
72-
elif isinstance(b, Never) or isinstance(a, Always):
73-
return b
74-
75-
result = _AndList(filters)
76-
self[filters] = result
77-
return result
78-
79-
80-
class _OrCache(Dict[Tuple[Filter, Filter], "_OrList"]):
81-
"""Cache for Or operation between filters."""
82-
83-
def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
84-
a, b = filters
85-
assert isinstance(b, Filter), "Expecting filter, got %r" % b
86-
87-
if isinstance(b, Always) or isinstance(a, Never):
88-
return b
89-
elif isinstance(b, Never) or isinstance(a, Always):
90-
return a
91-
92-
result = _OrList(filters)
93-
self[filters] = result
94-
return result
95-
96-
97-
class _InvertCache(Dict[Filter, "_Invert"]):
98-
"""Cache for inversion operator."""
99-
100-
def __missing__(self, filter: Filter) -> Filter:
101-
result = _Invert(filter)
102-
self[filter] = result
103-
return result
104-
105-
106-
_and_cache = _AndCache()
107-
_or_cache = _OrCache()
108-
_invert_cache = _InvertCache()
109-
110-
11192
class _AndList(Filter):
11293
"""
11394
Result of &-operation between several filters.
11495
"""
11596

11697
def __init__(self, filters: Iterable[Filter]) -> None:
98+
super().__init__()
11799
self.filters: List[Filter] = []
118100

119101
for f in filters:
@@ -135,6 +117,7 @@ class _OrList(Filter):
135117
"""
136118

137119
def __init__(self, filters: Iterable[Filter]) -> None:
120+
super().__init__()
138121
self.filters: List[Filter] = []
139122

140123
for f in filters:
@@ -156,6 +139,7 @@ class _Invert(Filter):
156139
"""
157140

158141
def __init__(self, filter: Filter) -> None:
142+
super().__init__()
159143
self.filter = filter
160144

161145
def __call__(self) -> bool:
@@ -173,6 +157,9 @@ class Always(Filter):
173157
def __call__(self) -> bool:
174158
return True
175159

160+
def __or__(self, other: "Filter") -> "Filter":
161+
return self
162+
176163
def __invert__(self) -> "Never":
177164
return Never()
178165

@@ -185,6 +172,9 @@ class Never(Filter):
185172
def __call__(self) -> bool:
186173
return False
187174

175+
def __and__(self, other: "Filter") -> "Filter":
176+
return self
177+
188178
def __invert__(self) -> Always:
189179
return Always()
190180

@@ -204,6 +194,7 @@ def feature_is_active(): # `feature_is_active` becomes a Filter.
204194
"""
205195

206196
def __init__(self, func: Callable[[], bool]) -> None:
197+
super().__init__()
207198
self.func = func
208199

209200
def __call__(self) -> bool:

Diff for: tests/test_memory_leaks.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import gc
2+
from prompt_toolkit.shortcuts.prompt import PromptSession
3+
4+
5+
def _count_prompt_session_instances() -> int:
6+
# Run full GC collection first.
7+
gc.collect()
8+
9+
# Count number of remaining referenced `PromptSession` instances.
10+
objects = gc.get_objects()
11+
return len([obj for obj in objects if isinstance(obj, PromptSession)])
12+
13+
14+
def test_prompt_session_memory_leak() -> None:
15+
before_count = _count_prompt_session_instances()
16+
assert before_count == 0
17+
18+
p = PromptSession()
19+
20+
after_count = _count_prompt_session_instances()
21+
assert after_count == before_count + 1
22+
23+
del p
24+
25+
after_delete_count = _count_prompt_session_instances()
26+
assert after_delete_count == before_count

0 commit comments

Comments
 (0)