Skip to content

Commit 366caf9

Browse files
authored
fix: #1212 allow Series.name in pivot_table (#1216)
* fix: #1212 Index.name currently has no typing #1212 (comment) * fix: #1212 use typing from pandas.core.reshape.pivot * fix(comment): #1216 (comment)
1 parent a7f70b7 commit 366caf9

File tree

4 files changed

+38
-18
lines changed

4 files changed

+38
-18
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ from pandas import (
3232
from pandas.core.arraylike import OpsMixin
3333
from pandas.core.generic import NDFrame
3434
from pandas.core.groupby.generic import DataFrameGroupBy
35-
from pandas.core.groupby.grouper import Grouper
3635
from pandas.core.indexers import BaseIndexer
3736
from pandas.core.indexes.base import (
3837
Index,
@@ -50,6 +49,11 @@ from pandas.core.indexing import (
5049
_LocIndexer,
5150
)
5251
from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg
52+
from pandas.core.reshape.pivot import (
53+
_PivotTableColumnsTypes,
54+
_PivotTableIndexTypes,
55+
_PivotTableValuesTypes,
56+
)
5357
from pandas.core.series import Series
5458
from pandas.core.window import (
5559
Expanding,
@@ -1287,9 +1291,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
12871291
) -> Self: ...
12881292
def pivot_table(
12891293
self,
1290-
values: _str | None | Sequence[_str] = ...,
1291-
index: _str | Grouper | Sequence | None = ...,
1292-
columns: _str | Grouper | Sequence | None = ...,
1294+
values: _PivotTableValuesTypes = ...,
1295+
index: _PivotTableIndexTypes = ...,
1296+
columns: _PivotTableColumnsTypes = ...,
12931297
aggfunc=...,
12941298
fill_value: Scalar | None = ...,
12951299
margins: _bool = ...,

pandas-stubs/core/indexes/base.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,13 @@ class Index(IndexOpsMixin[S1]):
298298
def to_series(self, index=..., name: Hashable = ...) -> Series: ...
299299
def to_frame(self, index: bool = ..., name=...) -> DataFrame: ...
300300
@property
301-
def name(self): ...
301+
def name(self) -> Hashable | None: ...
302302
@name.setter
303303
def name(self, value) -> None: ...
304304
@property
305-
def names(self) -> list[_str]: ...
305+
def names(self) -> list[Hashable]: ...
306306
@names.setter
307-
def names(self, names: list[_str]): ...
307+
def names(self, names: Sequence[Hashable]) -> None: ...
308308
def set_names(self, names, *, level=..., inplace: bool = ...): ...
309309
@overload
310310
def rename(self, name, inplace: Literal[False] = False) -> Self: ...

pandas-stubs/core/reshape/pivot.pyi

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,24 @@ _NonIterableHashable: TypeAlias = (
5151
| pd.Timedelta
5252
)
5353

54-
_PivotTableIndexTypes: TypeAlias = Label | list[HashableT1] | Series | Grouper | None
55-
_PivotTableColumnsTypes: TypeAlias = Label | list[HashableT2] | Series | Grouper | None
54+
_PivotTableIndexTypes: TypeAlias = (
55+
Label | Sequence[HashableT1] | Series | Grouper | None
56+
)
57+
_PivotTableColumnsTypes: TypeAlias = (
58+
Label | Sequence[HashableT2] | Series | Grouper | None
59+
)
60+
_PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None
5661

5762
_ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike
5863

5964
@overload
6065
def pivot_table(
6166
data: DataFrame,
62-
values: Label | list[HashableT3] | None = ...,
67+
values: _PivotTableValuesTypes = ...,
6368
index: _PivotTableIndexTypes = ...,
6469
columns: _PivotTableColumnsTypes = ...,
6570
aggfunc: (
66-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
71+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
6772
) = ...,
6873
fill_value: Scalar | None = ...,
6974
margins: bool = ...,
@@ -77,12 +82,12 @@ def pivot_table(
7782
@overload
7883
def pivot_table(
7984
data: DataFrame,
80-
values: Label | list[HashableT3] | None = ...,
85+
values: _PivotTableValuesTypes = ...,
8186
*,
8287
index: Grouper,
8388
columns: _PivotTableColumnsTypes | Index | npt.NDArray = ...,
8489
aggfunc: (
85-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
90+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
8691
) = ...,
8792
fill_value: Scalar | None = ...,
8893
margins: bool = ...,
@@ -94,12 +99,12 @@ def pivot_table(
9499
@overload
95100
def pivot_table(
96101
data: DataFrame,
97-
values: Label | list[HashableT3] | None = ...,
102+
values: _PivotTableValuesTypes = ...,
98103
index: _PivotTableIndexTypes | Index | npt.NDArray = ...,
99104
*,
100105
columns: Grouper,
101106
aggfunc: (
102-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
107+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
103108
) = ...,
104109
fill_value: Scalar | None = ...,
105110
margins: bool = ...,
@@ -111,9 +116,9 @@ def pivot_table(
111116
def pivot(
112117
data: DataFrame,
113118
*,
114-
index: _NonIterableHashable | list[HashableT1] = ...,
115-
columns: _NonIterableHashable | list[HashableT2] = ...,
116-
values: _NonIterableHashable | list[HashableT3] = ...,
119+
index: _NonIterableHashable | Sequence[HashableT1] = ...,
120+
columns: _NonIterableHashable | Sequence[HashableT2] = ...,
121+
values: _NonIterableHashable | Sequence[HashableT3] = ...,
117122
) -> DataFrame: ...
118123
@overload
119124
def crosstab(

tests/test_frame.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,17 @@ def test_types_pivot_table() -> None:
13371337
),
13381338
pd.DataFrame,
13391339
)
1340+
check(
1341+
assert_type(
1342+
df.pivot_table(
1343+
index=df["col1"].name,
1344+
columns=df["col3"].name,
1345+
values=[df["col2"].name, df["col4"].name],
1346+
),
1347+
pd.DataFrame,
1348+
),
1349+
pd.DataFrame,
1350+
)
13401351

13411352

13421353
def test_pivot_table_sort():

0 commit comments

Comments
 (0)