Skip to content

Commit 922f1eb

Browse files
committed
add more tests
1 parent b13db31 commit 922f1eb

File tree

10 files changed

+165
-75
lines changed

10 files changed

+165
-75
lines changed

Diff for: pandas/core/indexes/interval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ def equals(self, other):
10291029
self.closed == other.closed)
10301030

10311031
def _setop(op_name):
1032-
def func(self, other):
1032+
def func(self, other, sort=True):
10331033
other = self._as_like_interval_index(other)
10341034

10351035
# GH 19016: ensure set op will not return a prohibited dtype
@@ -1040,7 +1040,11 @@ def func(self, other):
10401040
'objects that have compatible dtypes')
10411041
raise TypeError(msg.format(op=op_name))
10421042

1043-
result = getattr(self._multiindex, op_name)(other._multiindex)
1043+
if op_name == 'difference':
1044+
result = getattr(self._multiindex, op_name)(other._multiindex,
1045+
sort)
1046+
else:
1047+
result = getattr(self._multiindex, op_name)(other._multiindex)
10441048
result_name = self.name if self.name == other.name else None
10451049

10461050
# GH 19101: ensure empty results have correct dtype

Diff for: pandas/core/indexes/multi.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2791,8 +2791,14 @@ def difference(self, other, sort=True):
27912791
labels=[[]] * self.nlevels,
27922792
names=result_names, verify_integrity=False)
27932793

2794-
difference = set(self._ndarray_values) - set(other._ndarray_values)
2794+
this = self._get_unique_index()
27952795

2796+
indexer = this.get_indexer(other)
2797+
indexer = indexer.take((indexer != -1).nonzero()[0])
2798+
2799+
label_diff = np.setdiff1d(np.arange(this.size), indexer,
2800+
assume_unique=True)
2801+
difference = this.values.take(label_diff)
27962802
if sort:
27972803
difference = sorted(difference)
27982804

Diff for: pandas/tests/indexes/common.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -668,12 +668,13 @@ def test_union_base(self):
668668
with tm.assert_raises_regex(TypeError, msg):
669669
result = first.union([1, 2, 3])
670670

671-
def test_difference_base(self):
671+
@pytest.mark.parametrize("sort", [True, False])
672+
def test_difference_base(self, sort):
672673
for name, idx in compat.iteritems(self.indices):
673674
first = idx[2:]
674675
second = idx[:4]
675676
answer = idx[4:]
676-
result = first.difference(second)
677+
result = first.difference(second, sort)
677678

678679
if isinstance(idx, CategoricalIndex):
679680
pass
@@ -687,21 +688,21 @@ def test_difference_base(self):
687688
if isinstance(idx, PeriodIndex):
688689
msg = "can only call with other PeriodIndex-ed objects"
689690
with tm.assert_raises_regex(ValueError, msg):
690-
result = first.difference(case)
691+
result = first.difference(case, sort)
691692
elif isinstance(idx, CategoricalIndex):
692693
pass
693694
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
694695
assert result.__class__ == answer.__class__
695696
tm.assert_numpy_array_equal(result.sort_values().asi8,
696697
answer.sort_values().asi8)
697698
else:
698-
result = first.difference(case)
699+
result = first.difference(case, sort)
699700
assert tm.equalContents(result, answer)
700701

701702
if isinstance(idx, MultiIndex):
702703
msg = "other must be a MultiIndex or a list of tuples"
703704
with tm.assert_raises_regex(TypeError, msg):
704-
result = first.difference([1, 2, 3])
705+
result = first.difference([1, 2, 3], sort)
705706

706707
def test_symmetric_difference(self):
707708
for name, idx in compat.iteritems(self.indices):

Diff for: pandas/tests/indexes/datetimes/test_setops.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -206,47 +206,55 @@ def test_intersection_bug_1708(self):
206206
assert len(result) == 0
207207

