Skip to content

Commit 09b8bb2

Browse files
[MNT] Improve multithreading testing (#2317)
* multithreading testing * set multithreading tag * tags * fix * tags * param * test params * test * Update _base.py * Update _base.py
1 parent 108948d commit 09b8bb2

30 files changed

+160
-27
lines changed

aeon/anomaly_detection/_copod.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class COPOD(PyODAdapter):
4343
"capability:multivariate": True,
4444
"capability:univariate": True,
4545
"capability:missing_values": False,
46+
"capability:multithreading": True,
4647
"fit_is_empty": False,
4748
"python_dependencies": ["pyod"],
4849
}

aeon/anomaly_detection/_iforest.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class IsolationForest(PyODAdapter):
8888
"capability:multivariate": True,
8989
"capability:univariate": True,
9090
"capability:missing_values": False,
91+
"capability:multithreading": True,
9192
"fit_is_empty": False,
9293
"python_dependencies": ["pyod"],
9394
}

aeon/anomaly_detection/tests/test_left_stampi.py

-12
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,3 @@ def test_the_number_of_distances_k_defaults_to_1_and_can_be_changed(
309309
],
310310
any_order=True,
311311
)
312-
313-
def test_it_checks_soft_dependencies(self, mocker):
314-
"""Unit testing the dependency check."""
315-
# given
316-
deps_checker_stub = mocker.patch(
317-
"aeon.base._base_series._check_estimator_deps", return_value=True
318-
)
319-
# deps_checker_stub.return_value = True
320-
ad = LeftSTAMPi(window_size=5, n_init_train=10)
321-
322-
# then
323-
deps_checker_stub.assert_called_once_with(ad)

aeon/base/_base_collection.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222

2323

2424
class BaseCollectionEstimator(BaseAeonEstimator):
25-
"""Base class for estimators that use collections of time series for method fit.
25+
"""Base class for estimators that use collections of time series for ``fit``.
2626
27-
Provides functions that are common to BaseClassifier, BaseRegressor,
28-
BaseClusterer and BaseCollectionTransformer for the checking and
29-
conversion of input to fit, predict and predict_proba, where relevant.
27+
Provides functions that are common to estimators which use colections such as
28+
``BaseClassifier``, ``BaseRegressor``, ``BaseClusterer``, ``BaseSimilaritySearch``
29+
and ``BaseCollectionTransformer``. Functionality includes checking and
30+
conversion of input to ``fit, predict and predict_proba, where relevant.
3031
3132
It also stores the common default tags used by all the subclasses and meta data
3233
describing the characteristics of time series passed to ``fit``.

aeon/classification/distance_based/_time_series_neighbors.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class KNeighborsTimeSeriesClassifier(BaseClassifier):
6767
_tags = {
6868
"capability:multivariate": True,
6969
"capability:unequal_length": True,
70+
"capability:multithreading": True,
7071
"X_inner_type": ["np-list", "numpy3D"],
7172
"algorithm_type": "distance",
7273
}

aeon/classification/shapelet_based/_rsast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
nb_inst_per_class=10,
7171
seed=None,
7272
classifier=None,
73-
n_jobs=-1,
73+
n_jobs=1,
7474
):
7575
super().__init__()
7676
self.n_random_points = n_random_points

aeon/classification/shapelet_based/_sast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
nb_inst_per_class: int = 1,
7474
seed: Optional[int] = None,
7575
classifier=None,
76-
n_jobs: int = -1,
76+
n_jobs: int = 1,
7777
) -> None:
7878
super().__init__()
7979
self.length_list = length_list

aeon/clustering/_kernel_k_means.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class TimeSeriesKernelKMeans(BaseClusterer):
8585

8686
_tags = {
8787
"capability:multivariate": True,
88+
"capability:multithreading": True,
8889
"python_dependencies": "tslearn",
8990
}
9091

@@ -97,7 +98,7 @@ def __init__(
9798
tol: float = 1e-4,
9899
kernel_params: Union[dict, None] = None,
99100
verbose: bool = False,
100-
n_jobs: Union[int, None] = None,
101+
n_jobs: Union[int, None] = 1,
101102
random_state: Optional[Union[int, RandomState]] = None,
102103
):
103104
self.kernel = kernel
@@ -200,8 +201,4 @@ def _get_test_params(cls, parameter_set="default") -> dict:
200201
"n_init": 1,
201202
"max_iter": 1,
202203
"tol": 0.0001,
203-
"kernel_params": None,
204-
"verbose": False,
205-
"n_jobs": 1,
206-
"random_state": 1,
207204
}

