diff --git a/src/prompt_toolkit/filters/app.py b/src/prompt_toolkit/filters/app.py index dcc3fc0c6..a850ec0aa 100644 --- a/src/prompt_toolkit/filters/app.py +++ b/src/prompt_toolkit/filters/app.py @@ -47,7 +47,11 @@ ] -@memoized() +# NOTE: `has_focus` below should *not* be `memoized`. It can reference any user +# control. For instance, if we would contiously create new +# `PromptSession` instances, then previous instances won't be released, +# because this memoize (which caches results in the global scope) will +# still refer to each instance. def has_focus(value: "FocusableElement") -> Condition: """ Enable when this buffer has the focus. diff --git a/src/prompt_toolkit/filters/base.py b/src/prompt_toolkit/filters/base.py index fd57cca6e..db20c0796 100644 --- a/src/prompt_toolkit/filters/base.py +++ b/src/prompt_toolkit/filters/base.py @@ -1,5 +1,6 @@ +import weakref from abc import ABCMeta, abstractmethod -from typing import Callable, Dict, Iterable, List, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union __all__ = ["Filter", "Never", "Always", "Condition", "FilterOrBool"] @@ -12,6 +13,15 @@ class Filter(metaclass=ABCMeta): The return value of ``__call__`` will tell if the feature should be active. """ + def __init__(self) -> None: + self._and_cache: "weakref.WeakValueDictionary[Filter, _AndList]" = ( + weakref.WeakValueDictionary() + ) + self._or_cache: "weakref.WeakValueDictionary[Filter, _OrList]" = ( + weakref.WeakValueDictionary() + ) + self._invert_result: Optional[Filter] = None + @abstractmethod def __call__(self) -> bool: """ @@ -23,19 +33,46 @@ def __and__(self, other: "Filter") -> "Filter": """ Chaining of filters using the & operator. """ - return _and_cache[self, other] + assert isinstance(other, Filter), "Expecting filter, got %r" % other + + if isinstance(other, Always): + return self + if isinstance(other, Never): + return other + + if other in self._and_cache: + return self._and_cache[other] + + result = _AndList([self, other]) + self._and_cache[other] = result + return result def __or__(self, other: "Filter") -> "Filter": """ Chaining of filters using the | operator. """ - return _or_cache[self, other] + assert isinstance(other, Filter), "Expecting filter, got %r" % other + + if isinstance(other, Always): + return other + if isinstance(other, Never): + return self + + if other in self._or_cache: + return self._or_cache[other] + + result = _OrList([self, other]) + self._or_cache[other] = result + return result def __invert__(self) -> "Filter": """ Inverting of filters using the ~ operator. """ - return _invert_cache[self] + if self._invert_result is None: + self._invert_result = _Invert(self) + + return self._invert_result def __bool__(self) -> None: """ @@ -52,68 +89,13 @@ def __bool__(self) -> None: ) -class _AndCache(Dict[Tuple[Filter, Filter], "_AndList"]): - """ - Cache for And operation between filters. - (Filter classes are stateless, so we can reuse them.) - - Note: This could be a memory leak if we keep creating filters at runtime. - If that is True, the filters should be weakreffed (not the tuple of - filters), and tuples should be removed when one of these filters is - removed. In practise however, there is a finite amount of filters. - """ - - def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter: - a, b = filters - assert isinstance(b, Filter), "Expecting filter, got %r" % b - - if isinstance(b, Always) or isinstance(a, Never): - return a - elif isinstance(b, Never) or isinstance(a, Always): - return b - - result = _AndList(filters) - self[filters] = result - return result - - -class _OrCache(Dict[Tuple[Filter, Filter], "_OrList"]): - """Cache for Or operation between filters.""" - - def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter: - a, b = filters - assert isinstance(b, Filter), "Expecting filter, got %r" % b - - if isinstance(b, Always) or isinstance(a, Never): - return b - elif isinstance(b, Never) or isinstance(a, Always): - return a - - result = _OrList(filters) - self[filters] = result - return result - - -class _InvertCache(Dict[Filter, "_Invert"]): - """Cache for inversion operator.""" - - def __missing__(self, filter: Filter) -> Filter: - result = _Invert(filter) - self[filter] = result - return result - - -_and_cache = _AndCache() -_or_cache = _OrCache() -_invert_cache = _InvertCache() - - class _AndList(Filter): """ Result of &-operation between several filters. """ def __init__(self, filters: Iterable[Filter]) -> None: + super().__init__() self.filters: List[Filter] = [] for f in filters: @@ -135,6 +117,7 @@ class _OrList(Filter): """ def __init__(self, filters: Iterable[Filter]) -> None: + super().__init__() self.filters: List[Filter] = [] for f in filters: @@ -156,6 +139,7 @@ class _Invert(Filter): """ def __init__(self, filter: Filter) -> None: + super().__init__() self.filter = filter def __call__(self) -> bool: @@ -173,6 +157,9 @@ class Always(Filter): def __call__(self) -> bool: return True + def __or__(self, other: "Filter") -> "Filter": + return self + def __invert__(self) -> "Never": return Never() @@ -185,6 +172,9 @@ class Never(Filter): def __call__(self) -> bool: return False + def __and__(self, other: "Filter") -> "Filter": + return self + def __invert__(self) -> Always: return Always() @@ -204,6 +194,7 @@ def feature_is_active(): # `feature_is_active` becomes a Filter. """ def __init__(self, func: Callable[[], bool]) -> None: + super().__init__() self.func = func def __call__(self) -> bool: diff --git a/tests/test_memory_leaks.py b/tests/test_memory_leaks.py new file mode 100644 index 000000000..0ad392a4b --- /dev/null +++ b/tests/test_memory_leaks.py @@ -0,0 +1,33 @@ +import gc + +import pytest + +from prompt_toolkit.shortcuts.prompt import PromptSession + + +def _count_prompt_session_instances() -> int: + # Run full GC collection first. + gc.collect() + + # Count number of remaining referenced `PromptSession` instances. + objects = gc.get_objects() + return len([obj for obj in objects if isinstance(obj, PromptSession)]) + + +# Fails in GitHub CI, probably due to GC differences. +@pytest.mark.xfail(reason="Memory leak testing fails in GitHub CI.") +def test_prompt_session_memory_leak() -> None: + before_count = _count_prompt_session_instances() + + # Somehow in CI/CD, the before_count is > 0 + assert before_count == 0 + + p = PromptSession() + + after_count = _count_prompt_session_instances() + assert after_count == before_count + 1 + + del p + + after_delete_count = _count_prompt_session_instances() + assert after_delete_count == before_count