Skip to content

Commit 2e4b31a

Browse files
committed
add more tests
1 parent 351138d commit 2e4b31a

File tree

10 files changed

+165
-75
lines changed

10 files changed

+165
-75
lines changed

Diff for: pandas/core/indexes/interval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ def overlaps(self, other):
10371037
return self._data.overlaps(other)
10381038

10391039
def _setop(op_name):
1040-
def func(self, other):
1040+
def func(self, other, sort=True):
10411041
other = self._as_like_interval_index(other)
10421042

10431043
# GH 19016: ensure set op will not return a prohibited dtype
@@ -1048,7 +1048,11 @@ def func(self, other):
10481048
'objects that have compatible dtypes')
10491049
raise TypeError(msg.format(op=op_name))
10501050

1051-
result = getattr(self._multiindex, op_name)(other._multiindex)
1051+
if op_name == 'difference':
1052+
result = getattr(self._multiindex, op_name)(other._multiindex,
1053+
sort)
1054+
else:
1055+
result = getattr(self._multiindex, op_name)(other._multiindex)
10521056
result_name = get_op_result_name(self, other)
10531057

10541058
# GH 19101: ensure empty results have correct dtype

Diff for: pandas/core/indexes/multi.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2796,8 +2796,14 @@ def difference(self, other, sort=True):
27962796
labels=[[]] * self.nlevels,
27972797
names=result_names, verify_integrity=False)
27982798

2799-
difference = set(self._ndarray_values) - set(other._ndarray_values)
2799+
this = self._get_unique_index()
28002800

2801+
indexer = this.get_indexer(other)
2802+
indexer = indexer.take((indexer != -1).nonzero()[0])
2803+
2804+
label_diff = np.setdiff1d(np.arange(this.size), indexer,
2805+
assume_unique=True)
2806+
difference = this.values.take(label_diff)
28012807
if sort:
28022808
difference = sorted(difference)
28032809

Diff for: pandas/tests/indexes/common.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,13 @@ def test_union_base(self):
666666
with pytest.raises(TypeError, match=msg):
667667
first.union([1, 2, 3])
668668

669-
def test_difference_base(self):
669+
@pytest.mark.parametrize("sort", [True, False])
670+
def test_difference_base(self, sort):
670671
for name, idx in compat.iteritems(self.indices):
671672
first = idx[2:]
672673
second = idx[:4]
673674
answer = idx[4:]
674-
result = first.difference(second)
675+
result = first.difference(second, sort)
675676

676677
if isinstance(idx, CategoricalIndex):
677678
pass
@@ -685,21 +686,21 @@ def test_difference_base(self):
685686
if isinstance(idx, PeriodIndex):
686687
msg = "can only call with other PeriodIndex-ed objects"
687688
with pytest.raises(ValueError, match=msg):
688-
first.difference(case)
689+
first.difference(case, sort)
689690
elif isinstance(idx, CategoricalIndex):
690691
pass
691692
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
692693
assert result.__class__ == answer.__class__
693694
tm.assert_numpy_array_equal(result.sort_values().asi8,
694695
answer.sort_values().asi8)
695696
else:
696-
result = first.difference(case)
697+
result = first.difference(case, sort)
697698
assert tm.equalContents(result, answer)
698699

699700
if isinstance(idx, MultiIndex):
700701
msg = "other must be a MultiIndex or a list of tuples"
701702
with pytest.raises(TypeError, match=msg):
702-
first.difference([1, 2, 3])
703+
first.difference([1, 2, 3], sort)
703704

704705
def test_symmetric_difference(self):
705706
for name, idx in compat.iteritems(self.indices):

Diff for: pandas/tests/indexes/datetimes/test_setops.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -209,47 +209,55 @@ def test_intersection_bug_1708(self):
209209
assert len(result) == 0
210210

