Skip to content

Commit 0ec3600

Browse files
committed
ENH: Sorting of ExtensionArrays
This enables {Series,DataFrame}.sort_values and {Series,DataFrame}.argsort
1 parent 9958ce6 commit 0ec3600

File tree

6 files changed

+158
-8
lines changed

6 files changed

+158
-8
lines changed

pandas/core/arrays/base.py

+21
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,27 @@ def isna(self):
216216
"""
217217
raise AbstractMethodError(self)
218218

219+
def argsort(self, axis=-1, kind='quicksort', order=None):
220+
"""Returns the indices that would sort this array.
221+
222+
Parameters
223+
----------
224+
axis : int or None, optional
225+
Axis along which to sort. ExtensionArrays are 1-dimensional,
226+
so this is only included for compatibility with NumPy.
227+
kind : {'quicksort', 'mergesort', 'heapsort'}, optional
228+
Sorting algorithm.
229+
order : str or list of str, optional
230+
Included for NumPy compatibility.
231+
232+
Returns
233+
-------
234+
index_array : ndarray
235+
Array of indices that sort ``self``.
236+
237+
"""
238+
return np.array(self).argsort(kind=kind)
239+
219240
# ------------------------------------------------------------------------
220241
# Indexing methods
221242
# ------------------------------------------------------------------------

pandas/tests/extension/base/methods.py

+40
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,43 @@ def test_count(self, data_missing):
3131
def test_apply_simple_series(self, data):
3232
result = pd.Series(data).apply(id)
3333
assert isinstance(result, pd.Series)
34+
35+
def test_argsort(self, data_for_sorting):
36+
result = pd.Series(data_for_sorting).argsort()
37+
expected = pd.Series(np.array([2, 0, 1]))
38+
self.assert_series_equal(result, expected)
39+
40+
def test_argsort_missing(self, data_missing_for_sorting):
41+
result = pd.Series(data_missing_for_sorting).argsort()
42+
expected = pd.Series(np.array([1, -1, 0]))
43+
self.assert_series_equal(result, expected)
44+
45+
@pytest.mark.parametrize('ascending', [True, False])
46+
def test_sort_values(self, data_for_sorting, ascending):
47+
ser = pd.Series(data_for_sorting)
48+
result = ser.sort_values(ascending=ascending)
49+
expected = ser.iloc[[2, 0, 1]]
50+
if not ascending:
51+
expected = expected[::-1]
52+
53+
self.assert_series_equal(result, expected)
54+
55+
@pytest.mark.parametrize('ascending', [True, False])
56+
def test_sort_values_missing(self, data_missing_for_sorting, ascending):
57+
ser = pd.Series(data_missing_for_sorting)
58+
result = ser.sort_values(ascending=ascending)
59+
if ascending:
60+
expected = ser.iloc[[2, 0, 1]]
61+
else:
62+
expected = ser.iloc[[0, 2, 1]]
63+
self.assert_series_equal(result, expected)
64+
65+
@pytest.mark.parametrize('ascending', [True, False])
66+
def test_sort_values_frame(self, data_for_sorting, ascending):
67+
df = pd.DataFrame({"A": [1, 2, 1],
68+
"B": data_for_sorting})
69+
result = df.sort_values(['A', 'B'])
70+
expected = pd.DataFrame({"A": [1, 1, 2],
71+
'B': data_for_sorting.take([2, 0, 1])},
72+
index=[2, 0, 1])
73+
self.assert_frame_equal(result, expected)

pandas/tests/extension/category/test_categorical.py

+12
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ def data_missing():
2929
return Categorical([np.nan, 'A'])
3030

3131

32+
@pytest.fixture
33+
def data_for_sorting():
34+
return Categorical(['A', 'B', 'C'], categories=['C', 'A', 'B'],
35+
ordered=True)
36+
37+
38+
@pytest.fixture
39+
def data_missing_for_sorting():
40+
return Categorical(['A', None, 'B'], categories=['B', 'A'],
41+
ordered=True)
42+
43+
3244
@pytest.fixture
3345
def na_value():
3446
return np.nan

pandas/tests/extension/conftest.py

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def all_data(request, data, data_missing):
3030
return data_missing
3131

3232

33+
@pytest.fixture
34+
def data_for_sorting():
35+
"""Length-3 array with a known sort order.
36+
37+
This should be three items [B, C, A] with
38+
A < B < C
39+
"""
40+
raise NotImplementedError
41+
42+
43+
@pytest.fixture
44+
def data_missing_for_sorting():
45+
"""Length-3 array with a known sort order.
46+
47+
This should be three items [B, NA, A] with
48+
A < B and NA missing.
49+
"""
50+
raise NotImplementedError
51+
52+
3353
@pytest.fixture
3454
def na_cmp():
3555
"""Binary operator for comparing NA values.

