Skip to content

Commit 929fb7a

Browse files
rosecersagoscinski
andauthored
Fix/check estimator (#196)
* add sklearn estimator_checks tests and fix emerging test errors * consistently validate and check input data in fit functions * adding whitening option in PCovR * KernelFlexibleCenterer was not consistently using validated kernel, this has been fixed * adding tests tests/test_standard_flexible_scaler.py for taking average * create new test file tests/test_check_estimators.py with sklearn estimator_checks tests --------- Co-authored-by: Alexander Goscinski <[email protected]>
1 parent ec67c46 commit 929fb7a

16 files changed

+153
-73
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ __pycache__
44
*.egg-info
55
*.swp
66
*.swo
7+
*DS_Store
78

89
.tox/
910
build/

src/skmatter/_selection.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from scipy.sparse.linalg import eigsh
1313
from sklearn.base import BaseEstimator, MetaEstimatorMixin
1414
from sklearn.feature_selection._base import SelectorMixin
15-
from sklearn.utils import check_array, check_random_state, safe_mask
16-
from sklearn.utils._tags import _safe_tags
17-
from sklearn.utils.validation import check_is_fitted
15+
from sklearn.utils import check_array, check_random_state, check_X_y, safe_mask
16+
from sklearn.utils.validation import FLOAT_DTYPES, as_float_array, check_is_fitted
1817

1918
from .utils import (
2019
X_orthogonalizer,
@@ -125,7 +124,6 @@ def fit(self, X, y=None, warm_start=False):
125124
-------
126125
self : object
127126
"""
128-
tags = self._get_tags()
129127

130128
if self.selection_type == "feature":
131129
self._axis = 1
@@ -144,28 +142,28 @@ def fit(self, X, y=None, warm_start=False):
144142
elif self.progress_bar is False:
145143
self.report_progress_ = no_progress_bar
146144

147-
params = dict(
148-
accept_sparse="csc",
149-
force_all_finite=not tags.get("allow_nan", True),
150-
)
151-
if self._axis == 1:
152-
params["ensure_min_features"] = 2
153-
else:
154-
params["ensure_min_samples"] = 2
145+
params = dict(ensure_min_samples=2, ensure_min_features=2, dtype=FLOAT_DTYPES)
155146

156-
if y is not None:
157-
params["multi_output"] = True
147+
if hasattr(self, "mixing") or y is not None:
158148
X, y = self._validate_data(X, y, **params)
149+
X, y = check_X_y(X, y, multi_output=True)
159150

160151
if len(y.shape) == 1:
161152
# force y to have multi_output 2D format even when it's 1D, since
162153
# many functions, most notably PCov routines, assume an array storage
163154
# format, most notably to compute (y @ y.T)
164155
y = y.reshape((len(y), 1))
156+
165157
else:
166158
X = check_array(X, **params)
167159

160+
if self.full and self.score_threshold is not None:
161+
raise ValueError(
162+
"You cannot specify both `score_threshold` and `full=True`."
163+
)
164+
168165
n_to_select_from = X.shape[self._axis]
166+
self.n_samples_in_, self.n_features_in_ = X.shape
169167

170168
self.n_samples_in_, self.n_features_in_ = X.shape
171169

@@ -243,22 +241,27 @@ def transform(self, X, y=None):
243241
The selected subset of the input.
244242
"""
245243

246-
if len(X.shape) == 1:
247-
X = X.reshape(-1, 1)
244+
check_is_fitted(self, ["_axis", "selected_idx_", "n_selected_"])
245+
246+
if self._axis == 0:
247+
raise ValueError(
248+
"Transform is not currently supported for sample selection."
249+
)
248250

249251
mask = self.get_support()
250252

251-
# note: we use _safe_tags instead of _get_tags because this is a
252-
# public Mixin.
253-
X = self._validate_data(
254-
X,
255-
dtype=None,
256-
accept_sparse="csr",
257-
force_all_finite=not _safe_tags(self, key="allow_nan"),
258-
reset=False,
259-
ensure_2d=self._axis,
260-
)
253+
X = check_array(X)
261254

255+
if len(X.shape) == 1:
256+
if self._axis == 0:
257+
X = X.reshape(-1, 1)
258+
else:
259+
X = X.reshape(1, -1)
260+
261+
if len(mask) != X.shape[self._axis]:
262+
raise ValueError(
263+
"X has a different shape than during fitting. Reshape your data."
264+
)
262265
if self._axis == 1:
263266
return X[:, safe_mask(X, mask)]
264267
else:
@@ -517,7 +520,7 @@ def _init_greedy_search(self, X, y, n_to_select):
517520
features and computes their initial importance score.
518521
"""
519522

520-
self.X_current_ = X.copy()
523+
self.X_current_ = as_float_array(X.copy())
521524
self.pi_ = self._compute_pi(self.X_current_)
522525

523526
super()._init_greedy_search(X, y, n_to_select)

src/skmatter/decomposition/_pcovr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class PCovR(_BasePCA, LinearModel):
130130
Used when the 'arpack' or 'randomized' solvers are used. Pass an int
131131
for reproducible results across multiple function calls.
132132
133+
whiten : boolean, deprecated
134+
133135
Attributes
134136
----------
135137
@@ -202,12 +204,13 @@ def __init__(
202204
regressor=None,
203205
iterated_power="auto",
204206
random_state=None,
207+
whiten=False,
205208
):
206209
self.mixing = mixing
207210
self.n_components = n_components
208211
self.space = space
209212

210-
self.whiten = False
213+
self.whiten = whiten
211214
self.svd_solver = svd_solver
212215
self.tol = tol
213216
self.iterated_power = iterated_power

src/skmatter/linear_model/_base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from scipy.linalg import orthogonal_procrustes
33
from sklearn.base import MultiOutputMixin, RegressorMixin
44
from sklearn.linear_model import LinearRegression
5+
from sklearn.utils import check_array, check_X_y
6+
from sklearn.utils.validation import check_is_fitted
57

68

79
class OrthogonalRegression(MultiOutputMixin, RegressorMixin):
@@ -61,6 +63,15 @@ def fit(self, X, y):
6163
and n_targets is the number of target properties.
6264
"""
6365

66+
X, y = check_X_y(
67+
X,
68+
y,
69+
y_numeric=True,
70+
ensure_min_features=1,
71+
ensure_min_samples=1,
72+
multi_output=True,
73+
)
74+
6475
self.n_samples_in_, self.n_features_in_ = X.shape
6576
if self.use_orthogonal_projector:
6677
# check estimator
@@ -71,12 +82,15 @@ def fit(self, X, y):
7182
)
7283
# compute orthogonal projectors
7384
linear_estimator.fit(X, y)
74-
U, _, Vt = np.linalg.svd(linear_estimator.coef_.T, full_matrices=False)
75-
# project X and y to same dimension
76-
X = X @ U
77-
y = y @ Vt.T
85+
coef = np.reshape(linear_estimator.coef_.T, (X.shape[1], -1))
86+
U, _, Vt = np.linalg.svd(coef, full_matrices=False)
87+
7888
# compute weights by solving the Procrustes problem
79-
self.coef_ = (U @ orthogonal_procrustes(X, y)[0] @ Vt).T
89+
self.coef_ = (
90+
U
91+
@ orthogonal_procrustes(X @ U, y.reshape(X.shape[0], -1) @ Vt.T)[0]
92+
@ Vt
93+
).T
8094
else:
8195
self.max_components_ = max(X.shape[1], y.shape[1])
8296
X = np.pad(X, [(0, 0), (0, self.max_components_ - X.shape[1])])
@@ -93,6 +107,9 @@ def predict(self, X):
93107
Training data, where n_samples is the number of samples
94108
and n_features is the number of features.
95109
"""
110+
X = check_array(X, ensure_min_features=1, ensure_min_samples=1)
111+
check_is_fitted(self, ["coef_"])
112+
96113
if not (self.use_orthogonal_projector):
97114
X = np.pad(X, [(0, 0), (0, self.max_components_ - X.shape[1])])
98115
return X @ self.coef_.T

src/skmatter/linear_model/_ridge.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import numpy as np
22
from joblib import Parallel, delayed
3-
from sklearn.base import MultiOutputMixin, RegressorMixin
3+
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
44
from sklearn.metrics import check_scoring
55
from sklearn.model_selection import KFold
6+
from sklearn.utils import check_array
7+
from sklearn.utils.validation import check_is_fitted
68

79

8-
class RidgeRegression2FoldCV(MultiOutputMixin, RegressorMixin):
10+
class RidgeRegression2FoldCV(BaseEstimator, MultiOutputMixin, RegressorMixin):
911
r"""Ridge regression with an efficient 2-fold cross-validation method using the SVD
1012
solver.
1113
@@ -110,6 +112,9 @@ def __init__(
110112
self.shuffle = shuffle
111113
self.n_jobs = n_jobs
112114

115+
def _more_tags(self):
116+
return {"multioutput_only": True}
117+
113118
def fit(self, X, y):
114119
"""
115120
Parameters
@@ -138,6 +143,7 @@ def fit(self, X, y):
138143
"[0,1)"
139144
)
140145

146+
X, y = self._validate_data(X, y, y_numeric=True, multi_output=True)
141147
self.n_samples_in_, self.n_features_in_ = X.shape
142148

143149
# check_scoring uses estimators scoring function if the scorer is None, this is
@@ -164,6 +170,11 @@ def predict(self, X):
164170
Training data, where n_samples is the number of samples
165171
and n_features is the number of features.
166172
"""
173+
174+
X = check_array(X)
175+
176+
check_is_fitted(self, ["coef_"])
177+
167178
return X @ self.coef_.T
168179

169180
def _2fold_cv(self, X, y, fold1_idx, fold2_idx, scorer):

src/skmatter/metrics/_reconstruction_measures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def pointwise_local_reconstruction_error(
445445

446446
scaler.fit(X_train)
447447
X_train = scaler.transform(X_train)
448-
X_test = scaler.transform(X_test)
448+
X_test = scaler.transform(X_test).astype(X_train.dtype)
449449
scaler.fit(Y_train)
450450
Y_train = scaler.transform(Y_train)
451451
Y_test = scaler.transform(Y_test)

src/skmatter/preprocessing/_data.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ def fit(self, X, y=None, sample_weight=None):
135135
Fitted scaler.
136136
"""
137137

138+
X = self._validate_data(
139+
X,
140+
copy=self.copy,
141+
estimator=self,
142+
dtype=FLOAT_DTYPES,
143+
ensure_min_samples=2,
144+
)
138145
self.n_samples_in_, self.n_features_in_ = X.shape
139146

140147
if sample_weight is not None:
@@ -157,7 +164,7 @@ def fit(self, X, y=None, sample_weight=None):
157164
self.scale_ = np.sqrt(var)
158165
else:
159166
var_sum = var.sum()
160-
if var_sum < abs(np.mean(X_mean)) * self.rtol + self.atol:
167+
if var_sum < abs(np.average(X_mean)) * self.rtol + self.atol:
161168
raise ValueError("Cannot normalize a matrix with zero variance")
162169
self.scale_ = np.sqrt(var_sum)
163170

@@ -187,11 +194,9 @@ def transform(self, X, y=None, copy=None):
187194
X = self._validate_data(
188195
X,
189196
reset=False,
190-
accept_sparse="csr",
191197
copy=copy,
192198
estimator=self,
193199
dtype=FLOAT_DTYPES,
194-
force_all_finite="allow-nan",
195200
)
196201
check_is_fitted(
197202
self, attributes=["n_samples_in_", "n_features_in_", "scale_", "mean_"]
@@ -288,7 +293,7 @@ def __init__(self, with_center=True, with_trace=True):
288293
self.with_trace = with_trace
289294
super().__init__()
290295

291-
def fit(self, K=None, y=None, sample_weight=None):
296+
def fit(self, K, y=None, sample_weight=None):
292297
"""Fit KernelFlexibleCenterer
293298
294299
Parameters
@@ -310,7 +315,7 @@ def fit(self, K=None, y=None, sample_weight=None):
310315
Fitted transformer.
311316
"""
312317

313-
Kc = self._validate_data(K, copy=True, dtype=FLOAT_DTYPES, reset=False)
318+
K = self._validate_data(K, copy=True, dtype=FLOAT_DTYPES, reset=False)
314319

315320
if sample_weight is not None:
316321
self.sample_weight_ = _check_sample_weight(sample_weight, K, dtype=K.dtype)
@@ -327,20 +332,20 @@ def fit(self, K=None, y=None, sample_weight=None):
327332
else:
328333
super().fit(K, y)
329334

330-
K_pred_cols = np.average(Kc, weights=self.sample_weight_, axis=1)[
335+
K_pred_cols = np.average(K, weights=self.sample_weight_, axis=1)[
331336
:, np.newaxis
332337
]
333338
else:
334-
self.K_fit_rows_ = np.zeros(Kc.shape[1])
339+
self.K_fit_rows_ = np.zeros(K.shape[1])
335340
self.K_fit_all_ = 0.0
336-
K_pred_cols = np.zeros((Kc.shape[0], 1))
341+
K_pred_cols = np.zeros((K.shape[0], 1))
337342

338343
if self.with_trace:
339-
Kc -= self.K_fit_rows_
340-
Kc -= K_pred_cols
341-
Kc += self.K_fit_all_
344+
K -= self.K_fit_rows_
345+
K -= K_pred_cols
346+
K += self.K_fit_all_
342347

343-
self.scale_ = np.trace(Kc) / Kc.shape[0]
348+
self.scale_ = np.trace(K) / K.shape[0]
344349
else:
345350
self.scale_ = 1.0
346351

@@ -408,7 +413,7 @@ def fit_transform(self, K, y=None, sample_weight=None, copy=True, **fit_params):
408413
return self.transform(K, copy)
409414

410415

411-
class SparseKernelCenterer(TransformerMixin, BaseEstimator):
416+
class SparseKernelCenterer(TransformerMixin):
412417
r"""Kernel centering method for sparse kernels, similar to
413418
KernelFlexibleCenterer.
414419

src/skmatter/utils/_orthogonalizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def X_orthogonalizer(x1, c=None, x2=None, tol=1e-12, copy=False):
5656
if np.linalg.norm(col) < tol:
5757
warnings.warn("Column vector contains only zeros.", stacklevel=1)
5858
else:
59-
col /= np.linalg.norm(col, axis=0)
59+
col = np.divide(col, np.linalg.norm(col, axis=0))
6060

61-
xnew -= col @ (col.T @ xnew)
61+
xnew -= (col @ (col.T @ xnew)).astype(xnew.dtype)
6262

6363
return xnew
6464

src/skmatter/utils/_pcovr_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def pcovr_covariance(
186186
C_Y = C_Y.reshape((C.shape[0], -1))
187187
C_Y = np.real(C_Y)
188188

189-
C += (1 - mixing) * C_Y @ C_Y.T
189+
C += (1 - mixing) * np.array(C_Y @ C_Y.T, dtype=np.float64)
190190

191191
if mixing > 0:
192192
C += (mixing) * (X.T @ X)

tests/test_check_estimators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from sklearn.utils.estimator_checks import parametrize_with_checks
2+
3+
from skmatter.decomposition import KernelPCovR, PCovR
4+
from skmatter.feature_selection import CUR as fCUR
5+
from skmatter.feature_selection import FPS as fFPS
6+
from skmatter.feature_selection import PCovCUR as fPCovCUR
7+
from skmatter.feature_selection import PCovFPS as fPCovFPS
8+
from skmatter.linear_model import RidgeRegression2FoldCV # OrthogonalRegression,
9+
from skmatter.preprocessing import KernelNormalizer, StandardFlexibleScaler
10+
11+
12+
@parametrize_with_checks(
13+
[
14+
KernelPCovR(mixing=0.5),
15+
PCovR(mixing=0.5),
16+
fCUR(),
17+
fFPS(),
18+
fPCovCUR(),
19+
fPCovFPS(),
20+
RidgeRegression2FoldCV(),
21+
KernelNormalizer(),
22+
StandardFlexibleScaler(),
23+
]
24+
)
25+
def test_sklearn_compatible_estimator(estimator, check):
26+
check(estimator)

0 commit comments

Comments
 (0)