211211
@pytest.mark.parametrize("tz", tz)
212-
def test_difference(self, tz):
213-
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
212+
@pytest.mark.parametrize("sort", [True, False])
213+
def test_difference(self, tz, sort):
214+
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
215+
'1/5/2000']
216+
217+
rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
214218
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
215-
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
219+
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)
216220

217-
rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
221+
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
218222
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
219-
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
223+
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)
220224

221-
rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
225+
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
222226
other3 = pd.DatetimeIndex([], tz=tz)
223-
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
227+
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)
224228

225229
for rng, other, expected in [(rng1, other1, expected1),
226230
(rng2, other2, expected2),
227231
(rng3, other3, expected3)]:
228-
result_diff = rng.difference(other)
232+
result_diff = rng.difference(other, sort)
233+
if sort:
234+
expected = expected.sort_values()
229235
tm.assert_index_equal(result_diff, expected)
230236

231-
def test_difference_freq(self):
237+
@pytest.mark.parametrize("sort", [True, False])
238+
def test_difference_freq(self, sort):
232239
# GH14323: difference of DatetimeIndex should not preserve frequency
233240

234241
index = date_range("20160920", "20160925", freq="D")
235242
other = date_range("20160921", "20160924", freq="D")
236243
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
237-
idx_diff = index.difference(other)
244+
idx_diff = index.difference(other, sort)
238245
tm.assert_index_equal(idx_diff, expected)
239246
tm.assert_attr_equal('freq', idx_diff, expected)
240247

241248
other = date_range("20160922", "20160925", freq="D")
242-
idx_diff = index.difference(other)
249+
idx_diff = index.difference(other, sort)
243250
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
244251
tm.assert_index_equal(idx_diff, expected)
245252
tm.assert_attr_equal('freq', idx_diff, expected)
246253

247-
def test_datetimeindex_diff(self):
254+
@pytest.mark.parametrize("sort", [True, False])
255+
def test_datetimeindex_diff(self, sort):
248256
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
249257
periods=100)
250258
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
251259
periods=98)
252-
assert len(dti1.difference(dti2)) == 2
260+
assert len(dti1.difference(dti2, sort)) == 2
253261

254262
def test_datetimeindex_union_join_empty(self):
255263
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')

Diff for: pandas/tests/indexes/interval/test_interval.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,26 @@ def test_intersection(self, closed):
801801
result = index.intersection(other)
802802
tm.assert_index_equal(result, expected)
803803

804-
def test_difference(self, closed):
805-
index = self.create_index(closed=closed)
806-
tm.assert_index_equal(index.difference(index[:1]), index[1:])
804+
@pytest.mark.parametrize("sort", [True, False])
805+
def test_difference(self, closed, sort):
806+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
807+
[1, 2, 3, 4],
808+
closed=closed)
809+
result = index.difference(index[:1], sort)
810+
expected = index[1:]
811+
if sort:
812+
expected = expected.sort_values()
813+
tm.assert_index_equal(result, expected)
807814

808815
# GH 19101: empty result, same dtype
809-
result = index.difference(index)
816+
result = index.difference(index, sort)
810817
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
811818
tm.assert_index_equal(result, expected)
812819

813820
# GH 19101: empty result, different dtypes
814821
other = IntervalIndex.from_arrays(index.left.astype('float64'),
815822
index.right, closed=closed)
816-
result = index.difference(other)
823+
result = index.difference(other, sort)
817824
tm.assert_index_equal(result, expected)
818825

819826
def test_symmetric_difference(self, closed):

Diff for: pandas/tests/indexes/multi/test_set_ops.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pandas.util.testing as tm
88
from pandas import MultiIndex, Series
9+
import pytest
910

1011

1112
@pytest.mark.parametrize("case", [0.5, "xxx"])
@@ -56,24 +57,25 @@ def test_union_base(idx):
5657
first.union([1, 2, 3])
5758

5859

59-
def test_difference_base(idx):
60+
@pytest.mark.parametrize("sort", [True, False])
61+
def test_difference_base(idx, sort):
6062
first = idx[2:]
6163
second = idx[:4]
6264
answer = idx[4:]
63-
result = first.difference(second)
65+
result = first.difference(second, sort)
6466

