22
22
from ..base import _fit_context
23
23
from ..utils import check_random_state
24
24
from ..utils ._arpack import _init_arpack_v0
25
+ from ..utils ._array_api import get_namespace
25
26
from ..utils ._param_validation import Interval , RealNotInt , StrOptions
26
27
from ..utils .deprecation import deprecated
27
28
from ..utils .extmath import fast_logdet , randomized_svd , stable_cumsum , svd_flip
@@ -108,8 +109,10 @@ def _infer_dimension(spectrum, n_samples):
108
109
109
110
The returned value will be in [1, n_features - 1].
110
111
"""
111
- ll = np .empty_like (spectrum )
112
- ll [0 ] = - np .inf # we don't want to return n_components = 0
112
+ xp , _ = get_namespace (spectrum )
113
+
114
+ ll = xp .empty_like (spectrum )
115
+ ll [0 ] = - xp .inf # we don't want to return n_components = 0
113
116
for rank in range (1 , spectrum .shape [0 ]):
114
117
ll [rank ] = _assess_dimension (spectrum , rank , n_samples )
115
118
return ll .argmax ()
@@ -471,6 +474,7 @@ def fit_transform(self, X, y=None):
471
474
472
475
def _fit (self , X ):
473
476
"""Dispatch to the right submethod depending on the chosen solver."""
477
+ xp , is_array_api_compliant = get_namespace (X )
474
478
475
479
# Raise an error for sparse input.
476
480
# This is more informative than the generic one raised by check_array.
@@ -479,9 +483,14 @@ def _fit(self, X):
479
483
"PCA does not support sparse input. See "
480
484
"TruncatedSVD for a possible alternative."
481
485
)
486
+ # Raise an error for non-Numpy input and arpack solver.
487
+ if self .svd_solver == "arpack" and is_array_api_compliant :
488
+ raise ValueError (
489
+ "PCA with svd_solver='arpack' is not supported for Array API inputs."
490
+ )
482
491
483
492
X = self ._validate_data (
484
- X , dtype = [np .float64 , np .float32 ], ensure_2d = True , copy = self .copy
493
+ X , dtype = [xp .float64 , xp .float32 ], ensure_2d = True , copy = self .copy
485
494
)
486
495
487
496
# Handle n_components==None
@@ -513,6 +522,8 @@ def _fit(self, X):
513
522
514
523
def _fit_full (self , X , n_components ):
515
524
"""Fit the model by computing full SVD on X."""
525
+ xp , is_array_api_compliant = get_namespace (X )
526
+
516
527
n_samples , n_features = X .shape
517
528
518
529
if n_components == "mle" :
@@ -528,20 +539,30 @@ def _fit_full(self, X, n_components):
528
539
)
529
540
530
541
# Center data
531
- self .mean_ = np .mean (X , axis = 0 )
542
+ self .mean_ = xp .mean (X , axis = 0 )
532
543
X -= self .mean_
533
544
534
- U , S , Vt = linalg .svd (X , full_matrices = False )
545
+ if not is_array_api_compliant :
546
+ # Use scipy.linalg with NumPy/SciPy inputs for the sake of not
547
+ # introducing unanticipated behavior changes. In the long run we
548
+ # could instead decide to always use xp.linalg.svd for all inputs,
549
+ # but that would make this code rely on numpy's SVD instead of
550
+ # scipy's. It's not 100% clear whether they use the same LAPACK
551
+ # solver by default though (assuming both are built against the
552
+ # same BLAS).
553
+ U , S , Vt = linalg .svd (X , full_matrices = False )
554
+ else :
555
+ U , S , Vt = xp .linalg .svd (X , full_matrices = False )
535
556
# flip eigenvectors' sign to enforce deterministic output
536
557
U , Vt = svd_flip (U , Vt )
537
558
538
559
components_ = Vt
539
560
540
561
# Get variance explained by singular values
541
562
explained_variance_ = (S ** 2 ) / (n_samples - 1 )
542
- total_var = explained_variance_ .sum ()
563
+ total_var = xp .sum (explained_variance_ )
543
564
explained_variance_ratio_ = explained_variance_ / total_var
544
- singular_values_ = S . copy ( ) # Store the singular values.
565
+ singular_values_ = xp . asarray ( S , copy = True ) # Store the singular values.
545
566
546
567
# Postprocess the number of components required
547
568
if n_components == "mle" :
@@ -553,16 +574,16 @@ def _fit_full(self, X, n_components):
553
574
# their variance is always greater than n_components float
554
575
# passed. More discussion in issue: #15669
555
576
ratio_cumsum = stable_cumsum (explained_variance_ratio_ )
556
- n_components = np .searchsorted (ratio_cumsum , n_components , side = "right" ) + 1
577
+ n_components = xp .searchsorted (ratio_cumsum , n_components , side = "right" ) + 1
557
578
# Compute noise covariance using Probabilistic PCA model
558
579
# The sigma2 maximum likelihood (cf. eq. 12.46)
559
580
if n_components < min (n_features , n_samples ):
560
- self .noise_variance_ = explained_variance_ [n_components :]. mean ( )
581
+ self .noise_variance_ = xp . mean ( explained_variance_ [n_components :])
561
582
else :
562
583
self .noise_variance_ = 0.0
563
584
564
585
self .n_samples_ = n_samples
565
- self .components_ = components_ [:n_components ]
586
+ self .components_ = components_ [:n_components , : ]
566
587
self .n_components_ = n_components
567
588
self .explained_variance_ = explained_variance_ [:n_components ]
568
589
self .explained_variance_ratio_ = explained_variance_ratio_ [:n_components ]
@@ -574,6 +595,8 @@ def _fit_truncated(self, X, n_components, svd_solver):
574
595
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
575
596
on X.
576
597
"""
598
+ xp , _ = get_namespace (X )
599
+
577
600
n_samples , n_features = X .shape
578
601
579
602
if isinstance (n_components , str ):
@@ -599,7 +622,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
599
622
random_state = check_random_state (self .random_state )
600
623
601
624
# Center data
602
- self .mean_ = np .mean (X , axis = 0 )
625
+ self .mean_ = xp .mean (X , axis = 0 )
603
626
X -= self .mean_
604
627
605
628
if svd_solver == "arpack" :
@@ -633,15 +656,14 @@ def _fit_truncated(self, X, n_components, svd_solver):
633
656
# Workaround in-place variance calculation since at the time numpy
634
657
# did not have a way to calculate variance in-place.
635
658
N = X .shape [0 ] - 1
636
- np .square (X , out = X )
637
- np .sum (X , axis = 0 , out = X [0 ])
638
- total_var = (X [0 ] / N ).sum ()
659
+ X **= 2
660
+ total_var = xp .sum (xp .sum (X , axis = 0 ) / N )
639
661
640
662
self .explained_variance_ratio_ = self .explained_variance_ / total_var
641
- self .singular_values_ = S . copy ( ) # Store the singular values.
663
+ self .singular_values_ = xp . asarray ( S , copy = True ) # Store the singular values.
642
664
643
665
if self .n_components_ < min (n_features , n_samples ):
644
- self .noise_variance_ = total_var - self . explained_variance_ . sum ()
666
+ self .noise_variance_ = total_var - xp . sum (self . explained_variance_ )
645
667
self .noise_variance_ /= min (n_features , n_samples ) - n_components
646
668
else :
647
669
self .noise_variance_ = 0.0
@@ -666,12 +688,12 @@ def score_samples(self, X):
666
688
Log-likelihood of each sample under the current model.
667
689
"""
668
690
check_is_fitted (self )
669
-
670
- X = self ._validate_data (X , dtype = [np .float64 , np .float32 ], reset = False )
691
+ xp , _ = get_namespace ( X )
692
+ X = self ._validate_data (X , dtype = [xp .float64 , xp .float32 ], reset = False )
671
693
Xr = X - self .mean_
672
694
n_features = X .shape [1 ]
673
695
precision = self .get_precision ()
674
- log_like = - 0.5 * (Xr * (np . dot ( Xr , precision ))). sum ( axis = 1 )
696
+ log_like = - 0.5 * xp . sum (Xr * (Xr @ precision ), axis = 1 )
675
697
log_like -= 0.5 * (n_features * log (2.0 * np .pi ) - fast_logdet (precision ))
676
698
return log_like
677
699
@@ -695,7 +717,8 @@ def score(self, X, y=None):
695
717
ll : float
696
718
Average log-likelihood of the samples under the current model.
697
719
"""
698
- return np .mean (self .score_samples (X ))
720
+ xp , _ = get_namespace (X )
721
+ return float (xp .mean (self .score_samples (X )))
699
722
700
723
def _more_tags (self ):
701
- return {"preserves_dtype" : [np .float64 , np .float32 ]}
724
+ return {"preserves_dtype" : [np .float64 , np .float32 ], "array_api_support" : True }
0 commit comments