forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_groupby_subclass.py
78 lines (63 loc) · 2.31 KB
/
test_groupby_subclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from datetime import datetime
import numpy as np
import pytest
from pandas import DataFrame, Series
import pandas._testing as tm
@pytest.mark.parametrize(
"obj",
[
tm.SubclassedDataFrame({"A": np.arange(0, 10)}),
tm.SubclassedSeries(np.arange(0, 10), name="A"),
],
)
@pytest.mark.filterwarnings("ignore:tshift is deprecated:FutureWarning")
def test_groupby_preserves_subclass(obj, groupby_func):
# GH28330 -- preserve subclass through groupby operations
if isinstance(obj, Series) and groupby_func in {"corrwith"}:
pytest.skip("Not applicable")
grouped = obj.groupby(np.arange(0, 10))
# Groups should preserve subclass type
assert isinstance(grouped.get_group(0), type(obj))
args = []
if groupby_func in {"fillna", "nth"}:
args.append(0)
elif groupby_func == "corrwith":
args.append(obj)
elif groupby_func == "tshift":
args.extend([0, 0])
result1 = getattr(grouped, groupby_func)(*args)
result2 = grouped.agg(groupby_func, *args)
# Reduction or transformation kernels should preserve type
slices = {"ngroup", "cumcount", "size"}
if isinstance(obj, DataFrame) and groupby_func in slices:
assert isinstance(result1, obj._constructor_sliced)
else:
assert isinstance(result1, type(obj))
# Confirm .agg() groupby operations return same results
if isinstance(result1, DataFrame):
tm.assert_frame_equal(result1, result2)
else:
tm.assert_series_equal(result1, result2)
@pytest.mark.parametrize(
"obj", [DataFrame, tm.SubclassedDataFrame],
)
def test_groupby_resample_preserves_subclass(obj):
# GH28330 -- preserve subclass through groupby.resample()
df = obj(
{
"Buyer": "Carl Carl Carl Carl Joe Carl".split(),
"Quantity": [18, 3, 5, 1, 9, 3],
"Date": [
datetime(2013, 9, 1, 13, 0),
datetime(2013, 9, 1, 13, 5),
datetime(2013, 10, 1, 20, 0),
datetime(2013, 10, 3, 10, 0),
datetime(2013, 12, 2, 12, 0),
datetime(2013, 9, 2, 14, 0),
],
}
)
df = df.set_index("Date")
# Confirm groupby.resample() preserves dataframe type
result = df.groupby("Buyer").resample("5D").sum()
assert isinstance(result, obj)