6567
assert tm.equalContents(result, answer)
6668

6769
# GH 10149
6870
cases = [klass(second.values)
6971
for klass in [np.array, Series, list]]
7072
for case in cases:
71-
result = first.difference(case)
73+
result = first.difference(case, sort)
7274
assert tm.equalContents(result, answer)
7375

7476
msg = "other must be a MultiIndex or a list of tuples"
7577
with pytest.raises(TypeError, match=msg):
76-
first.difference([1, 2, 3])
78+
first.difference([1, 2, 3], sort)
7779

7880

7981
def test_symmetric_difference(idx):
@@ -101,11 +103,17 @@ def test_empty(idx):
101103
assert idx[:0].empty
102104

103105

104-
def test_difference(idx):
106+
@pytest.mark.parametrize("sort", [True, False])
107+
def test_difference(idx, sort):
105108

106109
first = idx
107-
result = first.difference(idx[-3:])
108-
expected = MultiIndex.from_tuples(sorted(idx[:-3].values),
110+
result = first.difference(idx[-3:], sort)
111+
vals = idx[:-3].values
112+
113+
if sort:
114+
vals = sorted(vals)
115+
116+
expected = MultiIndex.from_tuples(vals,
109117
sortorder=0,
110118
names=idx.names)
111119

@@ -114,44 +122,44 @@ def test_difference(idx):
114122
assert result.names == idx.names
115123

116124
# empty difference: reflexive
117-
result = idx.difference(idx)
125+
result = idx.difference(idx, sort)
118126
expected = idx[:0]
119127
assert result.equals(expected)
120128
assert result.names == idx.names
121129

122130
# empty difference: superset
123-
result = idx[-3:].difference(idx)
131+
result = idx[-3:].difference(idx, sort)
124132
expected = idx[:0]
125133
assert result.equals(expected)
126134
assert result.names == idx.names
127135

128136
# empty difference: degenerate
129-
result = idx[:0].difference(idx)
137+
result = idx[:0].difference(idx, sort)
130138
expected = idx[:0]
131139
assert result.equals(expected)
132140
assert result.names == idx.names
133141

134142
# names not the same
135143
chunklet = idx[-3:]
136144
chunklet.names = ['foo', 'baz']
137-
result = first.difference(chunklet)
145+
result = first.difference(chunklet, sort)
138146
assert result.names == (None, None)
139147

140148
# empty, but non-equal
141-
result = idx.difference(idx.sortlevel(1)[0])
149+
result = idx.difference(idx.sortlevel(1)[0], sort)
142150
assert len(result) == 0
143151

144152
# raise Exception called with non-MultiIndex
145-
result = first.difference(first.values)
153+
result = first.difference(first.values, sort)
146154
assert result.equals(first[:0])
147155

148156
# name from empty array
149-
result = first.difference([])
157+
result = first.difference([], sort)
150158
assert first.equals(result)
151159
assert first.names == result.names
152160

153161
# name from non-empty array
154-
result = first.difference([('foo', 'one')])
162+
result = first.difference([('foo', 'one')], sort)
155163
expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
156164
'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
157165
expected.names = first.names

Diff for: pandas/tests/indexes/period/test_period.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,21 @@ def test_no_millisecond_field(self):
7272
with pytest.raises(AttributeError):
7373
DatetimeIndex([]).millisecond
7474

75-
def test_difference_freq(self):
75+
@pytest.mark.parametrize("sort", [True, False])
76+
def test_difference_freq(self, sort):
7677
# GH14323: difference of Period MUST preserve frequency
7778
# but the ability to union results must be preserved
7879

7980
index = period_range("20160920", "20160925", freq="D")
8081

8182
other = period_range("20160921", "20160924", freq="D")
8283
expected = PeriodIndex(["20160920", "20160925"], freq='D')
83-
idx_diff = index.difference(other)
84+
idx_diff = index.difference(other, sort)
8485
tm.assert_index_equal(idx_diff, expected)
8586
tm.assert_attr_equal('freq', idx_diff, expected)
8687

8788
other = period_range("20160922", "20160925", freq="D")
88-
idx_diff = index.difference(other)
89+
idx_diff = index.difference(other, sort)
8990
expected = PeriodIndex(["20160920", "20160921"], freq='D')
9091
tm.assert_index_equal(idx_diff, expected)
9192
tm.assert_attr_equal('freq', idx_diff, expected)

Diff for: pandas/tests/indexes/period/test_setops.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -203,37 +203,49 @@ def test_intersection_cases(self):
203203
result = rng.intersection(rng[0:0])
204204
assert len(result) == 0
205205

206-
def test_difference(self):
206+
@pytest.mark.parametrize("sort", [True, False])
207+
def test_difference(self, sort):
207208
# diff
208-
rng1 = pd.period_range('1/1/2000', freq='D', periods=5)
209+
period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000',
210+
'1/4/2000']
211+
rng1 = pd.PeriodIndex(period_rng, freq='D')
209212
other1 = pd.period_range('1/6/2000', freq='D', periods=5)
210-
expected1 = pd.period_range('1/1/2000', freq='D', periods=5)
213+
expected1 = rng1
211214

