diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index d6f9bb66e1e28..02828ec431aac 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -288,6 +288,7 @@ Other Enhancements - Added :meth:`Interval.overlaps`, :meth:`IntervalArray.overlaps`, and :meth:`IntervalIndex.overlaps` for determining overlaps between interval-like objects (:issue:`21998`) - :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`) - :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`) +- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`) - :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`) - :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object. - :meth:`DataFrame.to_stata` and :class:` pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`) @@ -1408,6 +1409,7 @@ Groupby/Resample/Rolling - Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`) - Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`) - Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`). +- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 96aff09126772..d2dc5f16de7f8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -494,7 +494,8 @@ def _set_group_selection(self): if len(groupers): # GH12839 clear selected obj cache when group selection changes - self._group_selection = ax.difference(Index(groupers)).tolist() + self._group_selection = ax.difference(Index(groupers), + sort=False).tolist() self._reset_cache('_selected_obj') def _set_result_index_ordered(self, result): diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 0632198c77262..0fa6973b717e9 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2944,17 +2944,20 @@ def intersection(self, other): taken.name = None return taken - def difference(self, other): + def difference(self, other, sort=True): """ Return a new Index with elements from the index that are not in `other`. This is the set difference of two Index objects. - It's sorted if sorting is possible. Parameters ---------- other : Index or array-like + sort : bool, default True + Sort the resulting index if possible + + .. versionadded:: 0.24.0 Returns ------- @@ -2963,10 +2966,12 @@ def difference(self, other): Examples -------- - >>> idx1 = pd.Index([1, 2, 3, 4]) + >>> idx1 = pd.Index([2, 1, 3, 4]) >>> idx2 = pd.Index([3, 4, 5, 6]) >>> idx1.difference(idx2) Int64Index([1, 2], dtype='int64') + >>> idx1.difference(idx2, sort=False) + Int64Index([2, 1], dtype='int64') """ self._assert_can_do_setop(other) @@ -2985,10 +2990,11 @@ def difference(self, other): label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True) the_diff = this.values.take(label_diff) - try: - the_diff = sorting.safe_sort(the_diff) - except TypeError: - pass + if sort: + try: + the_diff = sorting.safe_sort(the_diff) + except TypeError: + pass return this._shallow_copy(the_diff, name=result_name, freq=None) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 2b157bf91c5a2..c64a179a299e9 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1037,7 +1037,7 @@ def overlaps(self, other): return self._data.overlaps(other) def _setop(op_name): - def func(self, other): + def func(self, other, sort=True): other = self._as_like_interval_index(other) # GH 19016: ensure set op will not return a prohibited dtype @@ -1048,7 +1048,11 @@ def func(self, other): 'objects that have compatible dtypes') raise TypeError(msg.format(op=op_name)) - result = getattr(self._multiindex, op_name)(other._multiindex) + if op_name == 'difference': + result = getattr(self._multiindex, op_name)(other._multiindex, + sort) + else: + result = getattr(self._multiindex, op_name)(other._multiindex) result_name = get_op_result_name(self, other) # GH 19101: ensure empty results have correct dtype diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index dbb1b8e196bf7..619e1ae866a1b 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -2798,10 +2798,18 @@ def intersection(self, other): return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0, names=result_names) - def difference(self, other): + def difference(self, other, sort=True): """ Compute sorted set difference of two MultiIndex objects + Parameters + ---------- + other : MultiIndex + sort : bool, default True + Sort the resulting MultiIndex if possible + + .. versionadded:: 0.24.0 + Returns ------- diff : MultiIndex @@ -2817,8 +2825,16 @@ def difference(self, other): labels=[[]] * self.nlevels, names=result_names, verify_integrity=False) - difference = sorted(set(self._ndarray_values) - - set(other._ndarray_values)) + this = self._get_unique_index() + + indexer = this.get_indexer(other) + indexer = indexer.take((indexer != -1).nonzero()[0]) + + label_diff = np.setdiff1d(np.arange(this.size), indexer, + assume_unique=True) + difference = this.values.take(label_diff) + if sort: + difference = sorted(difference) if len(difference) == 0: return MultiIndex(levels=[[]] * self.nlevels, diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index a1b748cd50e8f..4ea4b580a2c3f 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -390,3 +390,27 @@ def test_nth_empty(): names=['a', 'b']), columns=['c']) assert_frame_equal(result, expected) + + +def test_nth_column_order(): + # GH 20760 + # Check that nth preserves column order + df = DataFrame([[1, 'b', 100], + [1, 'a', 50], + [1, 'a', np.nan], + [2, 'c', 200], + [2, 'd', 150]], + columns=['A', 'C', 'B']) + result = df.groupby('A').nth(0) + expected = DataFrame([['b', 100.0], + ['c', 200.0]], + columns=['C', 'B'], + index=Index([1, 2], name='A')) + assert_frame_equal(result, expected) + + result = df.groupby('A').nth(-1, dropna='any') + expected = DataFrame([['a', 50.0], + ['d', 150.0]], + columns=['C', 'B'], + index=Index([1, 2], name='A')) + assert_frame_equal(result, expected) diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index 4b0daac34c2e3..7f1cf143a3a6e 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -666,12 +666,13 @@ def test_union_base(self): with pytest.raises(TypeError, match=msg): first.union([1, 2, 3]) - def test_difference_base(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_base(self, sort): for name, idx in compat.iteritems(self.indices): first = idx[2:] second = idx[:4] answer = idx[4:] - result = first.difference(second) + result = first.difference(second, sort) if isinstance(idx, CategoricalIndex): pass @@ -685,7 +686,7 @@ def test_difference_base(self): if isinstance(idx, PeriodIndex): msg = "can only call with other PeriodIndex-ed objects" with pytest.raises(ValueError, match=msg): - first.difference(case) + first.difference(case, sort) elif isinstance(idx, CategoricalIndex): pass elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)): @@ -693,13 +694,13 @@ def test_difference_base(self): tm.assert_numpy_array_equal(result.sort_values().asi8, answer.sort_values().asi8) else: - result = first.difference(case) + result = first.difference(case, sort) assert tm.equalContents(result, answer) if isinstance(idx, MultiIndex): msg = "other must be a MultiIndex or a list of tuples" with pytest.raises(TypeError, match=msg): - first.difference([1, 2, 3]) + first.difference([1, 2, 3], sort) def test_symmetric_difference(self): for name, idx in compat.iteritems(self.indices): diff --git a/pandas/tests/indexes/datetimes/test_setops.py b/pandas/tests/indexes/datetimes/test_setops.py index d72bf275463ac..7c1f753dbeaaa 100644 --- a/pandas/tests/indexes/datetimes/test_setops.py +++ b/pandas/tests/indexes/datetimes/test_setops.py @@ -209,47 +209,55 @@ def test_intersection_bug_1708(self): assert len(result) == 0 @pytest.mark.parametrize("tz", tz) - def test_difference(self, tz): - rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz) + @pytest.mark.parametrize("sort", [True, False]) + def test_difference(self, tz, sort): + rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000', + '1/5/2000'] + + rng1 = pd.DatetimeIndex(rng_dates, tz=tz) other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz) - expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz) + expected1 = pd.DatetimeIndex(rng_dates, tz=tz) - rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz) + rng2 = pd.DatetimeIndex(rng_dates, tz=tz) other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz) - expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz) + expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz) - rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz) + rng3 = pd.DatetimeIndex(rng_dates, tz=tz) other3 = pd.DatetimeIndex([], tz=tz) - expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz) + expected3 = pd.DatetimeIndex(rng_dates, tz=tz) for rng, other, expected in [(rng1, other1, expected1), (rng2, other2, expected2), (rng3, other3, expected3)]: - result_diff = rng.difference(other) + result_diff = rng.difference(other, sort) + if sort: + expected = expected.sort_values() tm.assert_index_equal(result_diff, expected) - def test_difference_freq(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_freq(self, sort): # GH14323: difference of DatetimeIndex should not preserve frequency index = date_range("20160920", "20160925", freq="D") other = date_range("20160921", "20160924", freq="D") expected = DatetimeIndex(["20160920", "20160925"], freq=None) - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) other = date_range("20160922", "20160925", freq="D") - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) expected = DatetimeIndex(["20160920", "20160921"], freq=None) tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) - def test_datetimeindex_diff(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_datetimeindex_diff(self, sort): dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31), periods=100) dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31), periods=98) - assert len(dti1.difference(dti2)) == 2 + assert len(dti1.difference(dti2, sort)) == 2 def test_datetimeindex_union_join_empty(self): dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D') diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index d5f62429ddb73..da3b3253ecbd1 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -801,19 +801,26 @@ def test_intersection(self, closed): result = index.intersection(other) tm.assert_index_equal(result, expected) - def test_difference(self, closed): - index = self.create_index(closed=closed) - tm.assert_index_equal(index.difference(index[:1]), index[1:]) + @pytest.mark.parametrize("sort", [True, False]) + def test_difference(self, closed, sort): + index = IntervalIndex.from_arrays([1, 0, 3, 2], + [1, 2, 3, 4], + closed=closed) + result = index.difference(index[:1], sort) + expected = index[1:] + if sort: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) # GH 19101: empty result, same dtype - result = index.difference(index) + result = index.difference(index, sort) expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) tm.assert_index_equal(result, expected) # GH 19101: empty result, different dtypes other = IntervalIndex.from_arrays(index.left.astype('float64'), index.right, closed=closed) - result = index.difference(other) + result = index.difference(other, sort) tm.assert_index_equal(result, expected) def test_symmetric_difference(self, closed): diff --git a/pandas/tests/indexes/multi/test_set_ops.py b/pandas/tests/indexes/multi/test_set_ops.py index 34da3df4fb16e..91edf11e77f10 100644 --- a/pandas/tests/indexes/multi/test_set_ops.py +++ b/pandas/tests/indexes/multi/test_set_ops.py @@ -56,11 +56,12 @@ def test_union_base(idx): first.union([1, 2, 3]) -def test_difference_base(idx): +@pytest.mark.parametrize("sort", [True, False]) +def test_difference_base(idx, sort): first = idx[2:] second = idx[:4] answer = idx[4:] - result = first.difference(second) + result = first.difference(second, sort) assert tm.equalContents(result, answer) @@ -68,12 +69,12 @@ def test_difference_base(idx): cases = [klass(second.values) for klass in [np.array, Series, list]] for case in cases: - result = first.difference(case) + result = first.difference(case, sort) assert tm.equalContents(result, answer) msg = "other must be a MultiIndex or a list of tuples" with pytest.raises(TypeError, match=msg): - first.difference([1, 2, 3]) + first.difference([1, 2, 3], sort) def test_symmetric_difference(idx): @@ -101,11 +102,17 @@ def test_empty(idx): assert idx[:0].empty -def test_difference(idx): +@pytest.mark.parametrize("sort", [True, False]) +def test_difference(idx, sort): first = idx - result = first.difference(idx[-3:]) - expected = MultiIndex.from_tuples(sorted(idx[:-3].values), + result = first.difference(idx[-3:], sort) + vals = idx[:-3].values + + if sort: + vals = sorted(vals) + + expected = MultiIndex.from_tuples(vals, sortorder=0, names=idx.names) @@ -114,19 +121,19 @@ def test_difference(idx): assert result.names == idx.names # empty difference: reflexive - result = idx.difference(idx) + result = idx.difference(idx, sort) expected = idx[:0] assert result.equals(expected) assert result.names == idx.names # empty difference: superset - result = idx[-3:].difference(idx) + result = idx[-3:].difference(idx, sort) expected = idx[:0] assert result.equals(expected) assert result.names == idx.names # empty difference: degenerate - result = idx[:0].difference(idx) + result = idx[:0].difference(idx, sort) expected = idx[:0] assert result.equals(expected) assert result.names == idx.names @@ -134,24 +141,24 @@ def test_difference(idx): # names not the same chunklet = idx[-3:] chunklet.names = ['foo', 'baz'] - result = first.difference(chunklet) + result = first.difference(chunklet, sort) assert result.names == (None, None) # empty, but non-equal - result = idx.difference(idx.sortlevel(1)[0]) + result = idx.difference(idx.sortlevel(1)[0], sort) assert len(result) == 0 # raise Exception called with non-MultiIndex - result = first.difference(first.values) + result = first.difference(first.values, sort) assert result.equals(first[:0]) # name from empty array - result = first.difference([]) + result = first.difference([], sort) assert first.equals(result) assert first.names == result.names # name from non-empty array - result = first.difference([('foo', 'one')]) + result = first.difference([('foo', 'one')], sort) expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), ( 'foo', 'two'), ('qux', 'one'), ('qux', 'two')]) expected.names = first.names diff --git a/pandas/tests/indexes/period/test_period.py b/pandas/tests/indexes/period/test_period.py index ddb3fe686534a..5d78333016f74 100644 --- a/pandas/tests/indexes/period/test_period.py +++ b/pandas/tests/indexes/period/test_period.py @@ -72,7 +72,8 @@ def test_no_millisecond_field(self): with pytest.raises(AttributeError): DatetimeIndex([]).millisecond - def test_difference_freq(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_freq(self, sort): # GH14323: difference of Period MUST preserve frequency # but the ability to union results must be preserved @@ -80,12 +81,12 @@ def test_difference_freq(self): other = period_range("20160921", "20160924", freq="D") expected = PeriodIndex(["20160920", "20160925"], freq='D') - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) other = period_range("20160922", "20160925", freq="D") - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) expected = PeriodIndex(["20160920", "20160921"], freq='D') tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) diff --git a/pandas/tests/indexes/period/test_setops.py b/pandas/tests/indexes/period/test_setops.py index c8b7d82855519..565e64607350f 100644 --- a/pandas/tests/indexes/period/test_setops.py +++ b/pandas/tests/indexes/period/test_setops.py @@ -203,37 +203,49 @@ def test_intersection_cases(self): result = rng.intersection(rng[0:0]) assert len(result) == 0 - def test_difference(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference(self, sort): # diff - rng1 = pd.period_range('1/1/2000', freq='D', periods=5) + period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000', + '1/4/2000'] + rng1 = pd.PeriodIndex(period_rng, freq='D') other1 = pd.period_range('1/6/2000', freq='D', periods=5) - expected1 = pd.period_range('1/1/2000', freq='D', periods=5) + expected1 = rng1 - rng2 = pd.period_range('1/1/2000', freq='D', periods=5) + rng2 = pd.PeriodIndex(period_rng, freq='D') other2 = pd.period_range('1/4/2000', freq='D', periods=5) - expected2 = pd.period_range('1/1/2000', freq='D', periods=3) + expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'], + freq='D') - rng3 = pd.period_range('1/1/2000', freq='D', periods=5) + rng3 = pd.PeriodIndex(period_rng, freq='D') other3 = pd.PeriodIndex([], freq='D') - expected3 = pd.period_range('1/1/2000', freq='D', periods=5) + expected3 = rng3 - rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5) + period_rng = ['2000-01-01 10:00', '2000-01-01 09:00', + '2000-01-01 12:00', '2000-01-01 11:00', + '2000-01-01 13:00'] + rng4 = pd.PeriodIndex(period_rng, freq='H') other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5) expected4 = rng4 - rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03', + rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01', '2000-01-01 09:05'], freq='T') other5 = pd.PeriodIndex( ['2000-01-01 09:01', '2000-01-01 09:05'], freq='T') expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T') - rng6 = pd.period_range('2000-01-01', freq='M', periods=7) + period_rng = ['2000-02-01', '2000-01-01', '2000-06-01', + '2000-07-01', '2000-05-01', '2000-03-01', + '2000-04-01'] + rng6 = pd.PeriodIndex(period_rng, freq='M') other6 = pd.period_range('2000-04-01', freq='M', periods=7) - expected6 = pd.period_range('2000-01-01', freq='M', periods=3) + expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'], + freq='M') - rng7 = pd.period_range('2003-01-01', freq='A', periods=5) + period_rng = ['2003', '2007', '2006', '2005', '2004'] + rng7 = pd.PeriodIndex(period_rng, freq='A') other7 = pd.period_range('1998-01-01', freq='A', periods=8) - expected7 = pd.period_range('2006-01-01', freq='A', periods=2) + expected7 = pd.PeriodIndex(['2007', '2006'], freq='A') for rng, other, expected in [(rng1, other1, expected1), (rng2, other2, expected2), @@ -242,5 +254,7 @@ def test_difference(self): (rng5, other5, expected5), (rng6, other6, expected6), (rng7, other7, expected7), ]: - result_union = rng.difference(other) + result_union = rng.difference(other, sort) + if sort: + expected = expected.sort_values() tm.assert_index_equal(result_union, expected) diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index 424f6b1f9a77a..1b3b48075e292 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -29,6 +29,7 @@ from pandas.core.indexes.datetimes import _to_m8 from pandas.tests.indexes.common import Base from pandas.util.testing import assert_almost_equal +from pandas.core.sorting import safe_sort class TestIndex(Base): @@ -1119,7 +1120,8 @@ def test_iadd_string(self): @pytest.mark.parametrize("second_name,expected", [ (None, None), ('name', 'name')]) - def test_difference_name_preservation(self, second_name, expected): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_name_preservation(self, second_name, expected, sort): # TODO: replace with fixturesult first = self.strIndex[5:20] second = self.strIndex[:10] @@ -1127,7 +1129,7 @@ def test_difference_name_preservation(self, second_name, expected): first.name = 'name' second.name = second_name - result = first.difference(second) + result = first.difference(second, sort) assert tm.equalContents(result, answer) @@ -1136,22 +1138,37 @@ def test_difference_name_preservation(self, second_name, expected): else: assert result.name == expected - def test_difference_empty_arg(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_empty_arg(self, sort): first = self.strIndex[5:20] first.name == 'name' - result = first.difference([]) + result = first.difference([], sort) assert tm.equalContents(result, first) assert result.name == first.name - def test_difference_identity(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_identity(self, sort): first = self.strIndex[5:20] first.name == 'name' - result = first.difference(first) + result = first.difference(first, sort) assert len(result) == 0 assert result.name == first.name + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_sort(self, sort): + first = self.strIndex[5:20] + second = self.strIndex[:10] + + result = first.difference(second, sort) + expected = self.strIndex[10:20] + + if sort: + expected = expected.sort_values() + + tm.assert_index_equal(result, expected) + def test_symmetric_difference(self): # smoke index1 = Index([1, 2, 3, 4], name='index1') @@ -1196,17 +1213,19 @@ def test_symmetric_difference_non_index(self): assert tm.equalContents(result, expected) assert result.name == 'new_name' - def test_difference_type(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_type(self, sort): # GH 20040 # If taking difference of a set and itself, it # needs to preserve the type of the index skip_index_keys = ['repeats'] for key, index in self.generate_index_types(skip_index_keys): - result = index.difference(index) + result = index.difference(index, sort) expected = index.drop(index) tm.assert_index_equal(result, expected) - def test_intersection_difference(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_intersection_difference(self, sort): # GH 20040 # Test that the intersection of an index with an # empty index produces the same index as the difference @@ -1214,7 +1233,7 @@ def test_intersection_difference(self): skip_index_keys = ['repeats'] for key, index in self.generate_index_types(skip_index_keys): inter = index.intersection(index.drop(index)) - diff = index.difference(index) + diff = index.difference(index, sort) tm.assert_index_equal(inter, diff) @pytest.mark.parametrize("attr,expected", [ @@ -2424,14 +2443,17 @@ def test_intersection_different_type_base(self, klass): result = first.intersection(klass(second.values)) assert tm.equalContents(result, second) - def test_difference_base(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_base(self, sort): # (same results for py2 and py3 but sortedness not tested elsewhere) index = self.create_index() first = index[:4] second = index[3:] - result = first.difference(second) - expected = Index([0, 1, 'a']) + result = first.difference(second, sort) + expected = Index([0, 'a', 1]) + if sort: + expected = Index(safe_sort(expected)) tm.assert_index_equal(result, expected) def test_symmetric_difference(self): diff --git a/pandas/tests/indexes/timedeltas/test_timedelta.py b/pandas/tests/indexes/timedeltas/test_timedelta.py index 1d068971fad2d..ee92782a87363 100644 --- a/pandas/tests/indexes/timedeltas/test_timedelta.py +++ b/pandas/tests/indexes/timedeltas/test_timedelta.py @@ -53,23 +53,51 @@ def test_fillna_timedelta(self): [pd.Timedelta('1 day'), 'x', pd.Timedelta('3 day')], dtype=object) tm.assert_index_equal(idx.fillna('x'), exp) - def test_difference_freq(self): + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_freq(self, sort): # GH14323: Difference of TimedeltaIndex should not preserve frequency index = timedelta_range("0 days", "5 days", freq="D") other = timedelta_range("1 days", "4 days", freq="D") expected = TimedeltaIndex(["0 days", "5 days"], freq=None) - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) other = timedelta_range("2 days", "5 days", freq="D") - idx_diff = index.difference(other) + idx_diff = index.difference(other, sort) expected = TimedeltaIndex(["0 days", "1 days"], freq=None) tm.assert_index_equal(idx_diff, expected) tm.assert_attr_equal('freq', idx_diff, expected) + @pytest.mark.parametrize("sort", [True, False]) + def test_difference_sort(self, sort): + + index = pd.TimedeltaIndex(["5 days", "3 days", "2 days", "4 days", + "1 days", "0 days"]) + + other = timedelta_range("1 days", "4 days", freq="D") + idx_diff = index.difference(other, sort) + + expected = TimedeltaIndex(["5 days", "0 days"], freq=None) + + if sort: + expected = expected.sort_values() + + tm.assert_index_equal(idx_diff, expected) + tm.assert_attr_equal('freq', idx_diff, expected) + + other = timedelta_range("2 days", "5 days", freq="D") + idx_diff = index.difference(other, sort) + expected = TimedeltaIndex(["1 days", "0 days"], freq=None) + + if sort: + expected = expected.sort_values() + + tm.assert_index_equal(idx_diff, expected) + tm.assert_attr_equal('freq', idx_diff, expected) + def test_isin(self): index = tm.makeTimedeltaIndex(4)