Skip to content

Fix memoryleak in filters. #1690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/prompt_toolkit/filters/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
]


@memoized()
# NOTE: `has_focus` below should *not* be `memoized`. It can reference any user
# control. For instance, if we would contiously create new
# `PromptSession` instances, then previous instances won't be released,
# because this memoize (which caches results in the global scope) will
# still refer to each instance.
def has_focus(value: "FocusableElement") -> Condition:
"""
Enable when this buffer has the focus.
Expand Down
111 changes: 51 additions & 60 deletions src/prompt_toolkit/filters/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import weakref
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Iterable, List, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

__all__ = ["Filter", "Never", "Always", "Condition", "FilterOrBool"]

Expand All @@ -12,6 +13,15 @@ class Filter(metaclass=ABCMeta):
The return value of ``__call__`` will tell if the feature should be active.
"""

def __init__(self) -> None:
self._and_cache: "weakref.WeakValueDictionary[Filter, _AndList]" = (
weakref.WeakValueDictionary()
)
self._or_cache: "weakref.WeakValueDictionary[Filter, _OrList]" = (
weakref.WeakValueDictionary()
)
self._invert_result: Optional[Filter] = None

@abstractmethod
def __call__(self) -> bool:
"""
Expand All @@ -23,19 +33,46 @@ def __and__(self, other: "Filter") -> "Filter":
"""
Chaining of filters using the & operator.
"""
return _and_cache[self, other]
assert isinstance(other, Filter), "Expecting filter, got %r" % other

if isinstance(other, Always):
return self
if isinstance(other, Never):
return other

if other in self._and_cache:
return self._and_cache[other]

result = _AndList([self, other])
self._and_cache[other] = result
return result

def __or__(self, other: "Filter") -> "Filter":
"""
Chaining of filters using the | operator.
"""
return _or_cache[self, other]
assert isinstance(other, Filter), "Expecting filter, got %r" % other

if isinstance(other, Always):
return other
if isinstance(other, Never):
return self

if other in self._or_cache:
return self._or_cache[other]

result = _OrList([self, other])
self._or_cache[other] = result
return result

def __invert__(self) -> "Filter":
"""
Inverting of filters using the ~ operator.
"""
return _invert_cache[self]
if self._invert_result is None:
self._invert_result = _Invert(self)

return self._invert_result

def __bool__(self) -> None:
"""
Expand All @@ -52,68 +89,13 @@ def __bool__(self) -> None:
)


class _AndCache(Dict[Tuple[Filter, Filter], "_AndList"]):
"""
Cache for And operation between filters.
(Filter classes are stateless, so we can reuse them.)
Note: This could be a memory leak if we keep creating filters at runtime.
If that is True, the filters should be weakreffed (not the tuple of
filters), and tuples should be removed when one of these filters is
removed. In practise however, there is a finite amount of filters.
"""

def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
a, b = filters
assert isinstance(b, Filter), "Expecting filter, got %r" % b

if isinstance(b, Always) or isinstance(a, Never):
return a
elif isinstance(b, Never) or isinstance(a, Always):
return b

result = _AndList(filters)
self[filters] = result
return result


class _OrCache(Dict[Tuple[Filter, Filter], "_OrList"]):
"""Cache for Or operation between filters."""

def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
a, b = filters
assert isinstance(b, Filter), "Expecting filter, got %r" % b

if isinstance(b, Always) or isinstance(a, Never):
return b
elif isinstance(b, Never) or isinstance(a, Always):
return a

result = _OrList(filters)
self[filters] = result
return result


class _InvertCache(Dict[Filter, "_Invert"]):
"""Cache for inversion operator."""

def __missing__(self, filter: Filter) -> Filter:
result = _Invert(filter)
self[filter] = result
return result


_and_cache = _AndCache()
_or_cache = _OrCache()
_invert_cache = _InvertCache()


class _AndList(Filter):
"""
Result of &-operation between several filters.
"""

def __init__(self, filters: Iterable[Filter]) -> None:
super().__init__()
self.filters: List[Filter] = []

for f in filters:
Expand All @@ -135,6 +117,7 @@ class _OrList(Filter):
"""

def __init__(self, filters: Iterable[Filter]) -> None:
super().__init__()
self.filters: List[Filter] = []

for f in filters:
Expand All @@ -156,6 +139,7 @@ class _Invert(Filter):
"""

def __init__(self, filter: Filter) -> None:
super().__init__()
self.filter = filter

def __call__(self) -> bool:
Expand All @@ -173,6 +157,9 @@ class Always(Filter):
def __call__(self) -> bool:
return True

def __or__(self, other: "Filter") -> "Filter":
return self

def __invert__(self) -> "Never":
return Never()

Expand All @@ -185,6 +172,9 @@ class Never(Filter):
def __call__(self) -> bool:
return False

def __and__(self, other: "Filter") -> "Filter":
return self

def __invert__(self) -> Always:
return Always()

Expand All @@ -204,6 +194,7 @@ def feature_is_active(): # `feature_is_active` becomes a Filter.
"""

def __init__(self, func: Callable[[], bool]) -> None:
super().__init__()
self.func = func

def __call__(self) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_memory_leaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import gc

import pytest

from prompt_toolkit.shortcuts.prompt import PromptSession


def _count_prompt_session_instances() -> int:
# Run full GC collection first.
gc.collect()

# Count number of remaining referenced `PromptSession` instances.
objects = gc.get_objects()
return len([obj for obj in objects if isinstance(obj, PromptSession)])


# Fails in GitHub CI, probably due to GC differences.
@pytest.mark.xfail(reason="Memory leak testing fails in GitHub CI.")
def test_prompt_session_memory_leak() -> None:
before_count = _count_prompt_session_instances()

# Somehow in CI/CD, the before_count is > 0
assert before_count == 0

p = PromptSession()

after_count = _count_prompt_session_instances()
assert after_count == before_count + 1

del p

after_delete_count = _count_prompt_session_instances()
assert after_delete_count == before_count