diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index aeafd9de1..1ee85d1c0 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -32,7 +32,6 @@ from pandas import ( from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame from pandas.core.groupby.generic import DataFrameGroupBy -from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import ( Index, @@ -50,6 +49,11 @@ from pandas.core.indexing import ( _LocIndexer, ) from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg +from pandas.core.reshape.pivot import ( + _PivotTableColumnsTypes, + _PivotTableIndexTypes, + _PivotTableValuesTypes, +) from pandas.core.series import Series from pandas.core.window import ( Expanding, @@ -1287,9 +1291,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): ) -> Self: ... def pivot_table( self, - values: _str | None | Sequence[_str] = ..., - index: _str | Grouper | Sequence | None = ..., - columns: _str | Grouper | Sequence | None = ..., + values: _PivotTableValuesTypes = ..., + index: _PivotTableIndexTypes = ..., + columns: _PivotTableColumnsTypes = ..., aggfunc=..., fill_value: Scalar | None = ..., margins: _bool = ..., diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index cb0377ab8..d7961af00 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -298,13 +298,13 @@ class Index(IndexOpsMixin[S1]): def to_series(self, index=..., name: Hashable = ...) -> Series: ... def to_frame(self, index: bool = ..., name=...) -> DataFrame: ... @property - def name(self): ... + def name(self) -> Hashable | None: ... @name.setter def name(self, value) -> None: ... @property - def names(self) -> list[_str]: ... + def names(self) -> list[Hashable]: ... @names.setter - def names(self, names: list[_str]): ... + def names(self, names: Sequence[Hashable]) -> None: ... def set_names(self, names, *, level=..., inplace: bool = ...): ... @overload def rename(self, name, inplace: Literal[False] = False) -> Self: ... diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index 042539565..5554ce6fb 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -51,19 +51,24 @@ _NonIterableHashable: TypeAlias = ( | pd.Timedelta ) -_PivotTableIndexTypes: TypeAlias = Label | list[HashableT1] | Series | Grouper | None -_PivotTableColumnsTypes: TypeAlias = Label | list[HashableT2] | Series | Grouper | None +_PivotTableIndexTypes: TypeAlias = ( + Label | Sequence[HashableT1] | Series | Grouper | None +) +_PivotTableColumnsTypes: TypeAlias = ( + Label | Sequence[HashableT2] | Series | Grouper | None +) +_PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None _ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., + values: _PivotTableValuesTypes = ..., index: _PivotTableIndexTypes = ..., columns: _PivotTableColumnsTypes = ..., aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -77,12 +82,12 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., + values: _PivotTableValuesTypes = ..., *, index: Grouper, columns: _PivotTableColumnsTypes | Index | npt.NDArray = ..., aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -94,12 +99,12 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., + values: _PivotTableValuesTypes = ..., index: _PivotTableIndexTypes | Index | npt.NDArray = ..., *, columns: Grouper, aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -111,9 +116,9 @@ def pivot_table( def pivot( data: DataFrame, *, - index: _NonIterableHashable | list[HashableT1] = ..., - columns: _NonIterableHashable | list[HashableT2] = ..., - values: _NonIterableHashable | list[HashableT3] = ..., + index: _NonIterableHashable | Sequence[HashableT1] = ..., + columns: _NonIterableHashable | Sequence[HashableT2] = ..., + values: _NonIterableHashable | Sequence[HashableT3] = ..., ) -> DataFrame: ... @overload def crosstab( diff --git a/tests/test_frame.py b/tests/test_frame.py index ce05b7846..c70bdce23 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1337,6 +1337,17 @@ def test_types_pivot_table() -> None: ), pd.DataFrame, ) + check( + assert_type( + df.pivot_table( + index=df["col1"].name, + columns=df["col3"].name, + values=[df["col2"].name, df["col4"].name], + ), + pd.DataFrame, + ), + pd.DataFrame, + ) def test_pivot_table_sort():