Skip to content

Commit 766f955

Browse files
committed
Sparse output only if sparse=True. With docs.
1 parent ae885db commit 766f955

File tree

3 files changed

+71
-17
lines changed

3 files changed

+71
-17
lines changed

README.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ For these examples, we'll also use pandas, numpy, and sklearn::
4444
>>> import numpy as np
4545
>>> import sklearn.preprocessing, sklearn.decomposition, \
4646
... sklearn.linear_model, sklearn.pipeline, sklearn.metrics
47+
>>> from sklearn.feature_extraction.text import CountVectorizer
4748

4849
Load some Data
4950
**************
@@ -156,6 +157,20 @@ Only columns that are listed in the DataFrameMapper are kept. To keep a column b
156157
[ 1., 0., 0., 5.],
157158
[ 0., 0., 1., 4.]])
158159

160+
161+
Working with sparse features
162+
****************************
163+
164+
`DataFrameMapper`s will return a dense feature array by default. Setting `sparse=True` in the mapper will return a sparse array whenever any of the extracted features is sparse. Example:
165+
166+
>>> mapper4 = DataFrameMapper([
167+
... ('pet', CountVectorizer()),
168+
... ], sparse=True)
169+
>>> type(mapper4.fit_transform(data))
170+
<class 'scipy.sparse.csr.csr_matrix'>
171+
172+
The stacking of the sparse features is done without ever densifying them.
173+
159174
Cross-Validation
160175
----------------
161176

sklearn_pandas/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,19 @@ class DataFrameMapper(BaseEstimator, TransformerMixin):
6969
sklearn transformation.
7070
"""
7171

72-
def __init__(self, features):
72+
def __init__(self, features, sparse=False):
7373
"""
7474
Params:
7575
7676
features a list of pairs. The first element is the pandas column
7777
selector. This can be a string (for one column) or a list
7878
of strings. The second element is an object that supports
7979
sklearn's transform interface.
80+
sparse will return sparse matrix if set True and any of the
81+
extracted features is sparse. Defaults to False.
8082
"""
8183
self.features = features
84+
self.sparse = sparse
8285

8386
def _get_col_subset(self, X, cols):
8487
"""
@@ -154,10 +157,15 @@ def transform(self, X):
154157
# were created from which input columns, so it's
155158
# assumed that that doesn't matter to the model.
156159

157-
# If any of the extracted features is sparse, combine to produce a
158-
# sparse matrix. Otherwise, produce a dense one.
160+
# If any of the extracted features is sparse, combine sparsely.
161+
# Otherwise, combine as normal arrays.
159162
if any(sparse.issparse(fea) for fea in extracted):
160163
stacked = sparse.hstack(extracted).tocsr()
164+
# return a sparse matrix only if the mapper was initialized
165+
# with sparse=True
166+
if not self.sparse:
167+
stacked = stacked.toarray()
161168
else:
162169
stacked = np.hstack(extracted)
170+
163171
return stacked

tests/test_dataframe_mapper.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from sklearn.pipeline import Pipeline
1515
from sklearn.svm import SVC
1616
from sklearn.feature_extraction.text import CountVectorizer
17-
from sklearn.preprocessing import Imputer, StandardScaler, LabelBinarizer
17+
from sklearn.preprocessing import Imputer, StandardScaler
18+
from sklearn.base import BaseEstimator, TransformerMixin
1819
import numpy as np
1920

2021
from sklearn_pandas import (
@@ -24,6 +25,17 @@
2425
)
2526

2627

28+
class ToSparseTransformer(BaseEstimator, TransformerMixin):
29+
"""
30+
Transforms numpy matrix to sparse format.
31+
"""
32+
def fit(self, X):
33+
return self
34+
35+
def transform(self, X):
36+
return sparse.csr_matrix(X)
37+
38+
2739
@pytest.fixture
2840
def iris_dataframe():
2941
iris = load_iris()
@@ -43,6 +55,11 @@ def cars_dataframe():
4355
return pd.read_csv("tests/test_data/cars.csv.gz", compression='gzip')
4456

4557

58+
@pytest.fixture
59+
def simple_dataframe():
60+
return pd.DataFrame({'a': [1, 2, 3]})
61+
62+
4663
def test_nonexistent_columns_explicit_fail(iris_dataframe):
4764
"""
4865
If a nonexistent column is selected, KeyError is raised.
@@ -93,32 +110,32 @@ def test_with_car_dataframe(cars_dataframe):
93110
assert scores.mean() > 0.30
94111

95112

96-
def test_cols_string_array():
113+
def test_cols_string_array(simple_dataframe):
97114
"""
98115
If an string specified as the columns, the transformer
99116
is called with a 1-d array as input.
100117
"""
101-
dataframe = pd.DataFrame({"a": [1, 2, 3]})
118+
df = simple_dataframe
102119
mock_transformer = Mock()
103120
mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing
104121
mapper = DataFrameMapper([("a", mock_transformer)])
105122

106-
mapper.fit_transform(dataframe)
123+
mapper.fit_transform(df)
107124
args, kwargs = mock_transformer.fit.call_args
108125
assert args[0].shape == (3,)
109126

110127

111-
def test_cols_list_column_vector():
128+
def test_cols_list_column_vector(simple_dataframe):
112129
"""
113130
If a one-element list is specified as the columns, the transformer
114131
is called with a column vector as input.
115132
"""
116-
dataframe = pd.DataFrame({"a": [1, 2, 3]})
133+
df = simple_dataframe
117134
mock_transformer = Mock()
118135
mock_transformer.transform.return_value = np.array([1, 2, 3]) # do nothing
119136
mapper = DataFrameMapper([(["a"], mock_transformer)])
120137

121-
mapper.fit_transform(dataframe)
138+
mapper.fit_transform(df)
122139
args, kwargs = mock_transformer.fit.call_args
123140
assert args[0].shape == (3, 1)
124141

@@ -143,15 +160,29 @@ def test_list_transformers():
143160
assert (abs(dmatrix.std(axis=0) - 1) <= 1e-6).all()
144161

145162

146-
def test_sparse_features(cars_dataframe):
163+
def test_sparse_features(simple_dataframe):
147164
"""
148-
If any of the extracted features is sparse, the hstacked
149-
is also sparse.
165+
If any of the extracted features is sparse and "sparse" argument
166+
is true, the hstacked result is also sparse.
150167
"""
168+
df = simple_dataframe
151169
mapper = DataFrameMapper([
152-
("description", CountVectorizer()), # sparse feature
153-
("model", LabelBinarizer()), # dense feature
154-
])
155-
dmatrix = mapper.fit_transform(cars_dataframe)
170+
("a", ToSparseTransformer())
171+
], sparse=True)
172+
dmatrix = mapper.fit_transform(df)
156173

157174
assert type(dmatrix) == sparse.csr.csr_matrix
175+
176+
177+
def test_sparse_off(simple_dataframe):
178+
"""
179+
If the resulting features are sparse but the "sparse" argument
180+
of the mapper is False, return a non-sparse matrix.
181+
"""
182+
df = simple_dataframe
183+
mapper = DataFrameMapper([
184+
("a", ToSparseTransformer())
185+
], sparse=False)
186+
187+
dmatrix = mapper.fit_transform(df)
188+
assert type(dmatrix) != sparse.csr.csr_matrix

0 commit comments

Comments
 (0)