Skip to content

Commit af60553

Browse files
committed
Initial commit to pass an estimator check. WIP
1 parent dbab1f2 commit af60553

19 files changed

+355
-249
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ __pycache__
44
*.egg-info
55
*.swp
66
*.swo
7+
*DS_Store
78

89
.tox/
910
build/

src/skmatter/_selection.py

Lines changed: 84 additions & 84 deletions
Large diffs are not rendered by default.

src/skmatter/decomposition/_kernel_pcovr.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def __init__(
220220
self.kernel_params = kernel_params
221221

222222
self.n_jobs = n_jobs
223-
self.n_samples_ = None
224223

225224
self.fit_inverse_transform = fit_inverse_transform
226225

@@ -308,17 +307,19 @@ def fit(self, X, Y, W=None):
308307

309308
if self.n_components is None:
310309
if self.svd_solver != "arpack":
311-
self.n_components = X.shape[0]
310+
self.n_components_ = X.shape[0]
312311
else:
313-
self.n_components = X.shape[0] - 1
312+
self.n_components_ = X.shape[0] - 1
313+
else:
314+
self.n_components_ = self.n_components
314315

315316
K = self._get_kernel(X)
316317

317318
if self.center:
318319
self.centerer_ = KernelNormalizer()
319320
K = self.centerer_.fit_transform(K)
320321

321-
self.n_samples_ = X.shape[0]
322+
self.n_samples_in_, self.n_features_in_ = X.shape
322323

323324
if self.regressor != "precomputed":
324325
if self.regressor is None:
@@ -362,7 +363,7 @@ def fit(self, X, Y, W=None):
362363
# to avoid needing to compute the kernel a second time
363364
self.regressor_ = check_krr_fit(regressor, K, X, Y)
364365

365-
W = self.regressor_.dual_coef_.reshape(X.shape[0], -1)
366+
W = self.regressor_.dual_coef_.reshape(self.n_samples_in_, -1)
366367

367368
# Use this instead of `self.regressor_.predict(K)`
368369
# so that we can handle the case of the pre-fitted regressor
@@ -387,12 +388,17 @@ def fit(self, X, Y, W=None):
387388
# Handle svd_solver
388389
self._fit_svd_solver = self.svd_solver
389390
if self._fit_svd_solver == "auto":
390-
# Small problem or self.n_components == 'mle', just call full PCA
391-
if max(X.shape) <= 500 or self.n_components == "mle":
391+
# Small problem or self.n_components_ == 'mle', just call full PCA
392+
if (
393+
max(self.n_samples_in_, self.n_features_in_) <= 500
394+
or self.n_components_ == "mle"
395+
):
392396
self._fit_svd_solver = "full"
393-
elif self.n_components >= 1 and self.n_components < 0.8 * min(X.shape):
397+
elif self.n_components_ >= 1 and self.n_components_ < 0.8 * max(
398+
self.n_samples_in_, self.n_features_in_
399+
):
394400
self._fit_svd_solver = "randomized"
395-
# This is also the case of self.n_components in (0,1)
401+
# This is also the case of self.n_components_ in (0,1)
396402
else:
397403
self._fit_svd_solver = "full"
398404

@@ -536,31 +542,31 @@ def score(self, X, Y):
536542
return -sum([Lkpca, Lkrr])
537543

538544
def _decompose_truncated(self, mat):
539-
if not 1 <= self.n_components <= self.n_samples_:
545+
if not 1 <= self.n_components_ <= self.n_samples_in_:
540546
raise ValueError(
541547
"n_components=%r must be between 1 and "
542548
"n_samples=%r with "
543549
"svd_solver='%s'"
544550
% (
545-
self.n_components,
546-
self.n_samples_,
551+
self.n_components_,
552+
self.n_samples_in_,
547553
self.svd_solver,
548554
)
549555
)
550-
elif not isinstance(self.n_components, numbers.Integral):
556+
elif not isinstance(self.n_components_, numbers.Integral):
551557
raise ValueError(
552558
"n_components=%r must be of type int "
553559
"when greater than or equal to 1, was of type=%r"
554-
% (self.n_components, type(self.n_components))
560+
% (self.n_components_, type(self.n_components_))
555561
)
556-
elif self.svd_solver == "arpack" and self.n_components == self.n_samples_:
562+
elif self.svd_solver == "arpack" and self.n_components_ == self.n_samples_in_:
557563
raise ValueError(
558564
"n_components=%r must be strictly less than "
559565
"n_samples=%r with "
560566
"svd_solver='%s'"
561567
% (
562-
self.n_components,
563-
self.n_samples_,
568+
self.n_components_,
569+
self.n_samples_in_,
564570
self.svd_solver,
565571
)
566572
)
@@ -569,7 +575,7 @@ def _decompose_truncated(self, mat):
569575

570576
if self._fit_svd_solver == "arpack":
571577
v0 = _init_arpack_v0(min(mat.shape), random_state)
572-
U, S, Vt = svds(mat, k=self.n_components, tol=self.tol, v0=v0)
578+
U, S, Vt = svds(mat, k=self.n_components_, tol=self.tol, v0=v0)
573579
# svds doesn't abide by scipy.linalg.svd/randomized_svd
574580
# conventions, so reverse its outputs.
575581
S = S[::-1]
@@ -581,7 +587,7 @@ def _decompose_truncated(self, mat):
581587
# sign flipping is done inside
582588
U, S, Vt = randomized_svd(
583589
mat,
584-
n_components=self.n_components,
590+
n_components=self.n_components_,
585591
n_iter=self.iterated_power,
586592
flip_sign=True,
587593
random_state=random_state,
@@ -594,24 +600,25 @@ def _decompose_truncated(self, mat):
594600
return U, S, Vt
595601

596602
def _decompose_full(self, mat):
597-
if self.n_components != "mle":
598-
if not (0 <= self.n_components <= self.n_samples_):
603+
if self.n_components_ != "mle":
604+
if not (0 <= self.n_components_ <= self.n_samples_in_):
599605
raise ValueError(
600606
"n_components=%r must be between 1 and "
601607
"n_samples=%r with "
602608
"svd_solver='%s'"
603609
% (
604-
self.n_components,
605-
self.n_samples_,
610+
self.n_components_,
611+
self.n_samples_in_,
606612
self.svd_solver,
607613
)
608614
)
609-
elif self.n_components >= 1:
610-
if not isinstance(self.n_components, numbers.Integral):
615+
elif self.n_components_ >= 1:
616+
if not isinstance(self.n_components_, numbers.Integral):
611617
raise ValueError(
612618
"n_components=%r must be of type int "
613619
"when greater than or equal to 1, "
614-
"was of type=%r" % (self.n_components, type(self.n_components))
620+
"was of type=%r"
621+
% (self.n_components_, type(self.n_components_))
615622
)
616623

617624
U, S, Vt = linalg.svd(mat, full_matrices=False)
@@ -623,26 +630,28 @@ def _decompose_full(self, mat):
623630
U, Vt = svd_flip(U, Vt)
624631

625632
# Get variance explained by singular values
626-
explained_variance_ = (S**2) / (self.n_samples_ - 1)
633+
explained_variance_ = (S**2) / (self.n_samples_in_ - 1)
627634
total_var = explained_variance_.sum()
628635
explained_variance_ratio_ = explained_variance_ / total_var
629636

630637
# Postprocess the number of components required
631-
if self.n_components == "mle":
632-
self.n_components = _infer_dimension(explained_variance_, self.n_samples_)
633-
elif 0 < self.n_components < 1.0:
638+
if self.n_components_ == "mle":
639+
self.n_components_ = _infer_dimension(
640+
explained_variance_, self.n_samples_in_
641+
)
642+
elif 0 < self.n_components_ < 1.0:
634643
# number of components for which the cumulated explained
635644
# variance percentage is superior to the desired threshold
636645
# side='right' ensures that number of features selected
637-
# their variance is always greater than self.n_components float
646+
# their variance is always greater than self.n_components_ float
638647
# passed. More discussion in issue: #15669
639648
ratio_cumsum = stable_cumsum(explained_variance_ratio_)
640-
self.n_components = (
641-
np.searchsorted(ratio_cumsum, self.n_components, side="right") + 1
649+
self.n_components_ = (
650+
np.searchsorted(ratio_cumsum, self.n_components_, side="right") + 1
642651
)
643-
self.n_components = self.n_components
652+
644653
return (
645-
U[:, : self.n_components],
646-
S[: self.n_components],
647-
Vt[: self.n_components],
654+
U[:, : self.n_components_],
655+
S[: self.n_components_],
656+
Vt[: self.n_components_],
648657
)

0 commit comments

Comments
 (0)