Skip to content

Commit 2b76841

Browse files
andremcorreiamroeschke
authored andcommitted
BUG: Fixed DataFrameGroupBy.transform with numba returning the wrong order with non increasing indexes pandas-dev#57069 (pandas-dev#58030)
* Fix pandas-dev#57069: DataFrameGroupBy.transform with numba returning the wrong order with non monotonically increasing indexes Fixed a bug that was returning the wrong order unless the index was monotonically increasing while utilizing DataFrameGroupBy.transform with engine='numba' Fixed the test "pandas/tests/groupby/transform/test_numba.py::test_index_data_correctly_passed" to expect a result in the correct order Added a test "pandas/tests/groupby/transform/test_numba.py::test_index_order_consistency_preserved" to test DataFrameGroupBy.transform with engine='numba' with a decreasing index Updated whatsnew to reflect changes * Apply suggestions from code review Co-authored-by: Matthew Roeschke <[email protected]> * Fixed pre-commit requirements --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 5fa9476 commit 2b76841

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

pandas/core/groupby/groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14391439
data and indices into a Numba jitted function.
14401440
"""
14411441
data = self._obj_with_exclusions
1442+
index_sorting = self._grouper.result_ilocs
14421443
df = data if data.ndim == 2 else data.to_frame()
14431444

14441445
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
@@ -1456,7 +1457,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14561457
)
14571458
# result values needs to be resorted to their original positions since we
14581459
# evaluated the data sorted by group
1459-
result = result.take(np.argsort(sorted_index), axis=0)
1460+
result = result.take(np.argsort(index_sorting), axis=0)
14601461
index = data.index
14611462
if data.ndim == 1:
14621463
result_kwargs = {"name": data.name}

pandas/tests/groupby/transform/test_numba.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,25 @@ def f(values, index):
181181

182182
df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
183183
result = df.groupby("group").transform(f, engine="numba")
184-
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
184+
expected = DataFrame([-2.0, -3.0, -4.0], columns=["v"], index=[-1, -2, -3])
185185
tm.assert_frame_equal(result, expected)
186186

187187

188+
def test_index_order_consistency_preserved():
189+
# GH 57069
190+
pytest.importorskip("numba")
191+
192+
def f(values, index):
193+
return values
194+
195+
df = DataFrame(
196+
{"vals": [0.0, 1.0, 2.0, 3.0], "group": [0, 1, 0, 1]}, index=range(3, -1, -1)
197+
)
198+
result = df.groupby("group")["vals"].transform(f, engine="numba")
199+
expected = Series([0.0, 1.0, 2.0, 3.0], index=range(3, -1, -1), name="vals")
200+
tm.assert_series_equal(result, expected)
201+
202+
188203
def test_engine_kwargs_not_cached():
189204
# If the user passes a different set of engine_kwargs don't return the same
190205
# jitted function

0 commit comments

Comments
 (0)