208208
@pytest.mark.parametrize("tz", tz)
209-
def test_difference(self, tz):
210-
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
209+
@pytest.mark.parametrize("sort", [True, False])
210+
def test_difference(self, tz, sort):
211+
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
212+
'1/5/2000']
213+
214+
rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
211215
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
212-
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
216+
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)
213217

214-
rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
218+
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
215219
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
216-
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
220+
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)
217221

218-
rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
222+
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
219223
other3 = pd.DatetimeIndex([], tz=tz)
220-
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
224+
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)
221225

222226
for rng, other, expected in [(rng1, other1, expected1),
223227
(rng2, other2, expected2),
224228
(rng3, other3, expected3)]:
225-
result_diff = rng.difference(other)
229+
result_diff = rng.difference(other, sort)
230+
if sort:
231+
expected = expected.sort_values()
226232
tm.assert_index_equal(result_diff, expected)
227233

228-
def test_difference_freq(self):
234+
@pytest.mark.parametrize("sort", [True, False])
235+
def test_difference_freq(self, sort):
229236
# GH14323: difference of DatetimeIndex should not preserve frequency
230237

231238
index = date_range("20160920", "20160925", freq="D")
232239
other = date_range("20160921", "20160924", freq="D")
233240
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
234-
idx_diff = index.difference(other)
241+
idx_diff = index.difference(other, sort)
235242
tm.assert_index_equal(idx_diff, expected)
236243
tm.assert_attr_equal('freq', idx_diff, expected)
237244

238245
other = date_range("20160922", "20160925", freq="D")
239-
idx_diff = index.difference(other)
246+
idx_diff = index.difference(other, sort)
240247
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
241248
tm.assert_index_equal(idx_diff, expected)
242249
tm.assert_attr_equal('freq', idx_diff, expected)
243250

244-
def test_datetimeindex_diff(self):
251+
@pytest.mark.parametrize("sort", [True, False])
252+
def test_datetimeindex_diff(self, sort):
245253
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
246254
periods=100)
247255
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
248256
periods=98)
249-
assert len(dti1.difference(dti2)) == 2
257+
assert len(dti1.difference(dti2, sort)) == 2
250258

251259
def test_datetimeindex_union_join_empty(self):
252260
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')

Diff for: pandas/tests/indexes/interval/test_interval.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -798,19 +798,26 @@ def test_intersection(self, closed):
798798
result = index.intersection(other)
799799
tm.assert_index_equal(result, expected)
800800

801-
def test_difference(self, closed):
802-
index = self.create_index(closed=closed)
803-
tm.assert_index_equal(index.difference(index[:1]), index[1:])
801+
@pytest.mark.parametrize("sort", [True, False])
802+
def test_difference(self, closed, sort):
803+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
804+
[1, 2, 3, 4],
805+
closed=closed)
806+
result = index.difference(index[:1], sort)
807+
expected = index[1:]
808+
if sort:
809+
expected = expected.sort_values()
810+
tm.assert_index_equal(result, expected)
804811

805812
# GH 19101: empty result, same dtype
806-
result = index.difference(index)
813+
result = index.difference(index, sort)
807814
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
808815
tm.assert_index_equal(result, expected)
809816

810817
# GH 19101: empty result, different dtypes
811818
other = IntervalIndex.from_arrays(index.left.astype('float64'),
812819
index.right, closed=closed)
813-
result = index.difference(other)
820+
result = index.difference(other, sort)
814821
tm.assert_index_equal(result, expected)
815822

816823
def test_symmetric_difference(self, closed):

Diff for: pandas/tests/indexes/multi/test_set_ops.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pandas.util.testing as tm
66
from pandas import MultiIndex, Series
7+
import pytest
78

89

910
def test_setops_errorcases(idx):
@@ -58,24 +59,25 @@ def test_union_base(idx):
5859
result = first.union([1, 2, 3])
5960

6061

61-
def test_difference_base(idx):
62+
@pytest.mark.parametrize("sort", [True, False])
63+
def test_difference_base(idx, sort):
6264
first = idx[2:]
6365
second = idx[:4]
6466
answer = idx[4:]
65-
result = first.difference(second)
67+
result = first.difference(second, sort)
6668

