Skip to content

Commit 6e6392a

Browse files
committed
Use sgkit.distarray for PCA
1 parent b703e08 commit 6e6392a

File tree

5 files changed

+253
-5
lines changed

5 files changed

+253
-5
lines changed

.github/workflows/cubed.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: Test with pytest
3232
run: |
33-
pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
33+
pytest -v sgkit/tests/test_{aggregation,association,hwe,pca}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or (test_pca__array_backend and tsqr)' --use-cubed

sgkit/stats/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
from typing import Any, Optional, Union
22

3-
import dask.array as da
43
import numpy as np
54
import xarray as xr
6-
from dask_ml.decomposition import TruncatedSVD
75
from sklearn.base import BaseEstimator
86
from sklearn.pipeline import Pipeline
97
from typing_extensions import Literal
108
from xarray import DataArray, Dataset
119

10+
import sgkit.distarray as da
1211
from sgkit import variables
1312

1413
from ..typing import ArrayLike, DType, RandomStateType
1514
from ..utils import conditional_merge_datasets
1615
from .aggregation import count_call_alleles
1716
from .preprocessing import PattersonScaler
17+
from .truncated_svd import TruncatedSVD
1818

1919

2020
def pca_est(

sgkit/stats/preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Hashable, Optional
22

3-
import dask.array as da
43
import numpy as np
54
import xarray as xr
65
from sklearn.base import BaseEstimator, TransformerMixin
76
from xarray import Dataset
87

8+
import sgkit.distarray as da
99
from sgkit import variables
1010

1111
from ..typing import ArrayLike

sgkit/stats/truncated_svd.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from dask.utils import has_keyword
2+
from sklearn.base import BaseEstimator, TransformerMixin
3+
4+
import sgkit.distarray as da
5+
6+
# Based on the implementation in Dask-ML, with minor changes to support the
7+
# array API so it can work with both Dask and Cubed.
8+
9+
10+
class TruncatedSVD(BaseEstimator, TransformerMixin):
11+
def __init__(
12+
self,
13+
n_components=2,
14+
algorithm="tsqr",
15+
n_iter=5,
16+
random_state=None,
17+
tol=0.0,
18+
compute=True,
19+
):
20+
"""Dimensionality reduction using truncated SVD (aka LSA).
21+
22+
This transformer performs linear dimensionality reduction by means of
23+
truncated singular value decomposition (SVD). Contrary to PCA, this
24+
estimator does not center the data before computing the singular value
25+
decomposition.
26+
27+
Parameters
28+
----------
29+
n_components : int, default = 2
30+
Desired dimensionality of output data.
31+
Must be less than or equal to the number of features.
32+
The default value is useful for visualization.
33+
34+
algorithm : {'tsqr', 'randomized'}
35+
SVD solver to use. Both use the `tsqr` (for "tall-and-skinny QR")
36+
algorithm internally. 'randomized' uses an approximate algorithm
37+
that is faster, but not exact. See the References for more.
38+
39+
n_iter : int, optional (default 0)
40+
Number of power iterations, useful when the singular values
41+
decay slowly. Error decreases exponentially as n_power_iter
42+
increases. In practice, set n_power_iter <= 4.
43+
44+
random_state : int, RandomState instance or None, optional
45+
If int, random_state is the seed used by the random number
46+
generator;
47+
If RandomState instance, random_state is the random number
48+
generator;
49+
If None, the random number generator is the RandomState instance
50+
used by `np.random`.
51+
52+
tol : float, optional
53+
Ignored.
54+
55+
compute : bool
56+
Whether or not SVD results should be computed
57+
eagerly, by default True.
58+
59+
Attributes
60+
----------
61+
components_ : array, shape (n_components, n_features)
62+
63+
explained_variance_ : array, shape (n_components,)
64+
The variance of the training samples transformed by a projection to
65+
each component.
66+
67+
explained_variance_ratio_ : array, shape (n_components,)
68+
Percentage of variance explained by each of the selected
69+
components.
70+
71+
singular_values_ : array, shape (n_components,)
72+
The singular values corresponding to each of the selected
73+
components. The singular values are equal to the 2-norms of the
74+
``n_components`` variables in the lower-dimensional space.
75+
76+
See Also
77+
--------
78+
dask.array.linalg.tsqr
79+
dask.array.linalg.svd_compressed
80+
81+
References
82+
----------
83+
84+
Direct QR factorizations for tall-and-skinny matrices in
85+
MapReduce architectures.
86+
A. Benson, D. Gleich, and J. Demmel.
87+
IEEE International Conference on Big Data, 2013.
88+
http://arxiv.org/abs/1301.1071
89+
90+
Notes
91+
-----
92+
SVD suffers from a problem called "sign indeterminacy", which means
93+
the sign of the ``components_`` and the output from transform depend on
94+
the algorithm and random state. To work around this, fit instances of
95+
this class to data once, then keep the instance around to do
96+
transformations.
97+
98+
.. warning::
99+
100+
The implementation currently does not support sparse matrices.
101+
102+
Examples
103+
--------
104+
>>> from dask_ml.decomposition import TruncatedSVD
105+
>>> import dask.array as da
106+
>>> X = da.random.normal(size=(1000, 20), chunks=(100, 20))
107+
>>> svd = TruncatedSVD(n_components=5, n_iter=3, random_state=42)
108+
>>> svd.fit(X) # doctest: +NORMALIZE_WHITESPACE
109+
TruncatedSVD(algorithm='tsqr', n_components=5, n_iter=3,
110+
random_state=42, tol=0.0)
111+
112+
>>> print(svd.explained_variance_ratio_) # doctest: +ELLIPSIS
113+
[0.06386323 0.06176776 0.05901293 0.0576399 0.05726607]
114+
>>> print(svd.explained_variance_ratio_.sum()) # doctest: +ELLIPSIS
115+
0.299...
116+
>>> print(svd.singular_values_) # doctest: +ELLIPSIS
117+
array([35.92469517, 35.32922121, 34.53368856, 34.138..., 34.013...])
118+
119+
Note that ``transform`` returns a ``dask.Array``.
120+
121+
>>> svd.transform(X)
122+
dask.array<sum-agg, shape=(1000, 5), dtype=float64, chunksize=(100, 5)>
123+
"""
124+
self.algorithm = algorithm
125+
self.n_components = n_components
126+
self.n_iter = n_iter
127+
self.random_state = random_state
128+
self.tol = tol
129+
self.compute = compute
130+
131+
def fit(self, X, y=None):
132+
"""Fit truncated SVD on training data X
133+
134+
Parameters
135+
----------
136+
X : array-like, shape (n_samples, n_features)
137+
Training data.
138+
139+
y : Ignored
140+
141+
Returns
142+
-------
143+
self : object
144+
Returns the transformer object.
145+
"""
146+
self.fit_transform(X)
147+
return self
148+
149+
def _check_array(self, X):
150+
if self.n_components >= X.shape[1]:
151+
raise ValueError(
152+
"n_components must be < n_features; "
153+
"got {} >= {}".format(self.n_components, X.shape[1])
154+
)
155+
return X
156+
157+
def fit_transform(self, X, y=None):
158+
"""Fit model to X and perform dimensionality reduction on X.
159+
160+
Parameters
161+
----------
162+
X : array-like, shape (n_samples, n_features)
163+
Training data.
164+
165+
y : Ignored
166+
167+
Returns
168+
-------
169+
X_new : array, shape (n_samples, n_components)
170+
Reduced version of X. This will always be a dense array, of the
171+
same type as the input array. If ``X`` was a ``dask.array``, then
172+
``X_new`` will be a ``dask.array`` with the same chunks along the
173+
first dimension.
174+
"""
175+
X = self._check_array(X)
176+
if self.algorithm not in {"tsqr", "randomized"}:
177+
raise ValueError(
178+
"`algorithm` must be 'tsqr' or 'randomized', not '{}'".format(
179+
self.algorithm
180+
)
181+
)
182+
if self.algorithm == "tsqr":
183+
if has_keyword(da.linalg.svd, "full_matrices"):
184+
u, s, v = da.linalg.svd(X, full_matrices=False)
185+
else:
186+
u, s, v = da.linalg.svd(X)
187+
u = u[:, : self.n_components]
188+
s = s[: self.n_components]
189+
v = v[: self.n_components]
190+
else:
191+
u, s, v = da.linalg.svd_compressed(
192+
X, self.n_components, n_power_iter=self.n_iter, seed=self.random_state
193+
)
194+
195+
X_transformed = u * s
196+
explained_var = da.var(X_transformed, axis=0)
197+
full_var = da.var(X, axis=0)
198+
full_var = da.sum(full_var)
199+
explained_variance_ratio = explained_var / full_var
200+
201+
if self.compute:
202+
v, explained_var, explained_variance_ratio, s = da.compute(
203+
v, explained_var, explained_variance_ratio, s
204+
)
205+
self.components_ = v
206+
self.explained_variance_ = explained_var
207+
self.explained_variance_ratio_ = explained_variance_ratio
208+
self.singular_values_ = s
209+
self.n_features_in_ = X.shape[1]
210+
return X_transformed
211+
212+
def transform(self, X, y=None):
213+
"""Perform dimensionality reduction on X.
214+
215+
Parameters
216+
----------
217+
X : array-like, shape (n_samples, n_features)
218+
Data to be transformed.
219+
220+
y : Ignored
221+
222+
Returns
223+
-------
224+
X_new : array, shape (n_samples, n_components)
225+
Reduced version of X. This will always be a dense array, of the
226+
same type as the input array. If ``X`` was a ``dask.array``, then
227+
``X_new`` will be a ``dask.array`` with the same chunks along the
228+
first dimension.
229+
"""
230+
return X @ self.components_.T
231+
232+
def inverse_transform(self, X):
233+
"""Transform X back to its original space.
234+
235+
Returns an array X_original whose transform would be X.
236+
237+
Parameters
238+
----------
239+
X : array-like, shape (n_samples, n_components)
240+
New data.
241+
242+
Returns
243+
-------
244+
X_original : array, shape (n_samples, n_features)
245+
Note that this is always a dense array.
246+
"""
247+
# X = check_array(X)
248+
return X @ self.components_

sgkit/tests/test_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Any, Optional
22

33
import allel
4-
import dask.array as da
54
import numpy as np
65
import pytest
76
import xarray as xr
87
from xarray import Dataset
98

9+
import sgkit.distarray as da
1010
from sgkit.stats import pca
1111
from sgkit.stats.pca import count_call_alternate_alleles
1212
from sgkit.testing import simulate_genotype_call_dataset

0 commit comments

Comments
 (0)