Skip to content

Commit c468028

Browse files
BUG: Fixed DataFrameGroupBy.transform with numba returning the wrong order with non increasing indexes #57069 (#58030)
* Fix #57069: DataFrameGroupBy.transform with numba returning the wrong order with non monotonically increasing indexes Fixed a bug that was returning the wrong order unless the index was monotonically increasing while utilizing DataFrameGroupBy.transform with engine='numba' Fixed the test "pandas/tests/groupby/transform/test_numba.py::test_index_data_correctly_passed" to expect a result in the correct order Added a test "pandas/tests/groupby/transform/test_numba.py::test_index_order_consistency_preserved" to test DataFrameGroupBy.transform with engine='numba' with a decreasing index Updated whatsnew to reflect changes * Apply suggestions from code review Co-authored-by: Matthew Roeschke <[email protected]> * Fixed pre-commit requirements --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent b86eb99 commit c468028

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

Diff for: doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ Bug fixes
327327
- Fixed bug in :meth:`DataFrame.cumsum` which was raising ``IndexError`` if dtype is ``timedelta64[ns]`` (:issue:`57956`)
328328
- Fixed bug in :meth:`DataFrame.join` inconsistently setting result index name (:issue:`55815`)
329329
- Fixed bug in :meth:`DataFrame.to_string` that raised ``StopIteration`` with nested DataFrames. (:issue:`16098`)
330+
- Fixed bug in :meth:`DataFrame.transform` that was returning the wrong order unless the index was monotonically increasing. (:issue:`57069`)
330331
- Fixed bug in :meth:`DataFrame.update` bool dtype being converted to object (:issue:`55509`)
331332
- Fixed bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
332333
- Fixed bug in :meth:`Series.diff` allowing non-integer values for the ``periods`` argument. (:issue:`56607`)

Diff for: pandas/core/groupby/groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14391439
data and indices into a Numba jitted function.
14401440
"""
14411441
data = self._obj_with_exclusions
1442+
index_sorting = self._grouper.result_ilocs
14421443
df = data if data.ndim == 2 else data.to_frame()
14431444

14441445
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
@@ -1456,7 +1457,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14561457
)
14571458
# result values needs to be resorted to their original positions since we
14581459
# evaluated the data sorted by group
1459-
result = result.take(np.argsort(sorted_index), axis=0)
1460+
result = result.take(np.argsort(index_sorting), axis=0)
14601461
index = data.index
14611462
if data.ndim == 1:
14621463
result_kwargs = {"name": data.name}

Diff for: pandas/tests/groupby/transform/test_numba.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,25 @@ def f(values, index):
181181

182182
df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
183183
result = df.groupby("group").transform(f, engine="numba")
184-
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
184+
expected = DataFrame([-2.0, -3.0, -4.0], columns=["v"], index=[-1, -2, -3])
185185
tm.assert_frame_equal(result, expected)
186186

187187

188+
def test_index_order_consistency_preserved():
189+
# GH 57069
190+
pytest.importorskip("numba")
191+
192+
def f(values, index):
193+
return values
194+
195+
df = DataFrame(
196+
{"vals": [0.0, 1.0, 2.0, 3.0], "group": [0, 1, 0, 1]}, index=range(3, -1, -1)
197+
)
198+
result = df.groupby("group")["vals"].transform(f, engine="numba")
199+
expected = Series([0.0, 1.0, 2.0, 3.0], index=range(3, -1, -1), name="vals")
200+
tm.assert_series_equal(result, expected)
201+
202+
188203
def test_engine_kwargs_not_cached():
189204
# If the user passes a different set of engine_kwargs don't return the same
190205
# jitted function

0 commit comments

Comments
 (0)