6769
assert tm.equalContents(result, answer)
6870

6971
# GH 10149
7072
cases = [klass(second.values)
7173
for klass in [np.array, Series, list]]
7274
for case in cases:
73-
result = first.difference(case)
75+
result = first.difference(case, sort)
7476
assert tm.equalContents(result, answer)
7577

7678
msg = "other must be a MultiIndex or a list of tuples"
7779
with tm.assert_raises_regex(TypeError, msg):
78-
result = first.difference([1, 2, 3])
80+
result = first.difference([1, 2, 3], sort)
7981

8082

8183
def test_symmetric_difference(idx):
@@ -103,11 +105,17 @@ def test_empty(idx):
103105
assert idx[:0].empty
104106

105107

106-
def test_difference(idx):
108+
@pytest.mark.parametrize("sort", [True, False])
109+
def test_difference(idx, sort):
107110

108111
first = idx
109-
result = first.difference(idx[-3:])
110-
expected = MultiIndex.from_tuples(sorted(idx[:-3].values),
112+
result = first.difference(idx[-3:], sort)
113+
vals = idx[:-3].values
114+
115+
if sort:
116+
vals = sorted(vals)
117+
118+
expected = MultiIndex.from_tuples(vals,
111119
sortorder=0,
112120
names=idx.names)
113121

@@ -116,44 +124,44 @@ def test_difference(idx):
116124
assert result.names == idx.names
117125

118126
# empty difference: reflexive
119-
result = idx.difference(idx)
127+
result = idx.difference(idx, sort)
120128
expected = idx[:0]
121129
assert result.equals(expected)
122130
assert result.names == idx.names
123131

124132
# empty difference: superset
125-
result = idx[-3:].difference(idx)
133+
result = idx[-3:].difference(idx, sort)
126134
expected = idx[:0]
127135
assert result.equals(expected)
128136
assert result.names == idx.names
129137

130138
# empty difference: degenerate
131-
result = idx[:0].difference(idx)
139+
result = idx[:0].difference(idx, sort)
132140
expected = idx[:0]
133141
assert result.equals(expected)
134142
assert result.names == idx.names
135143

136144
# names not the same
137145
chunklet = idx[-3:]
138146
chunklet.names = ['foo', 'baz']
139-
result = first.difference(chunklet)
147+
result = first.difference(chunklet, sort)
140148
assert result.names == (None, None)
141149

142150
# empty, but non-equal
143-
result = idx.difference(idx.sortlevel(1)[0])
151+
result = idx.difference(idx.sortlevel(1)[0], sort)
144152
assert len(result) == 0
145153

146154
# raise Exception called with non-MultiIndex
147-
result = first.difference(first.values)
155+
result = first.difference(first.values, sort)
148156
assert result.equals(first[:0])
149157

150158
# name from empty array
151-
result = first.difference([])
159+
result = first.difference([], sort)
152160
assert first.equals(result)
153161
assert first.names == result.names
154162

155163
# name from non-empty array
156-
result = first.difference([('foo', 'one')])
164+
result = first.difference([('foo', 'one')], sort)
157165
expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
158166
'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
159167
expected.names = first.names

Diff for: pandas/tests/indexes/period/test_period.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,21 @@ def test_no_millisecond_field(self):
7272
with pytest.raises(AttributeError):
7373
DatetimeIndex([]).millisecond
7474

75-
def test_difference_freq(self):
75+
@pytest.mark.parametrize("sort", [True, False])
76+
def test_difference_freq(self, sort):
7677
# GH14323: difference of Period MUST preserve frequency
7778
# but the ability to union results must be preserved
7879

7980
index = period_range("20160920", "20160925", freq="D")
8081

8182
other = period_range("20160921", "20160924", freq="D")
8283
expected = PeriodIndex(["20160920", "20160925"], freq='D')
83-
idx_diff = index.difference(other)
84+
idx_diff = index.difference(other, sort)
8485
tm.assert_index_equal(idx_diff, expected)
8586
tm.assert_attr_equal('freq', idx_diff, expected)
8687

8788
other = period_range("20160922", "20160925", freq="D")
88-
idx_diff = index.difference(other)
89+
idx_diff = index.difference(other, sort)
8990
expected = PeriodIndex(["20160920", "20160921"], freq='D')
9091
tm.assert_index_equal(idx_diff, expected)
9192
tm.assert_attr_equal('freq', idx_diff, expected)

Diff for: pandas/tests/indexes/period/test_setops.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -204,37 +204,49 @@ def test_intersection_cases(self):
204204
result = rng.intersection(rng[0:0])
205205
assert len(result) == 0
206206

207-
def test_difference(self):
207+
@pytest.mark.parametrize("sort", [True, False])
208+
def test_difference(self, sort):
208209
# diff
209-
rng1 = pd.period_range('1/1/2000', freq='D', periods=5)
210+
period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000',
211+
'1/4/2000']
212+
rng1 = pd.PeriodIndex(period_rng, freq='D')
210213
other1 = pd.period_range('1/6/2000', freq='D', periods=5)
211-
expected1 = pd.period_range('1/1/2000', freq='D', periods=5)
214+
expected1 = rng1
212215