212-
rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
215+
rng2 = pd.PeriodIndex(period_rng, freq='D')
213216
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
214-
expected2 = pd.period_range('1/1/2000', freq='D', periods=3)
217+
expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'],
218+
freq='D')
215219

216-
rng3 = pd.period_range('1/1/2000', freq='D', periods=5)
220+
rng3 = pd.PeriodIndex(period_rng, freq='D')
217221
other3 = pd.PeriodIndex([], freq='D')
218-
expected3 = pd.period_range('1/1/2000', freq='D', periods=5)
222+
expected3 = rng3
219223

220-
rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5)
224+
period_rng = ['2000-01-01 10:00', '2000-01-01 09:00',
225+
'2000-01-01 12:00', '2000-01-01 11:00',
226+
'2000-01-01 13:00']
227+
rng4 = pd.PeriodIndex(period_rng, freq='H')
221228
other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5)
222229
expected4 = rng4
223230

224-
rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03',
231+
rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01',
225232
'2000-01-01 09:05'], freq='T')
226233
other5 = pd.PeriodIndex(
227234
['2000-01-01 09:01', '2000-01-01 09:05'], freq='T')
228235
expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T')
229236

230-
rng6 = pd.period_range('2000-01-01', freq='M', periods=7)
237+
period_rng = ['2000-02-01', '2000-01-01', '2000-06-01',
238+
'2000-07-01', '2000-05-01', '2000-03-01',
239+
'2000-04-01']
240+
rng6 = pd.PeriodIndex(period_rng, freq='M')
231241
other6 = pd.period_range('2000-04-01', freq='M', periods=7)
232-
expected6 = pd.period_range('2000-01-01', freq='M', periods=3)
242+
expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'],
243+
freq='M')
233244

234-
rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
245+
period_rng = ['2003', '2007', '2006', '2005', '2004']
246+
rng7 = pd.PeriodIndex(period_rng, freq='A')
235247
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
236-
expected7 = pd.period_range('2006-01-01', freq='A', periods=2)
248+
expected7 = pd.PeriodIndex(['2007', '2006'], freq='A')
237249

238250
for rng, other, expected in [(rng1, other1, expected1),
239251
(rng2, other2, expected2),
@@ -242,5 +254,7 @@ def test_difference(self):
242254
(rng5, other5, expected5),
243255
(rng6, other6, expected6),
244256
(rng7, other7, expected7), ]:
245-
result_union = rng.difference(other)
257+
result_union = rng.difference(other, sort)
258+
if sort:
259+
expected = expected.sort_values()
246260
tm.assert_index_equal(result_union, expected)

0 commit comments

Comments
 (0)