pandas/tests/extension/decimal/test_decimal.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ def data_missing():
2525
return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])
2626

2727

28+
@pytest.fixture
29+
def data_for_sorting():
30+
return DecimalArray([decimal.Decimal('1'),
31+
decimal.Decimal('2'),
32+
decimal.Decimal('0')])
33+
34+
35+
@pytest.fixture
36+
def data_missing_for_sorting():
37+
return DecimalArray([decimal.Decimal('1'),
38+
decimal.Decimal('NaN'),
39+
decimal.Decimal('0')])
40+
41+
2842
@pytest.fixture
2943
def na_cmp():
3044
return lambda x, y: x.is_nan() and y.is_nan()
@@ -35,19 +49,32 @@ def na_value():
3549
return decimal.Decimal("NaN")
3650

3751

38-
class TestDtype(base.BaseDtypeTests):
52+
class BaseDecimal(object):
53+
@staticmethod
54+
def assert_series_equal(left, right, *args, **kwargs):
55+
56+
left_na = left.isna()
57+
right_na = right.isna()
58+
59+
tm.assert_series_equal(left_na, right_na)
60+
return tm.assert_series_equal(left[~left_na],
61+
right[~right_na],
62+
*args, **kwargs)
63+
64+
65+
class TestDtype(BaseDecimal, base.BaseDtypeTests):
3966
pass
4067

4168

42-
class TestInterface(base.BaseInterfaceTests):
69+
class TestInterface(BaseDecimal, base.BaseInterfaceTests):
4370
pass
4471

4572

46-
class TestConstructors(base.BaseConstructorsTests):
73+
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
4774
pass
4875

4976

50-
class TestReshaping(base.BaseReshapingTests):
77+
class TestReshaping(BaseDecimal, base.BaseReshapingTests):
5178

5279
def test_align(self, data, na_value):
5380
# Have to override since assert_series_equal doesn't
@@ -88,15 +115,15 @@ def test_align_frame(self, data, na_value):
88115
assert e2.loc[0, 'A'].is_nan()
89116

90117

91-
class TestGetitem(base.BaseGetitemTests):
118+
class TestGetitem(BaseDecimal, base.BaseGetitemTests):
92119
pass
93120

94121

95-
class TestMissing(base.BaseMissingTests):
122+
class TestMissing(BaseDecimal, base.BaseMissingTests):
96123
pass
97124

98125

99-
class TestMethods(base.BaseMethodsTests):
126+
class TestMethods(BaseDecimal, base.BaseMethodsTests):
100127
@pytest.mark.parametrize('dropna', [True, False])
101128
@pytest.mark.xfail(reason="value_counts not implemented yet.")
102129
def test_value_counts(self, all_data, dropna):
@@ -112,7 +139,7 @@ def test_value_counts(self, all_data, dropna):
112139
tm.assert_series_equal(result, expected)
113140

114141

115-
class TestCasting(base.BaseCastingTests):
142+
class TestCasting(BaseDecimal, base.BaseCastingTests):
116143
pass
117144

118145

pandas/tests/extension/json/test_json.py

+30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def data_missing():
2929
return JSONArray([{}, {'a': 10}])
3030

3131

32+
@pytest.fixture
33+
def data_for_sorting():
34+
return JSONArray([{'b': 1}, {'c': 4}, {'a': 2, 'c': 3}])
35+
36+
37+
@pytest.fixture
38+
def data_missing_for_sorting():
39+
return JSONArray([{'b': 1}, {}, {'c': 4}])
40+
41+
3242
@pytest.fixture
3343
def na_value():
3444
return {}
@@ -68,6 +78,26 @@ class TestMethods(base.BaseMethodsTests):
6878
def test_value_counts(self, all_data, dropna):
6979
pass
7080

81+
@pytest.mark.skip(reason="Dictionaries are not orderable.")
82+
def test_argsort(self):
83+
pass
84+
85+
@pytest.mark.skip(reason="Dictionaries are not orderable.")
86+
def test_argsort_missing(self):
87+
pass
88+
89+
@pytest.mark.skip(reason="Dictionaries are not orderable.")
90+
def test_sort_values(self):
91+
pass
92+
93+
@pytest.mark.skip(reason="Dictionaries are not orderable.")
94+
def test_sort_values_missing(self):
95+
pass
96+
97+
@pytest.mark.skip(reason="Dictionaries are not orderable.")
98+
def test_sort_values_frame(self):
99+
pass
100+
71101

72102
class TestCasting(base.BaseCastingTests):
73103
pass

0 commit comments

Comments
 (0)