Skip to content

Commit 1ba16f8

Browse files
committed
Fixing error in CUR test
1 parent b83f986 commit 1ba16f8

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

src/skmatter/decomposition/_pcovr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class PCovR(_BasePCA, LinearModel):
181181
>>> Y = np.array([[0, -5], [-1, 1], [1, -5], [-3, 2]])
182182
>>> pcovr = PCovR(mixing=0.1, n_components=2)
183183
>>> pcovr.fit(X, Y)
184-
PCovR(mixing=0.1, n_components=2, space='sample')
184+
PCovR(mixing=0.1, n_components=2)
185185
>>> pcovr.transform(X)
186186
array([[ 3.2630561 , 0.06663787],
187187
[-2.69395511, -0.41582771],

tests/test_check_estimators.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from skmatter.feature_selection import PCovCUR as fPCovCUR
77
from skmatter.feature_selection import PCovFPS as fPCovFPS
88
from skmatter.linear_model import RidgeRegression2FoldCV # OrthogonalRegression,
9-
from skmatter.preprocessing import (
10-
KernelNormalizer,
11-
StandardFlexibleScaler,
12-
)
9+
from skmatter.preprocessing import KernelNormalizer, StandardFlexibleScaler
1310

1411

1512
@parametrize_with_checks(

tests/test_sample_simple_cur.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22

33
import numpy as np
4-
from sklearn import exceptions
54
from sklearn.datasets import fetch_california_housing as load
65

76
from skmatter.sample_selection import CUR, FPS
@@ -10,12 +9,12 @@
109
class TestCUR(unittest.TestCase):
1110
def setUp(self):
1211
self.X, _ = load(return_X_y=True)
13-
self.X = FPS(n_to_select=100).fit(self.X).transform(self.X)
12+
self.X = self.X[FPS(n_to_select=100).fit(self.X).selected_idx_]
1413
self.n_select = min(20, min(self.X.shape) // 2)
1514

1615
def test_bad_transform(self):
1716
selector = CUR(n_to_select=2)
18-
with self.assertRaises(exceptions.NotFittedError):
17+
with self.assertRaises(ValueError):
1918
_ = selector.transform(self.X)
2019

2120
def test_restart(self):

0 commit comments

Comments
 (0)