Skip to content

Commit 84c9792

Browse files
committed
BUG: pivot_table always returns a DataFrame
Before this commit, if * `values` is not list like * `columns` is `None` * `aggfunc` is not instance of `list` `pivot_table` returns a `Series`. This commit adds checking for `columns.nlevels` is greater than 1 to prevent from casting `table` to a `Series`. This will fix #4386.
1 parent a01644c commit 84c9792

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

Diff for: pandas/tools/pivot.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
158158
margins_name=margins_name)
159159

160160
# discard the top level
161-
if values_passed and not values_multi and not table.empty:
161+
if values_passed and not values_multi and not table.empty and \
162+
(table.columns.nlevels > 1):
162163
table = table[values[0]]
163164

164165
if len(index) == 0 and len(columns) > 0:

Diff for: pandas/tools/tests/test_pivot.py

+38
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,44 @@ def test_pivot_table_margins_name_with_aggfunc_list(self):
801801
expected = pd.DataFrame(table.values, index=ix, columns=cols)
802802
tm.assert_frame_equal(table, expected)
803803

804+
def test_pivot_table_not_series(self):
805+
# GH 4386
806+
# pivot_table always returns a DataFrame
807+
# when values is not list like and columns is None
808+
# and aggfunc is not instance of list
809+
df = DataFrame({'col1': [3, 4, 5],
810+
'col2': ['C', 'D', 'E'],
811+
'col3': [1, 3, 9]})
812+
813+
result = df.pivot_table('col1', index=['col3', 'col2'], aggfunc=np.sum)
814+
m = MultiIndex.from_arrays([[1, 3, 9],
815+
['C', 'D', 'E']],
816+
names=['col3', 'col2'])
817+
expected = DataFrame([3, 4, 5],
818+
index=m, columns=['col1'])
819+
820+
tm.assert_frame_equal(result, expected)
821+
822+
result = df.pivot_table(
823+
'col1', index='col3', columns='col2', aggfunc=np.sum
824+
)
825+
expected = DataFrame([[3, np.NaN, np.NaN],
826+
[np.NaN, 4, np.NaN],
827+
[np.NaN, np.NaN, 5]],
828+
index=Index([1, 3, 9], name='col3'),
829+
columns=Index(['C', 'D', 'E'], name='col2'))
830+
831+
tm.assert_frame_equal(result, expected)
832+
833+
result = df.pivot_table('col1', index='col3', aggfunc=[np.sum])
834+
m = MultiIndex.from_arrays([['sum'],
835+
['col1']])
836+
expected = DataFrame([3, 4, 5],
837+
index=Index([1, 3, 9], name='col3'),
838+
columns=m)
839+
840+
tm.assert_frame_equal(result, expected)
841+
804842

805843
class TestCrosstab(tm.TestCase):
806844

0 commit comments

Comments
 (0)