Skip to content

Commit bf4d512

Browse files
mtsokologriselbetatimthomasjpfan
authored andcommitted
ENH Array API support for PCA (scikit-learn#26315)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Tim Head <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 1a5245c commit bf4d512

File tree

14 files changed

+594
-184
lines changed

14 files changed

+594
-184
lines changed

Diff for: build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

+40-39
Large diffs are not rendered by default.

Diff for: doc/modules/array_api.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ the tensors directly::
8888
Estimators with support for `Array API`-compatible inputs
8989
=========================================================
9090

91+
- :class:`decomposition.PCA` (with `svd_solver="full"`,
92+
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
9193
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
9294

9395
Coverage for more estimators is expected to grow over time. Please follow the
@@ -107,4 +109,4 @@ To run these checks you need to install
107109
test environment. To run the full set of checks you need to install both
108110
`PyTorch <https://pytorch.org/>`_ and `CuPy <https://cupy.dev/>`_ and have
109111
a GPU. Checks that can not be executed or have missing dependencies will be
110-
automatically skipped.
112+
automatically skipped.

Diff for: doc/modules/model_evaluation.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ where :math:`1(x)` is the `indicator function
451451
>>> accuracy_score(y_true, y_pred)
452452
0.5
453453
>>> accuracy_score(y_true, y_pred, normalize=False)
454-
2
454+
2.0
455455

456456
In the multilabel case with binary label indicators::
457457

@@ -1696,7 +1696,7 @@ loss can also be computed as :math:`zero-one loss = 1 - accuracy`.
16961696
>>> zero_one_loss(y_true, y_pred)
16971697
0.25
16981698
>>> zero_one_loss(y_true, y_pred, normalize=False)
1699-
1
1699+
1.0
17001700

17011701
In the multilabel case with binary label indicators, where the first label
17021702
set [0,1] has an error::
@@ -1705,7 +1705,7 @@ set [0,1] has an error::
17051705
0.5
17061706

17071707
>>> zero_one_loss(np.array([[0, 1], [1, 1]]), np.ones((2, 2)), normalize=False)
1708-
1
1708+
1.0
17091709

17101710
.. topic:: Example:
17111711

Diff for: doc/whats_new/v1.4.rst

+7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Changelog
4343
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
4444
where 123456 is the *pull request* number, not the issue number.
4545
46+
4647
:mod:`sklearn.base`
4748
...................
4849

@@ -64,6 +65,12 @@ Changelog
6465
from `None` to `auto` in version 1.6.
6566
:pr:`26634` by :user:`Alexandre Landeau <AlexL>` and :user:`Alexandre Vigny <avigny>`.
6667

68+
- |Enhancement| :class:`decomposition.PCA` now supports the Array API for the
69+
`full` and `randomized` solvers (with QR power iterations). See
70+
:ref:`array_api` for more details.
71+
:pr:`26315` by :user:`Mateusz Sokół <mtsokol>` and
72+
:user:`Olivier Grisel <ogrisel>`.
73+
6774
:mod:`sklearn.ensemble`
6875
.......................
6976

Diff for: sklearn/decomposition/_base.py

+40-23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy import linalg
1515

1616
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
17+
from ..utils._array_api import _add_to_diagonal, get_namespace
1718
from ..utils.validation import check_is_fitted
1819

1920

@@ -38,13 +39,18 @@ def get_covariance(self):
3839
cov : array of shape=(n_features, n_features)
3940
Estimated covariance of data.
4041
"""
42+
xp, _ = get_namespace(self.components_)
43+
4144
components_ = self.components_
4245
exp_var = self.explained_variance_
4346
if self.whiten:
44-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
45-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
46-
cov = np.dot(components_.T * exp_var_diff, components_)
47-
cov.flat[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
47+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
48+
exp_var_diff = exp_var - self.noise_variance_
49+
exp_var_diff = xp.where(
50+
exp_var > self.noise_variance_, exp_var_diff, xp.asarray(0.0)
51+
)
52+
cov = (components_.T * exp_var_diff) @ components_
53+
_add_to_diagonal(cov, self.noise_variance_, xp)
4854
return cov
4955

5056
def get_precision(self):
@@ -58,26 +64,36 @@ def get_precision(self):
5864
precision : array, shape=(n_features, n_features)
5965
Estimated precision of data.
6066
"""
67+
xp, is_array_api_compliant = get_namespace(self.components_)
68+
6169
n_features = self.components_.shape[1]
6270

6371
# handle corner cases first
6472
if self.n_components_ == 0:
65-
return np.eye(n_features) / self.noise_variance_
73+
return xp.eye(n_features) / self.noise_variance_
74+
75+
if is_array_api_compliant:
76+
linalg_inv = xp.linalg.inv
77+
else:
78+
linalg_inv = linalg.inv
6679

67-
if np.isclose(self.noise_variance_, 0.0, atol=0.0):
68-
return linalg.inv(self.get_covariance())
80+
if self.noise_variance_ == 0.0:
81+
return linalg_inv(self.get_covariance())
6982

7083
# Get precision using matrix inversion lemma
7184
components_ = self.components_
7285
exp_var = self.explained_variance_
7386
if self.whiten:
74-
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
75-
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
76-
precision = np.dot(components_, components_.T) / self.noise_variance_
77-
precision.flat[:: len(precision) + 1] += 1.0 / exp_var_diff
78-
precision = np.dot(components_.T, np.dot(linalg.inv(precision), components_))
87+
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
88+
exp_var_diff = exp_var - self.noise_variance_
89+
exp_var_diff = xp.where(
90+
exp_var > self.noise_variance_, exp_var_diff, xp.asarray(0.0)
91+
)
92+
precision = components_ @ components_.T / self.noise_variance_
93+
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
94+
precision = components_.T @ linalg_inv(precision) @ components_
7995
precision /= -(self.noise_variance_**2)
80-
precision.flat[:: len(precision) + 1] += 1.0 / self.noise_variance_
96+
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
8197
return precision
8298

8399
@abstractmethod
@@ -116,14 +132,16 @@ def transform(self, X):
116132
Projection of X in the first principal components, where `n_samples`
117133
is the number of samples and `n_components` is the number of the components.
118134
"""
135+
xp, _ = get_namespace(X)
136+
119137
check_is_fitted(self)
120138

121-
X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
139+
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
122140
if self.mean_ is not None:
123141
X = X - self.mean_
124-
X_transformed = np.dot(X, self.components_.T)
142+
X_transformed = X @ self.components_.T
125143
if self.whiten:
126-
X_transformed /= np.sqrt(self.explained_variance_)
144+
X_transformed /= xp.sqrt(self.explained_variance_)
127145
return X_transformed
128146

129147
def inverse_transform(self, X):
@@ -148,16 +166,15 @@ def inverse_transform(self, X):
148166
If whitening is enabled, inverse_transform will compute the
149167
exact inverse operation, which includes reversing whitening.
150168
"""
169+
xp, _ = get_namespace(X)
170+
151171
if self.whiten:
152-
return (
153-
np.dot(
154-
X,
155-
np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_,
156-
)
157-
+ self.mean_
172+
scaled_components = (
173+
xp.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_
158174
)
175+
return X @ scaled_components + self.mean_
159176
else:
160-
return np.dot(X, self.components_) + self.mean_
177+
return X @ self.components_ + self.mean_
161178

162179
@property
163180
def _n_features_out(self):

Diff for: sklearn/decomposition/_pca.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..base import _fit_context
2323
from ..utils import check_random_state
2424
from ..utils._arpack import _init_arpack_v0
25+
from ..utils._array_api import get_namespace
2526
from ..utils._param_validation import Interval, RealNotInt, StrOptions
2627
from ..utils.deprecation import deprecated
2728
from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip
@@ -108,8 +109,10 @@ def _infer_dimension(spectrum, n_samples):
108109
109110
The returned value will be in [1, n_features - 1].
110111
"""
111-
ll = np.empty_like(spectrum)
112-
ll[0] = -np.inf # we don't want to return n_components = 0
112+
xp, _ = get_namespace(spectrum)
113+
114+
ll = xp.empty_like(spectrum)
115+
ll[0] = -xp.inf # we don't want to return n_components = 0
113116
for rank in range(1, spectrum.shape[0]):
114117
ll[rank] = _assess_dimension(spectrum, rank, n_samples)
115118
return ll.argmax()
@@ -471,6 +474,7 @@ def fit_transform(self, X, y=None):
471474

472475
def _fit(self, X):
473476
"""Dispatch to the right submethod depending on the chosen solver."""
477+
xp, is_array_api_compliant = get_namespace(X)
474478

475479
# Raise an error for sparse input.
476480
# This is more informative than the generic one raised by check_array.
@@ -479,9 +483,14 @@ def _fit(self, X):
479483
"PCA does not support sparse input. See "
480484
"TruncatedSVD for a possible alternative."
481485
)
486+
# Raise an error for non-Numpy input and arpack solver.
487+
if self.svd_solver == "arpack" and is_array_api_compliant:
488+
raise ValueError(
489+
"PCA with svd_solver='arpack' is not supported for Array API inputs."
490+
)
482491

483492
X = self._validate_data(
484-
X, dtype=[np.float64, np.float32], ensure_2d=True, copy=self.copy
493+
X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy
485494
)
486495

487496
# Handle n_components==None
@@ -513,6 +522,8 @@ def _fit(self, X):
513522

514523
def _fit_full(self, X, n_components):
515524
"""Fit the model by computing full SVD on X."""
525+
xp, is_array_api_compliant = get_namespace(X)
526+
516527
n_samples, n_features = X.shape
517528

518529
if n_components == "mle":
@@ -528,20 +539,30 @@ def _fit_full(self, X, n_components):
528539
)
529540

530541
# Center data
531-
self.mean_ = np.mean(X, axis=0)
542+
self.mean_ = xp.mean(X, axis=0)
532543
X -= self.mean_
533544

534-
U, S, Vt = linalg.svd(X, full_matrices=False)
545+
if not is_array_api_compliant:
546+
# Use scipy.linalg with NumPy/SciPy inputs for the sake of not
547+
# introducing unanticipated behavior changes. In the long run we
548+
# could instead decide to always use xp.linalg.svd for all inputs,
549+
# but that would make this code rely on numpy's SVD instead of
550+
# scipy's. It's not 100% clear whether they use the same LAPACK
551+
# solver by default though (assuming both are built against the
552+
# same BLAS).
553+
U, S, Vt = linalg.svd(X, full_matrices=False)
554+
else:
555+
U, S, Vt = xp.linalg.svd(X, full_matrices=False)
535556
# flip eigenvectors' sign to enforce deterministic output
536557
U, Vt = svd_flip(U, Vt)
537558

538559
components_ = Vt
539560

540561
# Get variance explained by singular values
541562
explained_variance_ = (S**2) / (n_samples - 1)
542-
total_var = explained_variance_.sum()
563+
total_var = xp.sum(explained_variance_)
543564
explained_variance_ratio_ = explained_variance_ / total_var
544-
singular_values_ = S.copy() # Store the singular values.
565+
singular_values_ = xp.asarray(S, copy=True) # Store the singular values.
545566

546567
# Postprocess the number of components required
547568
if n_components == "mle":
@@ -553,16 +574,16 @@ def _fit_full(self, X, n_components):
553574
# their variance is always greater than n_components float
554575
# passed. More discussion in issue: #15669
555576
ratio_cumsum = stable_cumsum(explained_variance_ratio_)
556-
n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1
577+
n_components = xp.searchsorted(ratio_cumsum, n_components, side="right") + 1
557578
# Compute noise covariance using Probabilistic PCA model
558579
# The sigma2 maximum likelihood (cf. eq. 12.46)
559580
if n_components < min(n_features, n_samples):
560-
self.noise_variance_ = explained_variance_[n_components:].mean()
581+
self.noise_variance_ = xp.mean(explained_variance_[n_components:])
561582
else:
562583
self.noise_variance_ = 0.0
563584

564585
self.n_samples_ = n_samples
565-
self.components_ = components_[:n_components]
586+
self.components_ = components_[:n_components, :]
566587
self.n_components_ = n_components
567588
self.explained_variance_ = explained_variance_[:n_components]
568589
self.explained_variance_ratio_ = explained_variance_ratio_[:n_components]
@@ -574,6 +595,8 @@ def _fit_truncated(self, X, n_components, svd_solver):
574595
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
575596
on X.
576597
"""
598+
xp, _ = get_namespace(X)
599+
577600
n_samples, n_features = X.shape
578601

579602
if isinstance(n_components, str):
@@ -599,7 +622,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
599622
random_state = check_random_state(self.random_state)
600623

601624
# Center data
602-
self.mean_ = np.mean(X, axis=0)
625+
self.mean_ = xp.mean(X, axis=0)
603626
X -= self.mean_
604627

605628
if svd_solver == "arpack":
@@ -633,15 +656,14 @@ def _fit_truncated(self, X, n_components, svd_solver):
633656
# Workaround in-place variance calculation since at the time numpy
634657
# did not have a way to calculate variance in-place.
635658
N = X.shape[0] - 1
636-
np.square(X, out=X)
637-
np.sum(X, axis=0, out=X[0])
638-
total_var = (X[0] / N).sum()
659+
X **= 2
660+
total_var = xp.sum(xp.sum(X, axis=0) / N)
639661

640662
self.explained_variance_ratio_ = self.explained_variance_ / total_var
641-
self.singular_values_ = S.copy() # Store the singular values.
663+
self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values.
642664

643665
if self.n_components_ < min(n_features, n_samples):
644-
self.noise_variance_ = total_var - self.explained_variance_.sum()
666+
self.noise_variance_ = total_var - xp.sum(self.explained_variance_)
645667
self.noise_variance_ /= min(n_features, n_samples) - n_components
646668
else:
647669
self.noise_variance_ = 0.0
@@ -666,12 +688,12 @@ def score_samples(self, X):
666688
Log-likelihood of each sample under the current model.
667689
"""
668690
check_is_fitted(self)
669-
670-
X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
691+
xp, _ = get_namespace(X)
692+
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
671693
Xr = X - self.mean_
672694
n_features = X.shape[1]
673695
precision = self.get_precision()
674-
log_like = -0.5 * (Xr * (np.dot(Xr, precision))).sum(axis=1)
696+
log_like = -0.5 * xp.sum(Xr * (Xr @ precision), axis=1)
675697
log_like -= 0.5 * (n_features * log(2.0 * np.pi) - fast_logdet(precision))
676698
return log_like
677699

@@ -695,7 +717,8 @@ def score(self, X, y=None):
695717
ll : float
696718
Average log-likelihood of the samples under the current model.
697719
"""
698-
return np.mean(self.score_samples(X))
720+
xp, _ = get_namespace(X)
721+
return float(xp.mean(self.score_samples(X)))
699722

700723
def _more_tags(self):
701-
return {"preserves_dtype": [np.float64, np.float32]}
724+
return {"preserves_dtype": [np.float64, np.float32], "array_api_support": True}

0 commit comments

Comments
 (0)