Skip to content

Commit beb7c97

Browse files
committed
BUG: Fixes issue pandas-dev#3334: brittle margin computation in pivot_table
Adds support for margin computation when all columns are used in rows and cols.
1 parent 527db38 commit beb7c97

File tree

4 files changed

+117
-24
lines changed

4 files changed

+117
-24
lines changed

Diff for: doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ pandas 0.13
102102
set _ref_locs (:issue:`4403`)
103103
- Fixed an issue where hist subplots were being overwritten when they were
104104
called using the top level matplotlib API (:issue:`4408`)
105+
- Fixed (:issue:`3334`) in pivot_table. Margins did not compute if values is the index.
105106

106107
pandas 0.12
107108
===========

Diff for: doc/source/v0.13.0.txt

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ Bug Fixes
5252

5353
- Fixed bug in ``PeriodIndex.map`` where using ``str`` would return the str
5454
representation of the index (:issue:`4136`)
55+
56+
- Fixed (:issue:`3334`) in pivot_table. Margins did not compute if values is the index.
57+
58+
5559

5660
- Fixed test failure ``test_time_series_plot_color_with_empty_kwargs`` when
5761
using custom matplotlib default colors (:issue:`4345`)

Diff for: pandas/tools/pivot.py

+90-24
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pandas import Series, DataFrame
44
from pandas.core.index import MultiIndex
5-
from pandas.core.reshape import _unstack_multiple
65
from pandas.tools.merge import concat
76
from pandas.tools.util import cartesian_product
87
from pandas.compat import range, lrange, zip
@@ -149,17 +148,64 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
149148
DataFrame.pivot_table = pivot_table
150149

151150

152-
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
153-
grand_margin = {}
154-
for k, v in compat.iteritems(data[values]):
155-
try:
156-
if isinstance(aggfunc, compat.string_types):
157-
grand_margin[k] = getattr(v, aggfunc)()
158-
else:
159-
grand_margin[k] = aggfunc(v)
160-
except TypeError:
161-
pass
151+
def _add_margins(table, data, values, rows, cols, aggfunc):
152+
153+
grand_margin = _compute_grand_margin(data, values, aggfunc)
154+
155+
if not values and isinstance(table, Series):
156+
# If there are no values and the table is a series, then there is only
157+
# one column in the data. Compute grand margin and return it.
158+
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
159+
return table.append(Series({row_key: grand_margin['All']}))
160+
161+
if values:
162+
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
163+
if not isinstance(marginal_result_set, tuple):
164+
return marginal_result_set
165+
result, margin_keys, row_margin = marginal_result_set
166+
else:
167+
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
168+
if not isinstance(marginal_result_set, tuple):
169+
return marginal_result_set
170+
result, margin_keys, row_margin = marginal_result_set
171+
172+
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
173+
174+
row_margin = row_margin.reindex(result.columns)
175+
# populate grand margin
176+
for k in margin_keys:
177+
if isinstance(k, basestring):
178+
row_margin[k] = grand_margin[k]
179+
else:
180+
row_margin[k] = grand_margin[k[0]]
162181

182+
margin_dummy = DataFrame(row_margin, columns=[key]).T
183+
184+
row_names = result.index.names
185+
result = result.append(margin_dummy)
186+
result.index.names = row_names
187+
188+
return result
189+
190+
191+
def _compute_grand_margin(data, values, aggfunc):
192+
193+
if values:
194+
grand_margin = {}
195+
for k, v in data[values].iteritems():
196+
try:
197+
if isinstance(aggfunc, basestring):
198+
grand_margin[k] = getattr(v, aggfunc)()
199+
else:
200+
grand_margin[k] = aggfunc(v)
201+
except TypeError:
202+
pass
203+
return grand_margin
204+
else:
205+
return {'All': aggfunc(data.index)}
206+
207+
208+
def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
163209
if len(cols) > 0:
164210
# need to "interleave" the margins
165211
table_pieces = []
@@ -203,23 +249,43 @@ def _all_key(key):
203249
else:
204250
row_margin = Series(np.nan, index=result.columns)
205251

206-
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
252+
return result, margin_keys, row_margin
207253

208-
row_margin = row_margin.reindex(result.columns)
209-
# populate grand margin
210-
for k in margin_keys:
211-
if len(cols) > 0:
212-
row_margin[k] = grand_margin[k[0]]
213-
else:
214-
row_margin[k] = grand_margin[k]
215254

216-
margin_dummy = DataFrame(row_margin, columns=[key]).T
255+
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
256+
if len(cols) > 0:
257+
# need to "interleave" the margins
258+
margin_keys = []
217259

218-
row_names = result.index.names
219-
result = result.append(margin_dummy)
220-
result.index.names = row_names
260+
def _all_key():
261+
if len(cols) == 1:
262+
return 'All'
263+
return ('All', ) + ('', ) * (len(cols) - 1)
221264

222-
return result
265+
if len(rows) > 0:
266+
margin = data[rows].groupby(rows).apply(aggfunc)
267+
all_key = _all_key()
268+
table[all_key] = margin
269+
result = table
270+
margin_keys.append(all_key)
271+
272+
else:
273+
margin = data.groupby(level=0, axis=0).apply(aggfunc)
274+
all_key = _all_key()
275+
table[all_key] = margin
276+
result = table
277+
margin_keys.append(all_key)
278+
return result
279+
else:
280+
result = table
281+
margin_keys = table.columns
282+
283+
if len(cols):
284+
row_margin = data[cols].groupby(cols).apply(aggfunc)
285+
else:
286+
row_margin = Series(np.nan, index=result.columns)
287+
288+
return result, margin_keys, row_margin
223289

224290

225291
def _convert_by(by):

Diff for: pandas/tools/tests/test_pivot.py

+22
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,28 @@ def test_pivot_complex_aggfunc(self):
296296

297297
tm.assert_frame_equal(result, expected)
298298

299+
def test_margins_no_values_no_cols(self):
300+
# Regression test on pivot table: no values or cols passed.
301+
result = self.data[['A', 'B']].pivot_table(rows=['A', 'B'], aggfunc=len, margins=True)
302+
result_list = result.tolist()
303+
self.assertEqual(sum(result_list[:-1]), result_list[-1])
304+
305+
def test_margins_no_values_two_rows(self):
306+
# Regression test on pivot table: no values passed but rows are a multi-index
307+
result = self.data[['A', 'B', 'C']].pivot_table(rows=['A', 'B'], cols='C', aggfunc=len, margins=True)
308+
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
309+
310+
def test_margins_no_values_one_row_one_col(self):
311+
# Regression test on pivot table: no values passed but row and col defined
312+
result = self.data[['A', 'B']].pivot_table(rows='A', cols='B', aggfunc=len, margins=True)
313+
self.assertEqual(result.All.tolist(), [4.0, 7.0, 11.0])
314+
315+
def test_margins_no_values_two_row_two_cols(self):
316+
# Regression test on pivot table: no values passed but rows and cols are multi-indexed
317+
self.data['D'] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']
318+
result = self.data[['A', 'B', 'C', 'D']].pivot_table(rows=['A', 'B'], cols=['C', 'D'], aggfunc=len, margins=True)
319+
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
320+
299321

300322
class TestCrosstab(unittest.TestCase):
301323

0 commit comments

Comments
 (0)