Skip to content

BUG: Maintain column order with groupby.nth #22811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 20, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
@@ -288,6 +288,7 @@ Other Enhancements
- Added :meth:`Interval.overlaps`, :meth:`IntervalArray.overlaps`, and :meth:`IntervalIndex.overlaps` for determining overlaps between interval-like objects (:issue:`21998`)
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`)
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
- :meth:`DataFrame.to_stata` and :class:` pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
@@ -1408,6 +1409,7 @@ Groupby/Resample/Rolling
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`)

Reshaping
^^^^^^^^^
3 changes: 2 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -494,7 +494,8 @@ def _set_group_selection(self):

if len(groupers):
# GH12839 clear selected obj cache when group selection changes
self._group_selection = ax.difference(Index(groupers)).tolist()
self._group_selection = ax.difference(Index(groupers),
sort=False).tolist()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Index.difference tries to sort its result by default and this means that sometimes the order of the columns was changed from the original DataFrame. I added a new sort parameter to Index.difference with a default of True to control this.

self._reset_cache('_selected_obj')

def _set_result_index_ordered(self, result):
20 changes: 13 additions & 7 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
@@ -2944,17 +2944,20 @@ def intersection(self, other):
taken.name = None
return taken

def difference(self, other):
def difference(self, other, sort=True):
"""
Return a new Index with elements from the index that are not in
`other`.

This is the set difference of two Index objects.
It's sorted if sorting is possible.

Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible

.. versionadded:: 0.24.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make sure this is added to all subclasses as well (mutli, interval) I think have there own impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you do this (in this PR), can ideally update the tests for .difference for all types to parameterize it where appropriate


Returns
-------
@@ -2963,10 +2966,12 @@ def difference(self, other):
Examples
--------

>>> idx1 = pd.Index([1, 2, 3, 4])
>>> idx1 = pd.Index([2, 1, 3, 4])
>>> idx2 = pd.Index([3, 4, 5, 6])
>>> idx1.difference(idx2)
Int64Index([1, 2], dtype='int64')
>>> idx1.difference(idx2, sort=False)
Int64Index([2, 1], dtype='int64')

"""
self._assert_can_do_setop(other)
@@ -2985,10 +2990,11 @@ def difference(self, other):
label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
the_diff = this.values.take(label_diff)
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some tests in the index tests to exercise this (prob just parameterize the parameter in the tests)

if sort:
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass

return this._shallow_copy(the_diff, name=result_name, freq=None)

8 changes: 6 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
@@ -1037,7 +1037,7 @@ def overlaps(self, other):
return self._data.overlaps(other)

def _setop(op_name):
def func(self, other):
def func(self, other, sort=True):
other = self._as_like_interval_index(other)

# GH 19016: ensure set op will not return a prohibited dtype
@@ -1048,7 +1048,11 @@ def func(self, other):
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

result = getattr(self._multiindex, op_name)(other._multiindex)
if op_name == 'difference':
result = getattr(self._multiindex, op_name)(other._multiindex,
Copy link
Contributor Author

@reidy-p reidy-p Oct 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit awkward at the moment because difference is the only set operation with the sort parameter. But if we add a sort parameter to the other set operations I think we can get rid of the if statement

sort)
else:
result = getattr(self._multiindex, op_name)(other._multiindex)
result_name = get_op_result_name(self, other)

# GH 19101: ensure empty results have correct dtype
22 changes: 19 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
@@ -2798,10 +2798,18 @@ def intersection(self, other):
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def difference(self, other):
def difference(self, other, sort=True):
"""
Compute sorted set difference of two MultiIndex objects

Parameters
----------
other : MultiIndex
sort : bool, default True
Sort the resulting MultiIndex if possible

.. versionadded:: 0.24.0

Returns
-------
diff : MultiIndex
@@ -2817,8 +2825,16 @@ def difference(self, other):
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)

difference = sorted(set(self._ndarray_values) -
set(other._ndarray_values))
this = self._get_unique_index()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old way of doing this using set did not preserve the original order so I took this code from the difference method in pandas/core/indexes/base.py:

this = self._get_unique_index()
indexer = this.get_indexer(other)
indexer = indexer.take((indexer != -1).nonzero()[0])
label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
the_diff = this.values.take(label_diff)


indexer = this.get_indexer(other)
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
difference = this.values.take(label_diff)
if sort:
difference = sorted(difference)

if len(difference) == 0:
return MultiIndex(levels=[[]] * self.nlevels,
24 changes: 24 additions & 0 deletions pandas/tests/groupby/test_nth.py
Original file line number Diff line number Diff line change
@@ -390,3 +390,27 @@ def test_nth_empty():
names=['a', 'b']),
columns=['c'])
assert_frame_equal(result, expected)


def test_nth_column_order():
# GH 20760
# Check that nth preserves column order
df = DataFrame([[1, 'b', 100],
[1, 'a', 50],
[1, 'a', np.nan],
[2, 'c', 200],
[2, 'd', 150]],
columns=['A', 'C', 'B'])
result = df.groupby('A').nth(0)
expected = DataFrame([['b', 100.0],
['c', 200.0]],
columns=['C', 'B'],
index=Index([1, 2], name='A'))
assert_frame_equal(result, expected)

result = df.groupby('A').nth(-1, dropna='any')
expected = DataFrame([['a', 50.0],
['d', 150.0]],
columns=['C', 'B'],
index=Index([1, 2], name='A'))
assert_frame_equal(result, expected)
11 changes: 6 additions & 5 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
@@ -666,12 +666,13 @@ def test_union_base(self):
with pytest.raises(TypeError, match=msg):
first.union([1, 2, 3])

def test_difference_base(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_base(self, sort):
for name, idx in compat.iteritems(self.indices):
first = idx[2:]
second = idx[:4]
answer = idx[4:]
result = first.difference(second)
result = first.difference(second, sort)

if isinstance(idx, CategoricalIndex):
pass
@@ -685,21 +686,21 @@ def test_difference_base(self):
if isinstance(idx, PeriodIndex):
msg = "can only call with other PeriodIndex-ed objects"
with pytest.raises(ValueError, match=msg):
first.difference(case)
first.difference(case, sort)
elif isinstance(idx, CategoricalIndex):
pass
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
assert result.__class__ == answer.__class__
tm.assert_numpy_array_equal(result.sort_values().asi8,
answer.sort_values().asi8)
else:
result = first.difference(case)
result = first.difference(case, sort)
assert tm.equalContents(result, answer)

if isinstance(idx, MultiIndex):
msg = "other must be a MultiIndex or a list of tuples"
with pytest.raises(TypeError, match=msg):
first.difference([1, 2, 3])
first.difference([1, 2, 3], sort)

def test_symmetric_difference(self):
for name, idx in compat.iteritems(self.indices):
34 changes: 21 additions & 13 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
@@ -209,47 +209,55 @@ def test_intersection_bug_1708(self):
assert len(result) == 0

@pytest.mark.parametrize("tz", tz)
def test_difference(self, tz):
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, tz, sort):
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
Copy link
Contributor Author

@reidy-p reidy-p Oct 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to ensure that the sort parameter was getting a proper test with unsorted data so I have rewritten some tests to have unsorted data (e.g., by manually specifying a list of dates here rather than using date_range). I have made similar changes to other existing tests.

'1/5/2000']

rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)

rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)

rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
other3 = pd.DatetimeIndex([], tz=tz)
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)

for rng, other, expected in [(rng1, other1, expected1),
(rng2, other2, expected2),
(rng3, other3, expected3)]:
result_diff = rng.difference(other)
result_diff = rng.difference(other, sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result_diff, expected)

def test_difference_freq(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_freq(self, sort):
# GH14323: difference of DatetimeIndex should not preserve frequency

index = date_range("20160920", "20160925", freq="D")
other = date_range("20160921", "20160924", freq="D")
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

other = date_range("20160922", "20160925", freq="D")
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

def test_datetimeindex_diff(self):
@pytest.mark.parametrize("sort", [True, False])
def test_datetimeindex_diff(self, sort):
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=100)
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=98)
assert len(dti1.difference(dti2)) == 2
assert len(dti1.difference(dti2, sort)) == 2

def test_datetimeindex_union_join_empty(self):
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')
17 changes: 12 additions & 5 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
@@ -801,19 +801,26 @@ def test_intersection(self, closed):
result = index.intersection(other)
tm.assert_index_equal(result, expected)

def test_difference(self, closed):
index = self.create_index(closed=closed)
tm.assert_index_equal(index.difference(index[:1]), index[1:])
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, closed, sort):
index = IntervalIndex.from_arrays([1, 0, 3, 2],
[1, 2, 3, 4],
closed=closed)
result = index.difference(index[:1], sort)
expected = index[1:]
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result, expected)

# GH 19101: empty result, same dtype
result = index.difference(index)
result = index.difference(index, sort)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = IntervalIndex.from_arrays(index.left.astype('float64'),
index.right, closed=closed)
result = index.difference(other)
result = index.difference(other, sort)
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self, closed):
Loading