Skip to content

Commit 44d0cae

Browse files
author
Franz Király
authored
[ENH] refactor - move StdoutMute context manager to utils (#338)
Minor refactor in anticipation of external and further internal usage - moves `StdoutMute` context manager to a submodule of `utils`. Also refactors manual `stdout` handling in `_check_soft_dependencies` to use the utility.
1 parent 7dfd670 commit 44d0cae

File tree

4 files changed

+95
-46
lines changed

4 files changed

+95
-46
lines changed

skbase/lookup/_lookup.py

+24-38
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
# https://github.com/sktime/sktime/blob/main/LICENSE
1717
import importlib
1818
import inspect
19-
import io
2019
import os
2120
import pathlib
2221
import pkgutil
2322
import re
24-
import sys
2523
import warnings
2624
from collections.abc import Iterable
2725
from copy import deepcopy
@@ -31,6 +29,7 @@
3129
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
3230

3331
from skbase.base import BaseObject
32+
from skbase.utils.stdout_mute import StdoutMute
3433
from skbase.validate import check_sequence
3534

3635
__all__: List[str] = ["all_objects", "get_package_metadata"]
@@ -335,7 +334,7 @@ def _import_module(
335334

336335
# if suppress_import_stdout:
337336
# setup text trap, import
338-
with StdoutMute(active=suppress_import_stdout):
337+
with StdoutMuteNCatchMNF(active=suppress_import_stdout):
339338
if isinstance(module, str):
340339
imported_mod = importlib.import_module(module)
341340
elif isinstance(module, importlib.machinery.SourceFileLoader):
@@ -865,7 +864,7 @@ class name if ``return_names=False`` and ``return_tags is not None``.
865864
obj_types = _check_object_types(object_types, class_lookup)
866865

867866
# Ignore deprecation warnings triggered at import time and from walking packages
868-
with warnings.catch_warnings(), StdoutMute(active=suppress_import_stdout):
867+
with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
869868
warnings.simplefilter("ignore", category=FutureWarning)
870869
warnings.simplefilter("module", category=ImportWarning)
871870
warnings.filterwarnings(
@@ -1025,7 +1024,7 @@ def _make_dataframe(all_objects, columns):
10251024
return pd.DataFrame(all_objects, columns=columns)
10261025

10271026

1028-
class StdoutMute:
1027+
class StdoutMuteNCatchMNF(StdoutMute):
10291028
"""A context manager to suppress stdout.
10301029
10311030
This class is used to suppress stdout when importing modules.
@@ -1042,39 +1041,26 @@ class StdoutMute:
10421041
except catch and suppress ModuleNotFoundError.
10431042
"""
10441043

1045-
def __init__(self, active=True):
1046-
self.active = active
1047-
1048-
def __enter__(self):
1049-
"""Context manager entry point."""
1050-
# capture stdout if active
1051-
# store the original stdout so it can be restored in __exit__
1052-
if self.active:
1053-
self._stdout = sys.stdout
1054-
sys.stdout = io.StringIO()
1055-
1056-
def __exit__(self, type, value, traceback): # noqa: A002
1057-
"""Context manager exit point."""
1058-
# restore stdout if active
1059-
# if not active, nothing needs to be done, since stdout was not replaced
1060-
if self.active:
1061-
sys.stdout = self._stdout
1062-
1063-
if type is not None:
1064-
# if a ModuleNotFoundError is raised,
1065-
# we suppress to a warning if "soft dependency" is in the error message
1066-
# otherwise, raise
1067-
if type is ModuleNotFoundError:
1068-
if "soft dependency" not in str(value):
1069-
return False
1070-
warnings.warn(str(value), ImportWarning, stacklevel=2)
1071-
return True
1072-
1073-
# all other exceptions are raised
1074-
return False
1075-
# if no exception was raised, return True to indicate successful exit
1076-
# return statement not needed as type was None, but included for clarity
1077-
return True
1044+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
1045+
"""Handle exceptions raised during __exit__.
1046+
1047+
Parameters
1048+
----------
1049+
type : type
1050+
The type of the exception raised.
1051+
Known to be not-None and Exception subtype when this method is called.
1052+
"""
1053+
# if a ModuleNotFoundError is raised,
1054+
# we suppress to a warning if "soft dependency" is in the error message
1055+
# otherwise, raise
1056+
if type is ModuleNotFoundError:
1057+
if "soft dependency" not in str(value):
1058+
return False
1059+
warnings.warn(str(value), ImportWarning, stacklevel=2)
1060+
return True
1061+
1062+
# all other exceptions are raised
1063+
return False
10781064

10791065

10801066
def _coerce_to_tuple(x):

skbase/tests/conftest.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"skbase.utils.dependencies",
5555
"skbase.utils.dependencies._dependencies",
5656
"skbase.utils.random_state",
57+
"skbase.utils.stdout_mute",
5758
"skbase.validate",
5859
"skbase.validate._named_objects",
5960
"skbase.validate._types",
@@ -79,6 +80,7 @@
7980
"skbase.utils.deep_equals",
8081
"skbase.utils.dependencies",
8182
"skbase.utils.random_state",
83+
"skbase.utils.stdout_mute",
8284
"skbase.validate",
8385
)
8486
SKBASE_PUBLIC_CLASSES_BY_MODULE = {
@@ -99,13 +101,14 @@
99101
"BaseMetaEstimatorMixin",
100102
),
101103
"skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
102-
"skbase.lookup._lookup": ("StdoutMute",),
104+
"skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
103105
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
104106
"skbase.testing.test_all_objects": (
105107
"BaseFixtureGenerator",
106108
"QuickTester",
107109
"TestAllObjects",
108110
),
111+
"skbase.utils.stdout_mute": ("StdoutMute",),
109112
}
110113
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
111114
SKBASE_CLASSES_BY_MODULE.update(

skbase/utils/dependencies/_dependencies.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22
"""Utility to check soft dependency imports, and raise warnings or errors."""
3-
import io
43
import sys
54
import warnings
65
from importlib import import_module
@@ -10,6 +9,8 @@
109
from packaging.requirements import InvalidRequirement, Requirement
1110
from packaging.specifiers import InvalidSpecifier, SpecifierSet
1211

12+
from skbase.utils.stdout_mute import StdoutMute
13+
1314
__author__: List[str] = ["fkiraly", "mloning"]
1415

1516

@@ -130,12 +131,7 @@ def _check_soft_dependencies(
130131
package_import_name = package_name
131132
# attempt import - if not possible, we know we need to raise warning/exception
132133
try:
133-
if suppress_import_stdout:
134-
# setup text trap, import, then restore
135-
sys.stdout = io.StringIO()
136-
pkg_ref = import_module(package_import_name)
137-
sys.stdout = sys.__stdout__
138-
else:
134+
with StdoutMute(active=suppress_import_stdout):
139135
pkg_ref = import_module(package_import_name)
140136
# if package cannot be imported, make the user aware of installation requirement
141137
except ModuleNotFoundError as e:

skbase/utils/stdout_mute.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# -*- coding: utf-8 -*-
2+
"""Context manager to suppress stdout."""
3+
4+
__author__ = ["fkiraly"]
5+
6+
import io
7+
import sys
8+
9+
10+
class StdoutMute:
11+
"""A context manager to suppress stdout.
12+
13+
Exception handling on exit can be customized by overriding
14+
the ``_handle_exit_exceptions`` method.
15+
16+
Parameters
17+
----------
18+
active : bool, default=True
19+
Whether to suppress stdout or not.
20+
If True, stdout is suppressed.
21+
If False, stdout is not suppressed, and the context manager does nothing
22+
except catch and suppress ModuleNotFoundError.
23+
"""
24+
25+
def __init__(self, active=True):
26+
self.active = active
27+
28+
def __enter__(self):
29+
"""Context manager entry point."""
30+
# capture stdout if active
31+
# store the original stdout so it can be restored in __exit__
32+
if self.active:
33+
self._stdout = sys.stdout
34+
sys.stdout = io.StringIO()
35+
36+
def __exit__(self, type, value, traceback): # noqa: A002
37+
"""Context manager exit point."""
38+
# restore stdout if active
39+
# if not active, nothing needs to be done, since stdout was not replaced
40+
if self.active:
41+
sys.stdout = self._stdout
42+
43+
if type is not None:
44+
return self._handle_exit_exceptions(type, value, traceback)
45+
46+
# if no exception was raised, return True to indicate successful exit
47+
# return statement not needed as type was None, but included for clarity
48+
return True
49+
50+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
51+
"""Handle exceptions raised during __exit__.
52+
53+
Parameters
54+
----------
55+
type : type
56+
The type of the exception raised.
57+
Known to be not-None and Exception subtype when this method is called.
58+
value : Exception
59+
The exception instance raised.
60+
traceback : traceback
61+
The traceback object associated with the exception.
62+
"""
63+
# by default, all exceptions are raised
64+
return False

0 commit comments

Comments
 (0)