Skip to content

Commit 62dbc5e

Browse files
mroeschkeyehoshuadimarsky
authored andcommitted
TST: Use more pytest idioms in test_reductions, test_generic (pandas-dev#45906)
* TST: Parameterize tests * Use pytest idioms in pandas/tests/frame/test_reductions.py
1 parent 3ff3bb0 commit 62dbc5e

File tree

2 files changed

+189
-212
lines changed

2 files changed

+189
-212
lines changed

pandas/tests/frame/test_reductions.py

+143-186
Original file line numberDiff line numberDiff line change
@@ -141,155 +141,57 @@ def wrapper(x):
141141
tm.assert_series_equal(r1, expected)
142142

143143

144-
def assert_stat_op_api(opname, float_frame, float_string_frame, has_numeric_only=True):
145-
"""
146-
Check that API for operator opname works as advertised on frame
147-
148-
Parameters
149-
----------
150-
opname : str
151-
Name of the operator to test on frame
152-
float_frame : DataFrame
153-
DataFrame with columns of type float
154-
float_string_frame : DataFrame
155-
DataFrame with both float and string columns
156-
has_numeric_only : bool, default False
157-
Whether the method "opname" has the kwarg "numeric_only"
158-
"""
159-
# make sure works on mixed-type frame
160-
getattr(float_string_frame, opname)(axis=0)
161-
getattr(float_string_frame, opname)(axis=1)
162-
163-
if has_numeric_only:
164-
getattr(float_string_frame, opname)(axis=0, numeric_only=True)
165-
getattr(float_string_frame, opname)(axis=1, numeric_only=True)
166-
getattr(float_frame, opname)(axis=0, numeric_only=False)
167-
getattr(float_frame, opname)(axis=1, numeric_only=False)
168-
169-
170-
def assert_bool_op_calc(opname, alternative, frame, has_skipna=True):
171-
"""
172-
Check that bool operator opname works as advertised on frame
173-
174-
Parameters
175-
----------
176-
opname : str
177-
Name of the operator to test on frame
178-
alternative : function
179-
Function that opname is tested against; i.e. "frame.opname()" should
180-
equal "alternative(frame)".
181-
frame : DataFrame
182-
The object that the tests are executed on
183-
has_skipna : bool, default True
184-
Whether the method "opname" has the kwarg "skip_na"
185-
"""
186-
f = getattr(frame, opname)
187-
188-
if has_skipna:
189-
190-
def skipna_wrapper(x):
191-
nona = x.dropna().values
192-
return alternative(nona)
193-
194-
def wrapper(x):
195-
return alternative(x.values)
196-
197-
result0 = f(axis=0, skipna=False)
198-
result1 = f(axis=1, skipna=False)
199-
200-
tm.assert_series_equal(result0, frame.apply(wrapper))
201-
tm.assert_series_equal(result1, frame.apply(wrapper, axis=1))
202-
else:
203-
skipna_wrapper = alternative
204-
wrapper = alternative
205-
206-
result0 = f(axis=0)
207-
result1 = f(axis=1)
208-
209-
tm.assert_series_equal(result0, frame.apply(skipna_wrapper))
210-
tm.assert_series_equal(
211-
result1, frame.apply(skipna_wrapper, axis=1), check_dtype=False
212-
)
213-
214-
# bad axis
215-
with pytest.raises(ValueError, match="No axis named 2"):
216-
f(axis=2)
217-
218-
# all NA case
219-
if has_skipna:
220-
all_na = frame * np.NaN
221-
r0 = getattr(all_na, opname)(axis=0)
222-
r1 = getattr(all_na, opname)(axis=1)
223-
if opname == "any":
224-
assert not r0.any()
225-
assert not r1.any()
226-
else:
227-
assert r0.all()
228-
assert r1.all()
229-
230-
231-
def assert_bool_op_api(
232-
opname, bool_frame_with_na, float_string_frame, has_bool_only=False
233-
):
234-
"""
235-
Check that API for boolean operator opname works as advertised on frame
236-
237-
Parameters
238-
----------
239-
opname : str
240-
Name of the operator to test on frame
241-
bool_frame_with_na : DataFrame
242-
DataFrame with columns of type float
243-
float_string_frame : DataFrame
244-
DataFrame with both float and string columns
245-
has_bool_only : bool, default False
246-
Whether the method "opname" has the kwarg "bool_only"
247-
"""
248-
# make sure op works on mixed-type frame
249-
mixed = float_string_frame
250-
mixed["_bool_"] = np.random.randn(len(mixed)) > 0.5
251-
252-
getattr(mixed, opname)(axis=0)
253-
getattr(mixed, opname)(axis=1)
254-
255-
if has_bool_only:
256-
getattr(mixed, opname)(axis=0, bool_only=True)
257-
getattr(mixed, opname)(axis=1, bool_only=True)
258-
getattr(bool_frame_with_na, opname)(axis=0, bool_only=False)
259-
getattr(bool_frame_with_na, opname)(axis=1, bool_only=False)
260-
261-
262144
class TestDataFrameAnalytics:
263145

264146
# ---------------------------------------------------------------------
265147
# Reductions
266-
267148
@pytest.mark.filterwarnings("ignore:Dropping of nuisance:FutureWarning")
268-
def test_stat_op_api(self, float_frame, float_string_frame):
269-
assert_stat_op_api("count", float_frame, float_string_frame)
270-
assert_stat_op_api("sum", float_frame, float_string_frame)
271-
272-
assert_stat_op_api(
273-
"nunique", float_frame, float_string_frame, has_numeric_only=False
274-
)
275-
assert_stat_op_api("mean", float_frame, float_string_frame)
276-
assert_stat_op_api("product", float_frame, float_string_frame)
277-
assert_stat_op_api("median", float_frame, float_string_frame)
278-
assert_stat_op_api("min", float_frame, float_string_frame)
279-
assert_stat_op_api("max", float_frame, float_string_frame)
280-
assert_stat_op_api(
281-
"mad", float_frame, float_string_frame, has_numeric_only=False
282-
)
283-
assert_stat_op_api("var", float_frame, float_string_frame)
284-
assert_stat_op_api("std", float_frame, float_string_frame)
285-
assert_stat_op_api("sem", float_frame, float_string_frame)
286-
assert_stat_op_api("median", float_frame, float_string_frame)
149+
@pytest.mark.parametrize("axis", [0, 1])
150+
@pytest.mark.parametrize(
151+
"opname",
152+
[
153+
"count",
154+
"sum",
155+
"mean",
156+
"product",
157+
"median",
158+
"min",
159+
"max",
160+
"nunique",
161+
"mad",
162+
"var",
163+
"std",
164+
"sem",
165+
pytest.param("skew", marks=td.skip_if_no_scipy),
166+
pytest.param("kurt", marks=td.skip_if_no_scipy),
167+
],
168+
)
169+
def test_stat_op_api_float_string_frame(self, float_string_frame, axis, opname):
170+
getattr(float_string_frame, opname)(axis=axis)
171+
if opname not in ("nunique", "mad"):
172+
getattr(float_string_frame, opname)(axis=axis, numeric_only=True)
287173

288174
@pytest.mark.filterwarnings("ignore:Dropping of nuisance:FutureWarning")
289-
@td.skip_if_no_scipy
290-
def test_stat_op_api_skew_kurt(self, float_frame, float_string_frame):
291-
assert_stat_op_api("skew", float_frame, float_string_frame)
292-
assert_stat_op_api("kurt", float_frame, float_string_frame)
175+
@pytest.mark.parametrize("axis", [0, 1])
176+
@pytest.mark.parametrize(
177+
"opname",
178+
[
179+
"count",
180+
"sum",
181+
"mean",
182+
"product",
183+
"median",
184+
"min",
185+
"max",
186+
"var",
187+
"std",
188+
"sem",
189+
pytest.param("skew", marks=td.skip_if_no_scipy),
190+
pytest.param("kurt", marks=td.skip_if_no_scipy),
191+
],
192+
)
193+
def test_stat_op_api_float_frame(self, float_frame, axis, opname):
194+
getattr(float_frame, opname)(axis=axis, numeric_only=False)
293195

294196
def test_stat_op_calc(self, float_frame_with_na, mixed_float_frame):
295197
def count(s):
@@ -388,32 +290,37 @@ def wrapper(x):
388290
@pytest.mark.parametrize(
389291
"method", ["sum", "mean", "prod", "var", "std", "skew", "min", "max"]
390292
)
391-
def test_stat_operators_attempt_obj_array(self, method):
293+
@pytest.mark.parametrize(
294+
"df",
295+
[
296+
DataFrame(
297+
{
298+
"a": [
299+
-0.00049987540199591344,
300+
-0.0016467257772919831,
301+
0.00067695870775883013,
302+
],
303+
"b": [-0, -0, 0.0],
304+
"c": [
305+
0.00031111847529610595,
306+
0.0014902627951905339,
307+
-0.00094099200035979691,
308+
],
309+
},
310+
index=["foo", "bar", "baz"],
311+
dtype="O",
312+
),
313+
DataFrame({0: [np.nan, 2], 1: [np.nan, 3], 2: [np.nan, 4]}, dtype=object),
314+
],
315+
)
316+
def test_stat_operators_attempt_obj_array(self, method, df):
392317
# GH#676
393-
data = {
394-
"a": [
395-
-0.00049987540199591344,
396-
-0.0016467257772919831,
397-
0.00067695870775883013,
398-
],
399-
"b": [-0, -0, 0.0],
400-
"c": [
401-
0.00031111847529610595,
402-
0.0014902627951905339,
403-
-0.00094099200035979691,
404-
],
405-
}
406-
df1 = DataFrame(data, index=["foo", "bar", "baz"], dtype="O")
407-
408-
df2 = DataFrame({0: [np.nan, 2], 1: [np.nan, 3], 2: [np.nan, 4]}, dtype=object)
318+
assert df.values.dtype == np.object_
319+
result = getattr(df, method)(1)
320+
expected = getattr(df.astype("f8"), method)(1)
409321

410-
for df in [df1, df2]:
411-
assert df.values.dtype == np.object_
412-
result = getattr(df, method)(1)
413-
expected = getattr(df.astype("f8"), method)(1)
414-
415-
if method in ["sum", "prod"]:
416-
tm.assert_series_equal(result, expected)
322+
if method in ["sum", "prod"]:
323+
tm.assert_series_equal(result, expected)
417324

418325
@pytest.mark.parametrize("op", ["mean", "std", "var", "skew", "kurt", "sem"])
419326
def test_mixed_ops(self, op):
@@ -968,32 +875,36 @@ def test_sum_bools(self):
968875
# ----------------------------------------------------------------------
969876
# Index of max / min
970877

971-
def test_idxmin(self, float_frame, int_frame):
878+
@pytest.mark.parametrize("skipna", [True, False])
879+
@pytest.mark.parametrize("axis", [0, 1])
880+
def test_idxmin(self, float_frame, int_frame, skipna, axis):
972881
frame = float_frame
973882
frame.iloc[5:10] = np.nan
974883
frame.iloc[15:20, -2:] = np.nan
975-
for skipna in [True, False]:
976-
for axis in [0, 1]:
977-
for df in [frame, int_frame]:
978-
result = df.idxmin(axis=axis, skipna=skipna)
979-
expected = df.apply(Series.idxmin, axis=axis, skipna=skipna)
980-
tm.assert_series_equal(result, expected)
884+
for df in [frame, int_frame]:
885+
result = df.idxmin(axis=axis, skipna=skipna)
886+
expected = df.apply(Series.idxmin, axis=axis, skipna=skipna)
887+
tm.assert_series_equal(result, expected)
981888

889+
def test_idxmin_axis_2(self, float_frame):
890+
frame = float_frame
982891
msg = "No axis named 2 for object type DataFrame"
983892
with pytest.raises(ValueError, match=msg):
984893
frame.idxmin(axis=2)
985894

986-
def test_idxmax(self, float_frame, int_frame):
895+
@pytest.mark.parametrize("skipna", [True, False])
896+
@pytest.mark.parametrize("axis", [0, 1])
897+
def test_idxmax(self, float_frame, int_frame, skipna, axis):
987898
frame = float_frame
988899
frame.iloc[5:10] = np.nan
989900
frame.iloc[15:20, -2:] = np.nan
990-
for skipna in [True, False]:
991-
for axis in [0, 1]:
992-
for df in [frame, int_frame]:
993-
result = df.idxmax(axis=axis, skipna=skipna)
994-
expected = df.apply(Series.idxmax, axis=axis, skipna=skipna)
995-
tm.assert_series_equal(result, expected)
901+
for df in [frame, int_frame]:
902+
result = df.idxmax(axis=axis, skipna=skipna)
903+
expected = df.apply(Series.idxmax, axis=axis, skipna=skipna)
904+
tm.assert_series_equal(result, expected)
996905

906+
def test_idxmax_axis_2(self, float_frame):
907+
frame = float_frame
997908
msg = "No axis named 2 for object type DataFrame"
998909
with pytest.raises(ValueError, match=msg):
999910
frame.idxmax(axis=2)
@@ -1077,17 +988,63 @@ def test_idxmax_dt64_multicolumn_axis1(self):
1077988
# Logical reductions
1078989

1079990
@pytest.mark.parametrize("opname", ["any", "all"])
1080-
def test_any_all(self, opname, bool_frame_with_na, float_string_frame):
1081-
assert_bool_op_api(
1082-
opname, bool_frame_with_na, float_string_frame, has_bool_only=True
1083-
)
991+
@pytest.mark.parametrize("axis", [0, 1])
992+
@pytest.mark.parametrize("bool_only", [False, True])
993+
def test_any_all_mixed_float(self, opname, axis, bool_only, float_string_frame):
994+
# make sure op works on mixed-type frame
995+
mixed = float_string_frame
996+
mixed["_bool_"] = np.random.randn(len(mixed)) > 0.5
997+
998+
getattr(mixed, opname)(axis=axis, bool_only=bool_only)
999+
1000+
@pytest.mark.parametrize("opname", ["any", "all"])
1001+
@pytest.mark.parametrize("axis", [0, 1])
1002+
def test_any_all_bool_with_na(self, opname, axis, bool_frame_with_na):
1003+
getattr(bool_frame_with_na, opname)(axis=axis, bool_only=False)
10841004

10851005
@pytest.mark.parametrize("opname", ["any", "all"])
10861006
def test_any_all_bool_frame(self, opname, bool_frame_with_na):
10871007
# GH#12863: numpy gives back non-boolean data for object type
10881008
# so fill NaNs to compare with pandas behavior
1089-
df = bool_frame_with_na.fillna(True)
1090-
assert_bool_op_calc(opname, getattr(np, opname), df, has_skipna=True)
1009+
frame = bool_frame_with_na.fillna(True)
1010+
alternative = getattr(np, opname)
1011+
f = getattr(frame, opname)
1012+
1013+
def skipna_wrapper(x):
1014+
nona = x.dropna().values
1015+
return alternative(nona)
1016+
1017+
def wrapper(x):
1018+
return alternative(x.values)
1019+
1020+
result0 = f(axis=0, skipna=False)
1021+
result1 = f(axis=1, skipna=False)
1022+
1023+
tm.assert_series_equal(result0, frame.apply(wrapper))
1024+
tm.assert_series_equal(result1, frame.apply(wrapper, axis=1))
1025+
1026+
result0 = f(axis=0)
1027+
result1 = f(axis=1)
1028+
1029+
tm.assert_series_equal(result0, frame.apply(skipna_wrapper))
1030+
tm.assert_series_equal(
1031+
result1, frame.apply(skipna_wrapper, axis=1), check_dtype=False
1032+
)
1033+
1034+
# bad axis
1035+
with pytest.raises(ValueError, match="No axis named 2"):
1036+
f(axis=2)
1037+
1038+
# all NA case
1039+
all_na = frame * np.NaN
1040+
r0 = getattr(all_na, opname)(axis=0)
1041+
r1 = getattr(all_na, opname)(axis=1)
1042+
if opname == "any":
1043+
assert not r0.any()
1044+
assert not r1.any()
1045+
else:
1046+
assert r0.all()
1047+
assert r1.all()
10911048

10921049
def test_any_all_extra(self):
10931050
df = DataFrame(

0 commit comments

Comments
 (0)