Skip to content

Commit 8478cf6

Browse files
authored
REF: de-duplicate wrap_agged_manager/wrap_aggregate_result (#51201)
1 parent 5b6c5d1 commit 8478cf6

File tree

4 files changed

+19
-67
lines changed

4 files changed

+19
-67
lines changed

Diff for: pandas/core/groupby/generic.py

+2-28
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,7 @@ def prop(self):
163163

164164
class SeriesGroupBy(GroupBy[Series]):
165165
def _wrap_agged_manager(self, mgr: Manager) -> Series:
166-
if mgr.ndim == 1:
167-
mgr = cast(SingleManager, mgr)
168-
single = mgr
169-
else:
170-
mgr = cast(Manager2D, mgr)
171-
single = mgr.iget(0)
172-
ser = self.obj._constructor(single, name=self.obj.name)
173-
# NB: caller is responsible for setting ser.index
174-
return ser
166+
return self.obj._constructor(mgr, name=self.obj.name)
175167

176168
def _get_data_to_aggregate(
177169
self, *, numeric_only: bool = False, name: str | None = None
@@ -1902,25 +1894,7 @@ def _indexed_output_to_ndframe(
19021894
return result
19031895

19041896
def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
1905-
if not self.as_index:
1906-
# GH 41998 - empty mgr always gets index of length 0
1907-
rows = mgr.shape[1] if mgr.shape[0] > 0 else 0
1908-
index = Index(range(rows))
1909-
mgr.set_axis(1, index)
1910-
result = self.obj._constructor(mgr)
1911-
1912-
result = self._insert_inaxis_grouper(result)
1913-
result = result._consolidate()
1914-
else:
1915-
index = self.grouper.result_index
1916-
mgr.set_axis(1, index)
1917-
result = self.obj._constructor(mgr)
1918-
1919-
if self.axis == 1:
1920-
result = result.T
1921-
1922-
# Note: we really only care about inferring numeric dtypes here
1923-
return self._reindex_output(result).infer_objects(copy=False)
1897+
return self.obj._constructor(mgr)
19241898

19251899
def _iterate_column_groupbys(self, obj: DataFrame | Series):
19261900
for i, colname in enumerate(obj.columns):

Diff for: pandas/core/groupby/groupby.py

+11-35
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,6 @@ def _cython_agg_general(
15011501
# that goes through SeriesGroupBy
15021502

15031503
data = self._get_data_to_aggregate(numeric_only=numeric_only, name=how)
1504-
is_ser = data.ndim == 1
15051504

15061505
def array_func(values: ArrayLike) -> ArrayLike:
15071506
try:
@@ -1523,16 +1522,12 @@ def array_func(values: ArrayLike) -> ArrayLike:
15231522
return result
15241523

15251524
new_mgr = data.grouped_reduce(array_func)
1526-
15271525
res = self._wrap_agged_manager(new_mgr)
1528-
if is_ser:
1529-
if self.as_index:
1530-
res.index = self.grouper.result_index
1531-
else:
1532-
res = self._insert_inaxis_grouper(res)
1533-
return self._reindex_output(res)
1534-
else:
1535-
return res
1526+
out = self._wrap_aggregated_output(res)
1527+
if data.ndim == 2:
1528+
# TODO: don't special-case DataFrame vs Series
1529+
out = out.infer_objects(copy=False)
1530+
return out
15361531

15371532
def _cython_transform(
15381533
self, how: str, numeric_only: bool = False, axis: AxisInt = 0, **kwargs
@@ -1793,19 +1788,14 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
17931788
return counted
17941789

17951790
new_mgr = data.grouped_reduce(hfunc)
1791+
new_obj = self._wrap_agged_manager(new_mgr)
17961792

17971793
# If we are grouping on categoricals we want unobserved categories to
17981794
# return zero, rather than the default of NaN which the reindexing in
1799-
# _wrap_agged_manager() returns. GH 35028
1795+
# _wrap_aggregated_output() returns. GH 35028
18001796
# e.g. test_dataframe_groupby_on_2_categoricals_when_observed_is_false
18011797
with com.temp_setattr(self, "observed", True):
1802-
result = self._wrap_agged_manager(new_mgr)
1803-
1804-
if result.ndim == 1:
1805-
if self.as_index:
1806-
result.index = self.grouper.result_index
1807-
else:
1808-
result = self._insert_inaxis_grouper(result)
1798+
result = self._wrap_aggregated_output(new_obj)
18091799

18101800
return self._reindex_output(result, fill_value=0)
18111801

@@ -2790,9 +2780,7 @@ def blk_func(values: ArrayLike) -> ArrayLike:
27902780
mgr = obj._mgr
27912781
res_mgr = mgr.apply(blk_func)
27922782

2793-
new_obj = obj._constructor(res_mgr)
2794-
if isinstance(new_obj, Series):
2795-
new_obj.name = obj.name
2783+
new_obj = self._wrap_agged_manager(res_mgr)
27962784

27972785
if self.axis == 1:
27982786
# Only relevant for DataFrameGroupBy
@@ -3197,15 +3185,10 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31973185
out = out.reshape(ncols, ngroups * nqs)
31983186
return post_processor(out, inference, result_mask, orig_vals)
31993187

3200-
obj = self._obj_with_exclusions
3201-
is_ser = obj.ndim == 1
32023188
data = self._get_data_to_aggregate(numeric_only=numeric_only, name="quantile")
32033189
res_mgr = data.grouped_reduce(blk_func)
32043190

3205-
if is_ser:
3206-
res = self._wrap_agged_manager(res_mgr)
3207-
else:
3208-
res = obj._constructor(res_mgr)
3191+
res = self._wrap_agged_manager(res_mgr)
32093192

32103193
if orig_scalar:
32113194
# Avoid expensive MultiIndex construction
@@ -3652,19 +3635,12 @@ def blk_func(values: ArrayLike) -> ArrayLike:
36523635

36533636
return result.T
36543637

3655-
obj = self._obj_with_exclusions
3656-
36573638
# Operate block-wise instead of column-by-column
3658-
is_ser = obj.ndim == 1
36593639
mgr = self._get_data_to_aggregate(numeric_only=numeric_only, name=how)
36603640

36613641
res_mgr = mgr.grouped_reduce(blk_func)
36623642

3663-
if is_ser:
3664-
out = self._wrap_agged_manager(res_mgr)
3665-
else:
3666-
out = obj._constructor(res_mgr)
3667-
3643+
out = self._wrap_agged_manager(res_mgr)
36683644
return self._wrap_aggregated_output(out)
36693645

36703646
@final

Diff for: pandas/core/internals/array_manager.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -947,9 +947,10 @@ def grouped_reduce(self: T, func: Callable) -> T:
947947
result_indices.append(i)
948948

949949
if len(result_arrays) == 0:
950-
index = Index([None]) # placeholder
950+
nrows = 0
951951
else:
952-
index = Index(range(result_arrays[0].shape[0]))
952+
nrows = result_arrays[0].shape[0]
953+
index = Index(range(nrows))
953954

954955
columns = self.items
955956

Diff for: pandas/core/internals/managers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1538,9 +1538,10 @@ def grouped_reduce(self: T, func: Callable) -> T:
15381538
result_blocks = extend_blocks(applied, result_blocks)
15391539

15401540
if len(result_blocks) == 0:
1541-
index = Index([None]) # placeholder
1541+
nrows = 0
15421542
else:
1543-
index = Index(range(result_blocks[0].values.shape[-1]))
1543+
nrows = result_blocks[0].values.shape[-1]
1544+
index = Index(range(nrows))
15441545

15451546
return type(self).from_blocks(result_blocks, [self.axes[0], index])
15461547

0 commit comments

Comments
 (0)