Skip to content

Commit ca9c44e

Browse files
committed
fix: #1212 use typing from pandas.core.reshape.pivot
1 parent ae647cc commit ca9c44e

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ from pandas import (
3232
from pandas.core.arraylike import OpsMixin
3333
from pandas.core.generic import NDFrame
3434
from pandas.core.groupby.generic import DataFrameGroupBy
35-
from pandas.core.groupby.grouper import Grouper
3635
from pandas.core.indexers import BaseIndexer
3736
from pandas.core.indexes.base import (
3837
Index,
@@ -50,6 +49,11 @@ from pandas.core.indexing import (
5049
_LocIndexer,
5150
)
5251
from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg
52+
from pandas.core.reshape.pivot import (
53+
PivotTableColumnsTypes,
54+
PivotTableIndexTypes,
55+
PivotTableValuesTypes,
56+
)
5357
from pandas.core.series import Series
5458
from pandas.core.window import (
5559
Expanding,
@@ -1287,9 +1291,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
12871291
) -> Self: ...
12881292
def pivot_table(
12891293
self,
1290-
values: _str | None | Sequence[_str] = ...,
1291-
index: _str | Grouper | Sequence | None = ...,
1292-
columns: _str | Grouper | Sequence | None = ...,
1294+
values: PivotTableValuesTypes = ...,
1295+
index: PivotTableIndexTypes = ...,
1296+
columns: PivotTableColumnsTypes = ...,
12931297
aggfunc=...,
12941298
fill_value: Scalar | None = ...,
12951299
margins: _bool = ...,

pandas-stubs/core/reshape/pivot.pyi

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,22 @@ _NonIterableHashable: TypeAlias = (
5151
| pd.Timedelta
5252
)
5353

54-
_PivotTableIndexTypes: TypeAlias = Label | list[HashableT1] | Series | Grouper | None
55-
_PivotTableColumnsTypes: TypeAlias = Label | list[HashableT2] | Series | Grouper | None
54+
PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None
55+
PivotTableIndexTypes: TypeAlias = Label | Sequence[HashableT1] | Series | Grouper | None
56+
PivotTableColumnsTypes: TypeAlias = (
57+
Label | Sequence[HashableT2] | Series | Grouper | None
58+
)
5659

5760
_ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike
5861

5962
@overload
6063
def pivot_table(
6164
data: DataFrame,
62-
values: Label | list[HashableT3] | None = ...,
63-
index: _PivotTableIndexTypes = ...,
64-
columns: _PivotTableColumnsTypes = ...,
65+
values: PivotTableValuesTypes = ...,
66+
index: PivotTableIndexTypes = ...,
67+
columns: PivotTableColumnsTypes = ...,
6568
aggfunc: (
66-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
69+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
6770
) = ...,
6871
fill_value: Scalar | None = ...,
6972
margins: bool = ...,
@@ -77,12 +80,12 @@ def pivot_table(
7780
@overload
7881
def pivot_table(
7982
data: DataFrame,
80-
values: Label | list[HashableT3] | None = ...,
83+
values: PivotTableValuesTypes = ...,
8184
*,
8285
index: Grouper,
83-
columns: _PivotTableColumnsTypes | Index | npt.NDArray = ...,
86+
columns: PivotTableColumnsTypes | Index | npt.NDArray = ...,
8487
aggfunc: (
85-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
88+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
8689
) = ...,
8790
fill_value: Scalar | None = ...,
8891
margins: bool = ...,
@@ -94,12 +97,12 @@ def pivot_table(
9497
@overload
9598
def pivot_table(
9699
data: DataFrame,
97-
values: Label | list[HashableT3] | None = ...,
98-
index: _PivotTableIndexTypes | Index | npt.NDArray = ...,
100+
values: PivotTableValuesTypes = ...,
101+
index: PivotTableIndexTypes | Index | npt.NDArray = ...,
99102
*,
100103
columns: Grouper,
101104
aggfunc: (
102-
_PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
105+
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
103106
) = ...,
104107
fill_value: Scalar | None = ...,
105108
margins: bool = ...,
@@ -111,9 +114,9 @@ def pivot_table(
111114
def pivot(
112115
data: DataFrame,
113116
*,
114-
index: _NonIterableHashable | list[HashableT1] = ...,
115-
columns: _NonIterableHashable | list[HashableT2] = ...,
116-
values: _NonIterableHashable | list[HashableT3] = ...,
117+
index: _NonIterableHashable | Sequence[HashableT1] = ...,
118+
columns: _NonIterableHashable | Sequence[HashableT2] = ...,
119+
values: _NonIterableHashable | Sequence[HashableT3] = ...,
117120
) -> DataFrame: ...
118121
@overload
119122
def crosstab(

tests/test_frame.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,17 @@ def test_types_pivot_table() -> None:
13371337
),
13381338
pd.DataFrame,
13391339
)
1340+
check(
1341+
assert_type(
1342+
df.pivot_table(
1343+
index=df["col1"].name,
1344+
columns=df["col3"].name,
1345+
values=[df["col2"].name, df["col4"].name],
1346+
),
1347+
pd.DataFrame,
1348+
),
1349+
pd.DataFrame,
1350+
)
13401351

13411352

13421353
def test_pivot_table_sort():

0 commit comments

Comments
 (0)