aeon/regression/distance_based/_time_series_neighbors.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class KNeighborsTimeSeriesRegressor(BaseRegressor):
6767
_tags = {
6868
"capability:multivariate": True,
6969
"capability:unequal_length": True,
70+
"capability:multithreading": True,
7071
"X_inner_type": ["np-list", "numpy3D"],
7172
"algorithm_type": "distance",
7273
}

aeon/segmentation/_clasp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class ClaSPSegmenter(BaseSegmenter):
206206
>>> scores = clasp.scores
207207
"""
208208

209-
_tags = {"fit_is_empty": True} # for unit test cases
209+
_tags = {"capability:multithreading": True, "fit_is_empty": True}
210210

211211
def __init__(self, period_length=10, n_cps=1, exclusion_radius=0.05, n_jobs=1):
212212
self.period_length = int(period_length)

aeon/testing/estimator_checking/_yield_classification_checks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes
7474
# test class instances
7575
for i, estimator in enumerate(estimator_instances):
7676
# data type irrelevant
77-
if _get_tag(estimator_class, "capability:train_estimate", raise_error=True):
77+
if _get_tag(estimator, "capability:train_estimate", raise_error=True):
7878
yield partial(
7979
check_classifier_train_estimate,
8080
estimator=estimator,

aeon/testing/estimator_checking/_yield_estimator_checks.py

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
from aeon.testing.estimator_checking._yield_early_classification_checks import (
4141
_yield_early_classification_checks,
4242
)
43+
from aeon.testing.estimator_checking._yield_multithreading_checks import (
44+
_yield_multithreading_checks,
45+
)
4346
from aeon.testing.estimator_checking._yield_regression_checks import (
4447
_yield_regression_checks,
4548
)
@@ -116,6 +119,10 @@ def _yield_all_aeon_checks(
116119
estimator_class, estimator_instances, datatypes
117120
)
118121

122+
yield from _yield_multithreading_checks(
123+
estimator_class, estimator_instances, datatypes
124+
)
125+
119126
if issubclass(estimator_class, BaseClassifier):
120127
yield from _yield_classification_checks(
121128
estimator_class, estimator_instances, datatypes
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import inspect
2+
from functools import partial
3+
4+
from numpy.testing import assert_array_almost_equal
5+
6+
from aeon.base._base import _clone_estimator
7+
from aeon.testing.testing_config import (
8+
MULTITHREAD_TESTING,
9+
NON_STATE_CHANGING_METHODS_ARRAYLIKE,
10+
)
11+
from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method
12+
from aeon.utils.validation import check_n_jobs
13+
14+
15+
def _yield_multithreading_checks(estimator_class, estimator_instances, datatypes):
16+
"""Yield all multithreading checks for an aeon estimator."""
17+
can_thread = _get_tag(estimator_class, "capability:multithreading")
18+
19+
# only class required
20+
if can_thread:
21+
yield partial(check_multithreading_param, estimator_class=estimator_class)
22+
else:
23+
yield partial(check_no_multithreading_param, estimator_class=estimator_class)
24+
25+
if can_thread and MULTITHREAD_TESTING:
26+
# test class instances
27+
for i, estimator in enumerate(estimator_instances):
28+
# test all data types
29+
for datatype in datatypes[i]:
30+
yield partial(
31+
check_estimator_multithreading,
32+
estimator=estimator,
33+
datatype=datatype,
34+
)
35+
36+
37+
def check_multithreading_param(estimator_class):
38+
"""Test that estimators that can multithread have a n_jobs parameter."""
39+
default_params = inspect.signature(estimator_class.__init__).parameters
40+
n_jobs = default_params.get("n_jobs", None)
41+
42+
# check that the estimator has a n_jobs parameter
43+
if n_jobs is None:
44+
raise ValueError(
45+
f"{estimator_class} which sets "
46+
"capability:multithreading=True must have a n_jobs parameter."
47+
)
48+
49+
# check that the default value is to use 1 thread
50+
if n_jobs.default != 1:
51+
raise ValueError(
52+
"n_jobs parameter must have a default value of 1, "
53+
"disabling multithreading by default."
54+
)
55+
56+
# test parameters should not change the default value
57+
params = estimator_class._get_test_params()
58+
if not isinstance(params, list):
59+
params = [params]
60+
for param_set in params:
61+
assert "n_jobs" not in param_set
62+
63+
64+
def check_no_multithreading_param(estimator_class):
65+
"""Test that estimators that cant multithread have no n_jobs parameter."""
66+
default_params = inspect.signature(estimator_class.__init__).parameters
67+
68+
# check that the estimator does not have a n_jobs parameter
69+
if default_params.get("n_jobs", None) is not None:
70+
raise ValueError(
71+
f"{estimator_class} has a n_jobs parameter, but does not set "
72+
"capability:multithreading=True in its tags."
73+
)
74+
75+
76+
def check_estimator_multithreading(estimator, datatype):
77+
"""Test that multithreaded estimators store n_jobs_ and produce same results."""
78+
st_estimator = _clone_estimator(estimator, random_state=42)
79+
mt_estimator = _clone_estimator(estimator, random_state=42)
80+
n_jobs = max(2, check_n_jobs(-2))
81+
mt_estimator.set_params(n_jobs=n_jobs)
82+
83+
# fit and get results for single thread estimator
84+
_run_estimator_method(st_estimator, "fit", datatype, "train")
85+
86+
results = []
87+
for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE:
88+
if hasattr(st_estimator, method) and callable(getattr(estimator, method)):
89+
output = _run_estimator_method(st_estimator, method, datatype, "test")
90+
results.append(output)
91+
92+
# fit multithreaded estimator
93+
_run_estimator_method(mt_estimator, "fit", datatype, "train")
94+
95+
# check n_jobs_ attribute is set
96+
assert mt_estimator.n_jobs_ == n_jobs, (
97+
f"Multithreaded estimator {mt_estimator} does not store n_jobs_ "
98+
f"attribute correctly. Expected {n_jobs}, got {mt_estimator.n_jobs_}."
99+
f"It is recommended to use the check_n_jobs function to set n_jobs_ and use"
100+
f"this for any multithreading."
101+
)
102+
103+
# compare results from single and multithreaded estimators
104+
i = 0
105+
for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE:
106+
if hasattr(estimator, method) and callable(getattr(estimator, method)):
107+
output = _run_estimator_method(estimator, method, datatype, "test")
108+
109+
assert_array_almost_equal(
110+
output,
111+
results[i],
112+
err_msg=f"Running {method} after fit twice with test "
113+
f"parameters gives different results.",
114+
)
115+
i += 1

aeon/testing/testing_config.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
# whether to use smaller parameter matrices for test generation and subsample estimators
99
# per os/version default is False, can be set to True by pytest --prtesting True flag
1010
PR_TESTING = False
11+
# whether to use multithreading in tests, can be set to True by pytest
12+
# --enablethreading True flag
13+
MULTITHREAD_TESTING = False
1114

1215
# exclude estimators here for short term fixes
1316
EXCLUDE_ESTIMATORS = [

aeon/transformations/collection/convolution_based/_minirocket.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class MiniRocket(BaseCollectionTransformer):
7474
"output_data_type": "Tabular",
7575
"algorithm_type": "convolution",
7676
"capability:multivariate": True,
77+
"capability:multithreading": True,
7778
}
7879
# indices for the 84 kernels used by MiniRocket
7980
_indices = np.array([_ for _ in combinations(np.arange(9), 3)], dtype=np.int32)

aeon/transformations/collection/convolution_based/_minirocket_mv.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class MiniRocketMultivariateVariable(BaseCollectionTransformer):
8787
"output_data_type": "Tabular",
8888
"capability:multivariate": True,
8989
"capability:unequal_length": True,
90+
"capability:multithreading": True,
9091
"X_inner_type": "np-list",
9192
"algorithm_type": "convolution",
9293
}

aeon/transformations/collection/convolution_based/_multirocket.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class MultiRocket(BaseCollectionTransformer):
7474
"output_data_type": "Tabular",
7575
"algorithm_type": "convolution",
7676
"capability:multivariate": True,
77+
"capability:multithreading": True,
7778
}
7879
# indices for the 84 kernels used by MiniRocket
7980
_indices = np.array([_ for _ in combinations(np.arange(9), 3)], dtype=np.int32)

aeon/transformations/collection/convolution_based/_rocket.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class Rocket(BaseCollectionTransformer):
6565
_tags = {
6666
"output_data_type": "Tabular",
6767
"capability:multivariate": True,
68+
"capability:multithreading": True,
6869
"algorithm_type": "convolution",
6970
}
7071

aeon/transformations/collection/dictionary_based/_sfa.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class SFA(BaseCollectionTransformer):
110110

111111
_tags = {
112112
"requires_y": False, # SFA is unsupervised for equi-depth and equi-width bins
113+
"capability:multithreading": True,
113114
"algorithm_type": "dictionary",
114115
}
115116

aeon/transformations/collection/dictionary_based/_sfa_fast.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class SFAFast(BaseCollectionTransformer):
130130

131131
_tags = {
132132
"requires_y": False, # SFA is unsupervised for equi-depth and equi-width bins
133+
"capability:multithreading": True,
133134
"algorithm_type": "dictionary",
134135
}
135136

aeon/transformations/collection/feature_based/_catch22.py

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class Catch22(BaseCollectionTransformer):
173173
"X_inner_type": ["np-list", "numpy3D"],
174174
"capability:unequal_length": True,
175175
"capability:multivariate": True,
176+
"capability:multithreading": True,
176177
"fit_is_empty": True,
177178
}
178179

aeon/transformations/collection/feature_based/_tsfresh.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class _TSFresh(BaseCollectionTransformer):
3535
_tags = {
3636
"output_data_type": "Tabular",
3737
"capability:multivariate": True,
38+
"capability:multithreading": True,
3839
"fit_is_empty": True,
3940
"python_dependencies": "tsfresh",
4041
}

aeon/transformations/collection/interval_based/_random_intervals.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class RandomIntervals(BaseCollectionTransformer):
110110
_tags = {
111111
"output_data_type": "Tabular",
112112
"capability:multivariate": True,
113+
"capability:multithreading": True,
113114
"fit_is_empty": False,
114115
"algorithm_type": "interval",
115116
}

aeon/transformations/collection/interval_based/_supervised_intervals.py

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class SupervisedIntervals(BaseCollectionTransformer):
136136
_tags = {
137137
"output_data_type": "Tabular",
138138
"capability:multivariate": True,
139+
"capability:multithreading": True,
139140
"requires_y": True,
140141
"algorithm_type": "interval",
141142
}

aeon/transformations/collection/shapelet_based/_dilated_shapelet_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class RandomDilatedShapeletTransform(BaseCollectionTransformer):
145145
"output_data_type": "Tabular",
146146
"capability:multivariate": True,
147147
"capability:unequal_length": True,
148+
"capability:multithreading": True,
148149
"X_inner_type": ["np-list", "numpy3D"],
149150
"algorithm_type": "shapelet",
150151
}

aeon/transformations/collection/shapelet_based/_rsast.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class RSAST(BaseCollectionTransformer):
9393
_tags = {
9494
"output_data_type": "Tabular",
9595
"capability:multivariate": False,
96+
"capability:multithreading": True,
9697
"algorithm_type": "shapelet",
9798
"python_dependencies": "statsmodels",
9899
}

aeon/transformations/collection/shapelet_based/_sast.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class SAST(BaseCollectionTransformer):
8888
_tags = {
8989
"output_data_type": "Tabular",
9090
"capability:multivariate": False,
91+
"capability:multithreading": True,
9192
"algorithm_type": "shapelet",
9293
}
9394

aeon/transformations/collection/shapelet_based/_shapelet_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class RandomShapeletTransform(BaseCollectionTransformer):
144144
"output_data_type": "Tabular",
145145
"capability:multivariate": True,
146146
"capability:unequal_length": True,
147+
"capability:multithreading": True,
147148
"X_inner_type": ["np-list", "numpy3D"],
148149
"requires_y": True,
149150
"algorithm_type": "shapelet",

aeon/transformations/series/_clasp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class ClaSPTransformer(BaseSeriesTransformer):
438438
"X_inner_type": "np.ndarray",
439439
"fit_is_empty": True,
440440
"requires_y": False,
441-
"capability:inverse_transform": False,
441+
"capability:multithreading": True,
442442
}
443443

444444
def __init__(

0 commit comments

Comments
 (0)