16
16
# https://github.com/sktime/sktime/blob/main/LICENSE
17
17
import importlib
18
18
import inspect
19
- import io
20
19
import os
21
20
import pathlib
22
21
import pkgutil
23
22
import re
24
- import sys
25
23
import warnings
26
24
from collections .abc import Iterable
27
25
from copy import deepcopy
31
29
from typing import Any , List , Mapping , MutableMapping , Optional , Sequence , Tuple , Union
32
30
33
31
from skbase .base import BaseObject
32
+ from skbase .utils .stdout_mute import StdoutMute
34
33
from skbase .validate import check_sequence
35
34
36
35
__all__ : List [str ] = ["all_objects" , "get_package_metadata" ]
@@ -335,7 +334,7 @@ def _import_module(
335
334
336
335
# if suppress_import_stdout:
337
336
# setup text trap, import
338
- with StdoutMute (active = suppress_import_stdout ):
337
+ with StdoutMuteNCatchMNF (active = suppress_import_stdout ):
339
338
if isinstance (module , str ):
340
339
imported_mod = importlib .import_module (module )
341
340
elif isinstance (module , importlib .machinery .SourceFileLoader ):
@@ -865,7 +864,7 @@ class name if ``return_names=False`` and ``return_tags is not None``.
865
864
obj_types = _check_object_types (object_types , class_lookup )
866
865
867
866
# Ignore deprecation warnings triggered at import time and from walking packages
868
- with warnings .catch_warnings (), StdoutMute (active = suppress_import_stdout ):
867
+ with warnings .catch_warnings (), StdoutMuteNCatchMNF (active = suppress_import_stdout ):
869
868
warnings .simplefilter ("ignore" , category = FutureWarning )
870
869
warnings .simplefilter ("module" , category = ImportWarning )
871
870
warnings .filterwarnings (
@@ -1025,7 +1024,7 @@ def _make_dataframe(all_objects, columns):
1025
1024
return pd .DataFrame (all_objects , columns = columns )
1026
1025
1027
1026
1028
- class StdoutMute :
1027
+ class StdoutMuteNCatchMNF ( StdoutMute ) :
1029
1028
"""A context manager to suppress stdout.
1030
1029
1031
1030
This class is used to suppress stdout when importing modules.
@@ -1042,39 +1041,26 @@ class StdoutMute:
1042
1041
except catch and suppress ModuleNotFoundError.
1043
1042
"""
1044
1043
1045
- def __init__ (self , active = True ):
1046
- self .active = active
1047
-
1048
- def __enter__ (self ):
1049
- """Context manager entry point."""
1050
- # capture stdout if active
1051
- # store the original stdout so it can be restored in __exit__
1052
- if self .active :
1053
- self ._stdout = sys .stdout
1054
- sys .stdout = io .StringIO ()
1055
-
1056
- def __exit__ (self , type , value , traceback ): # noqa: A002
1057
- """Context manager exit point."""
1058
- # restore stdout if active
1059
- # if not active, nothing needs to be done, since stdout was not replaced
1060
- if self .active :
1061
- sys .stdout = self ._stdout
1062
-
1063
- if type is not None :
1064
- # if a ModuleNotFoundError is raised,
1065
- # we suppress to a warning if "soft dependency" is in the error message
1066
- # otherwise, raise
1067
- if type is ModuleNotFoundError :
1068
- if "soft dependency" not in str (value ):
1069
- return False
1070
- warnings .warn (str (value ), ImportWarning , stacklevel = 2 )
1071
- return True
1072
-
1073
- # all other exceptions are raised
1074
- return False
1075
- # if no exception was raised, return True to indicate successful exit
1076
- # return statement not needed as type was None, but included for clarity
1077
- return True
1044
+ def _handle_exit_exceptions (self , type , value , traceback ): # noqa: A002
1045
+ """Handle exceptions raised during __exit__.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ type : type
1050
+ The type of the exception raised.
1051
+ Known to be not-None and Exception subtype when this method is called.
1052
+ """
1053
+ # if a ModuleNotFoundError is raised,
1054
+ # we suppress to a warning if "soft dependency" is in the error message
1055
+ # otherwise, raise
1056
+ if type is ModuleNotFoundError :
1057
+ if "soft dependency" not in str (value ):
1058
+ return False
1059
+ warnings .warn (str (value ), ImportWarning , stacklevel = 2 )
1060
+ return True
1061
+
1062
+ # all other exceptions are raised
1063
+ return False
1078
1064
1079
1065
1080
1066
def _coerce_to_tuple (x ):
0 commit comments