213-
rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
216+
rng2 = pd.PeriodIndex(period_rng, freq='D')
214217
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
215-
expected2 = pd.period_range('1/1/2000', freq='D', periods=3)
218+
expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'],
219+
freq='D')
216220

217-
rng3 = pd.period_range('1/1/2000', freq='D', periods=5)
221+
rng3 = pd.PeriodIndex(period_rng, freq='D')
218222
other3 = pd.PeriodIndex([], freq='D')
219-
expected3 = pd.period_range('1/1/2000', freq='D', periods=5)
223+
expected3 = rng3
220224

221-
rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5)
225+
period_rng = ['2000-01-01 10:00', '2000-01-01 09:00',
226+
'2000-01-01 12:00', '2000-01-01 11:00',
227+
'2000-01-01 13:00']
228+
rng4 = pd.PeriodIndex(period_rng, freq='H')
222229
other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5)
223230
expected4 = rng4
224231

225-
rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03',
232+
rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01',
226233
'2000-01-01 09:05'], freq='T')
227234
other5 = pd.PeriodIndex(
228235
['2000-01-01 09:01', '2000-01-01 09:05'], freq='T')
229236
expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T')
230237

231-
rng6 = pd.period_range('2000-01-01', freq='M', periods=7)
238+
period_rng = ['2000-02-01', '2000-01-01', '2000-06-01',
239+
'2000-07-01', '2000-05-01', '2000-03-01',
240+
'2000-04-01']
241+
rng6 = pd.PeriodIndex(period_rng, freq='M')
232242
other6 = pd.period_range('2000-04-01', freq='M', periods=7)
233-
expected6 = pd.period_range('2000-01-01', freq='M', periods=3)
243+
expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'],
244+
freq='M')
234245

235-
rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
246+
period_rng = ['2003', '2007', '2006', '2005', '2004']
247+
rng7 = pd.PeriodIndex(period_rng, freq='A')
236248
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
237-
expected7 = pd.period_range('2006-01-01', freq='A', periods=2)
249+
expected7 = pd.PeriodIndex(['2007', '2006'], freq='A')
238250

239251
for rng, other, expected in [(rng1, other1, expected1),
240252
(rng2, other2, expected2),
@@ -243,5 +255,7 @@ def test_difference(self):
243255
(rng5, other5, expected5),
244256
(rng6, other6, expected6),
245257
(rng7, other7, expected7), ]:
246-
result_union = rng.difference(other)
258+
result_union = rng.difference(other, sort)
259+
if sort:
260+
expected = expected.sort_values()
247261
tm.assert_index_equal(result_union, expected)

0 commit comments

Comments
 (0)