Skip to content

Commit 05a26c7

Browse files
Add B031: Warn when using groupby() result multiple times
1 parent 9ea5a34 commit 05a26c7

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

README.rst

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ It is therefore recommended to use a stacklevel of 2 or greater to provide more
181181

182182
**B030**: Except handlers should only be exception classes or tuples of exception classes.
183183

184+
**B031**: Using the generator returned from `itertools.groupby()` more than once will do nothing on the
185+
second usage. Save the result to a list if the result is needed multiple times.
186+
184187
Opinionated warnings
185188
~~~~~~~~~~~~~~~~~~~~
186189

bugbear.py

+65
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ def children_in_scope(node):
265265
yield from children_in_scope(child)
266266

267267

268+
def walk_list(nodes):
269+
for node in nodes:
270+
yield from ast.walk(node)
271+
272+
268273
def _typesafe_issubclass(cls, class_or_tuple):
269274
try:
270275
return issubclass(cls, class_or_tuple)
@@ -401,6 +406,7 @@ def visit_For(self, node):
401406
self.check_for_b007(node)
402407
self.check_for_b020(node)
403408
self.check_for_b023(node)
409+
self.check_for_b031(node)
404410
self.generic_visit(node)
405411

406412
def visit_AsyncFor(self, node):
@@ -793,6 +799,56 @@ def check_for_b026(self, call: ast.Call):
793799
):
794800
self.errors.append(B026(starred.lineno, starred.col_offset))
795801

802+
def check_for_b031(self, loop_node): # noqa: C901
803+
"""Check that `itertools.groupby` isn't iterated over more than once.
804+
805+
We emit a warning when the generator returned by `groupby()` is used
806+
more than once inside a loop body or when it's used in a nested loop.
807+
"""
808+
# for <loop_node.target> in <loop_node.iter>: ...
809+
if isinstance(loop_node.iter, ast.Call):
810+
node = loop_node.iter
811+
if (isinstance(node.func, ast.Name) and node.func.id in ("groupby",)) or (
812+
isinstance(node.func, ast.Attribute)
813+
and node.func.attr == "groupby"
814+
and isinstance(node.func.value, ast.Name)
815+
and node.func.value.id == "itertools"
816+
):
817+
# We have an invocation of groupby which is a simple unpacking
818+
if isinstance(loop_node.target, ast.Tuple) and isinstance(
819+
loop_node.target.elts[1], ast.Name
820+
):
821+
group_name = loop_node.target.elts[1].id
822+
else:
823+
# Ignore any `groupby()` invocation that isn't unpacked
824+
return
825+
826+
num_usages = 0
827+
for node in walk_list(loop_node.body):
828+
# Handled nested loops
829+
if isinstance(node, ast.For):
830+
for nested_node in walk_list(node.body):
831+
assert nested_node != node
832+
if (
833+
isinstance(nested_node, ast.Name)
834+
and nested_node.id == group_name
835+
):
836+
self.errors.append(
837+
B031(
838+
nested_node.lineno,
839+
nested_node.col_offset,
840+
vars=(nested_node.id,),
841+
)
842+
)
843+
844+
# Handle multiple uses
845+
if isinstance(node, ast.Name) and node.id == group_name:
846+
num_usages += 1
847+
if num_usages > 1:
848+
self.errors.append(
849+
B031(node.lineno, node.col_offset, vars=(node.id,))
850+
)
851+
796852
def _get_assigned_names(self, loop_node):
797853
loop_targets = (ast.For, ast.AsyncFor, ast.comprehension)
798854
for node in children_in_scope(loop_node):
@@ -1558,8 +1614,17 @@ def visit_Lambda(self, node):
15581614
"anything. Add exceptions to handle."
15591615
)
15601616
)
1617+
15611618
B030 = Error(message="B030 Except handlers should only be names of exception classes")
15621619

1620+
B031 = Error(
1621+
message=(
1622+
"B031 Using the generator returned from `itertools.groupby()` more than once"
1623+
" will do nothing on the second usage. Save the result to a list, if the"
1624+
" result is needed multiple times."
1625+
)
1626+
)
1627+
15631628
# Warnings disabled by default.
15641629
B901 = Error(
15651630
message=(

tests/b031.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Should emit:
3+
B030 - on lines 29, 33, 43
4+
"""
5+
import itertools
6+
from itertools import groupby
7+
8+
shoppers = ["Jane", "Joe", "Sarah"]
9+
items = [
10+
("lettuce", "greens"),
11+
("tomatoes", "greens"),
12+
("cucumber", "greens"),
13+
("chicken breast", "meats & fish"),
14+
("salmon", "meats & fish"),
15+
("ice cream", "frozen items"),
16+
]
17+
18+
carts = {shopper: [] for shopper in shoppers}
19+
20+
21+
def collect_shop_items(shopper, items):
22+
# Imagine this an expensive database query or calculation that is
23+
# advantageous to batch.
24+
carts[shopper] += items
25+
26+
27+
# Group by shopping section
28+
for _section, section_items in groupby(items, key=lambda p: p[1]):
29+
for shopper in shoppers:
30+
collect_shop_items(shopper, section_items)
31+
32+
for _section, section_items in groupby(items, key=lambda p: p[1]):
33+
collect_shop_items("Jane", section_items)
34+
collect_shop_items("Joe", section_items)
35+
36+
37+
for _section, section_items in groupby(items, key=lambda p: p[1]):
38+
# This is ok
39+
collect_shop_items("Jane", section_items)
40+
41+
for _section, section_items in itertools.groupby(items, key=lambda p: p[1]):
42+
for shopper in shoppers:
43+
collect_shop_items(shopper, section_items)
44+
45+
for group in groupby(items, key=lambda p: p[1]):
46+
# This is bad, but not detected currently
47+
collect_shop_items("Jane", group[1])
48+
collect_shop_items("Joe", group[1])
49+
50+
51+
# Make sure we ignore - but don't fail on more complicated invocations
52+
for _key, (_value1, _value2) in groupby(
53+
[("a", (1, 2)), ("b", (3, 4)), ("a", (5, 6))], key=lambda p: p[1]
54+
):
55+
collect_shop_items("Jane", group[1])
56+
collect_shop_items("Joe", group[1])
57+
58+
# Make sure we ignore - but don't fail on more complicated invocations
59+
for (_key1, _key2), (_value1, _value2) in groupby(
60+
[(("a", "a"), (1, 2)), (("b", "b"), (3, 4)), (("a", "a"), (5, 6))],
61+
key=lambda p: p[1],
62+
):
63+
collect_shop_items("Jane", group[1])
64+
collect_shop_items("Joe", group[1])

tests/test_bugbear.py

+12
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
B028,
4343
B029,
4444
B030,
45+
B031,
4546
B901,
4647
B902,
4748
B903,
@@ -459,6 +460,17 @@ def test_b030(self):
459460
)
460461
self.assertEqual(errors, expected)
461462

463+
def test_b031(self):
464+
filename = Path(__file__).absolute().parent / "b031.py"
465+
bbc = BugBearChecker(filename=str(filename))
466+
errors = list(bbc.run())
467+
expected = self.errors(
468+
B031(30, 36, vars=("section_items",)),
469+
B031(34, 30, vars=("section_items",)),
470+
B031(43, 36, vars=("section_items",)),
471+
)
472+
self.assertEqual(errors, expected)
473+
462474
@unittest.skipIf(sys.version_info < (3, 8), "not implemented for <3.8")
463475
def test_b907(self):
464476
filename = Path(__file__).absolute().parent / "b907.py"

0 commit comments

Comments
 (0)