Skip to content

bpo-32856: Optimize the idiom for assignment in comprehensions. #5695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Lib/test/test_dictcomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def test_illegal_assignment(self):
compile("{x: y for y, x in ((1, 2), (3, 4))} += 5", "<test>",
"exec")

def test_assignment_idiom_in_comprehesions(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: comprehesions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

expected = {1: 1, 2: 4, 3: 9, 4: 16}
actual = {j: j*j for i in range(4) for j in [i+1]}
self.assertEqual(actual, expected)
expected = {3: 2, 5: 6, 7: 12, 9: 20}
actual = {j+k: j*k for i in range(4) for j in [i+1] for k in [j+1]}
self.assertEqual(actual, expected)
expected = {3: 2, 5: 6, 7: 12, 9: 20}
actual = {j+k: j*k for i in range(4) for j, k in [(i+1, i+2)]}
self.assertEqual(actual, expected)


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions Lib/test/test_genexps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
>>> list((i,j) for i in range(4) for j in range(i) )
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]

Test the idiom for temporary variable assignment in comprehensions.

>>> list((j*j for i in range(4) for j in [i+1]))
[1, 4, 9, 16]
>>> list((j*k for i in range(4) for j in [i+1] for k in [j+1]))
[2, 6, 12, 20]
>>> list((j*k for i in range(4) for j, k in [(i+1, i+2)]))
[2, 6, 12, 20]

Make sure the induction variable is not exposed

>>> i = 20
Expand Down
9 changes: 9 additions & 0 deletions Lib/test/test_listcomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
>>> [(i,j) for i in range(4) for j in range(i)]
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]

Test the idiom for temporary variable assignment in comprehensions.

>>> [j*j for i in range(4) for j in [i+1]]
[1, 4, 9, 16]
>>> [j*k for i in range(4) for j in [i+1] for k in [j+1]]
[2, 6, 12, 20]
>>> [j*k for i in range(4) for j, k in [(i+1, i+2)]]
[2, 6, 12, 20]

Make sure the induction variable is not exposed

>>> i = 20
Expand Down
27 changes: 27 additions & 0 deletions Lib/test/test_peepholer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@

from test.bytecode_helper import BytecodeTestCase

def count_instr_recursively(f, opname):
count = 0
for instr in dis.get_instructions(f):
if instr.opname == opname:
count += 1
if hasattr(f, '__code__'):
f = f.__code__
for c in f.co_consts:
if hasattr(c, 'co_code'):
count += count_instr_recursively(c, opname)
return count


class TestTranforms(BytecodeTestCase):

def test_unot(self):
Expand Down Expand Up @@ -311,6 +324,20 @@ def test_constant_folding(self):
self.assertFalse(instr.opname.startswith('BINARY_'))
self.assertFalse(instr.opname.startswith('BUILD_'))

def test_assignment_idiom_in_comprehesions(self):
def listcomp():
return [y for x in a for y in [f(x)]]
self.assertEqual(count_instr_recursively(listcomp, 'FOR_ITER'), 1)
def setcomp():
return {y for x in a for y in [f(x)]}
self.assertEqual(count_instr_recursively(setcomp, 'FOR_ITER'), 1)
def dictcomp():
return {y: y for x in a for y in [f(x)]}
self.assertEqual(count_instr_recursively(dictcomp, 'FOR_ITER'), 1)
def genexpr():
return (y for x in a for y in [f(x)])
self.assertEqual(count_instr_recursively(genexpr, 'FOR_ITER'), 1)


class TestBuglets(unittest.TestCase):

Expand Down
9 changes: 9 additions & 0 deletions Lib/test/test_setcomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
>>> list(sorted({(i,j) for i in range(4) for j in range(i)}))
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]

Test the idiom for temporary variable assignment in comprehensions.

>>> list(sorted({j*j for i in range(4) for j in [i+1]}))
[1, 4, 9, 16]
>>> list(sorted({j*k for i in range(4) for j in [i+1] for k in [j+1]}))
[2, 6, 12, 20]
>>> list(sorted({j*k for i in range(4) for j, k in [(i+1, i+2)]}))
[2, 6, 12, 20]

Make sure the induction variable is not exposed

>>> i = 20
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimized the idiom for assignment a temporary variable in comprehensions.
Now ``for y in [expr]`` in comprehensions is so fast as a simple assignment
``y = expr``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the NEWS entry with some benchmark results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what should I provide. The optimized code always is the part of complex comprehension expression.

70 changes: 52 additions & 18 deletions Python/compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ static int compiler_set_qualname(struct compiler *);
static int compiler_sync_comprehension_generator(
struct compiler *c,
asdl_seq *generators, int gen_index,
int depth,
expr_ty elt, expr_ty val, int type);

static int compiler_async_comprehension_generator(
struct compiler *c,
asdl_seq *generators, int gen_index,
int depth,
expr_ty elt, expr_ty val, int type);

static PyCodeObject *assemble(struct compiler *, int addNone);
Expand Down Expand Up @@ -3782,22 +3784,24 @@ compiler_call_helper(struct compiler *c,
static int
compiler_comprehension_generator(struct compiler *c,
asdl_seq *generators, int gen_index,
int depth,
expr_ty elt, expr_ty val, int type)
{
comprehension_ty gen;
gen = (comprehension_ty)asdl_seq_GET(generators, gen_index);
if (gen->is_async) {
return compiler_async_comprehension_generator(
c, generators, gen_index, elt, val, type);
c, generators, gen_index, depth, elt, val, type);
} else {
return compiler_sync_comprehension_generator(
c, generators, gen_index, elt, val, type);
c, generators, gen_index, depth, elt, val, type);
}
}

static int
compiler_sync_comprehension_generator(struct compiler *c,
asdl_seq *generators, int gen_index,
int depth,
expr_ty elt, expr_ty val, int type)
{
/* generate code for the iterator, then each of the ifs,
Expand Down Expand Up @@ -3825,12 +3829,38 @@ compiler_sync_comprehension_generator(struct compiler *c,
}
else {
/* Sub-iter - calculate on the fly */
VISIT(c, expr, gen->iter);
ADDOP(c, GET_ITER);
/* Fast path for the temporary variable assignment idiom:
for y in [f(x)]
*/
asdl_seq *elts;
switch (gen->iter->kind) {
case List_kind:
elts = gen->iter->v.List.elts;
break;
case Tuple_kind:
elts = gen->iter->v.Tuple.elts;
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we handle for x in {a} too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will change the behavior if a is not hashable.

default:
elts = NULL;
}
if (asdl_seq_LEN(elts) == 1) {
expr_ty elt = asdl_seq_GET(elts, 0);
if (elt->kind != Starred_kind) {
VISIT(c, expr, elt);
start = NULL;
}
}
if (start) {
VISIT(c, expr, gen->iter);
ADDOP(c, GET_ITER);
}
}
if (start) {
depth++;
compiler_use_next_block(c, start);
ADDOP_JREL(c, FOR_ITER, anchor);
NEXT_BLOCK(c);
}
compiler_use_next_block(c, start);
ADDOP_JREL(c, FOR_ITER, anchor);
NEXT_BLOCK(c);
VISIT(c, expr, gen->target);

/* XXX this needs to be cleaned up...a lot! */
Expand All @@ -3844,7 +3874,7 @@ compiler_sync_comprehension_generator(struct compiler *c,

if (++gen_index < asdl_seq_LEN(generators))
if (!compiler_comprehension_generator(c,
generators, gen_index,
generators, gen_index, depth,
elt, val, type))
return 0;

Expand All @@ -3859,18 +3889,18 @@ compiler_sync_comprehension_generator(struct compiler *c,
break;
case COMP_LISTCOMP:
VISIT(c, expr, elt);
ADDOP_I(c, LIST_APPEND, gen_index + 1);
ADDOP_I(c, LIST_APPEND, depth + 1);
break;
case COMP_SETCOMP:
VISIT(c, expr, elt);
ADDOP_I(c, SET_ADD, gen_index + 1);
ADDOP_I(c, SET_ADD, depth + 1);
break;
case COMP_DICTCOMP:
/* With 'd[k] = v', v is evaluated before k, so we do
the same. */
VISIT(c, expr, val);
VISIT(c, expr, elt);
ADDOP_I(c, MAP_ADD, gen_index + 1);
ADDOP_I(c, MAP_ADD, depth + 1);
break;
default:
return 0;
Expand All @@ -3879,15 +3909,18 @@ compiler_sync_comprehension_generator(struct compiler *c,
compiler_use_next_block(c, skip);
}
compiler_use_next_block(c, if_cleanup);
ADDOP_JABS(c, JUMP_ABSOLUTE, start);
compiler_use_next_block(c, anchor);
if (start) {
ADDOP_JABS(c, JUMP_ABSOLUTE, start);
compiler_use_next_block(c, anchor);
}

return 1;
}

static int
compiler_async_comprehension_generator(struct compiler *c,
asdl_seq *generators, int gen_index,
int depth,
expr_ty elt, expr_ty val, int type)
{
_Py_IDENTIFIER(StopAsyncIteration);
Expand Down Expand Up @@ -3971,9 +4004,10 @@ compiler_async_comprehension_generator(struct compiler *c,
NEXT_BLOCK(c);
}

depth++;
if (++gen_index < asdl_seq_LEN(generators))
if (!compiler_comprehension_generator(c,
generators, gen_index,
generators, gen_index, depth,
elt, val, type))
return 0;

Expand All @@ -3988,18 +4022,18 @@ compiler_async_comprehension_generator(struct compiler *c,
break;
case COMP_LISTCOMP:
VISIT(c, expr, elt);
ADDOP_I(c, LIST_APPEND, gen_index + 1);
ADDOP_I(c, LIST_APPEND, depth + 1);
break;
case COMP_SETCOMP:
VISIT(c, expr, elt);
ADDOP_I(c, SET_ADD, gen_index + 1);
ADDOP_I(c, SET_ADD, depth + 1);
break;
case COMP_DICTCOMP:
/* With 'd[k] = v', v is evaluated before k, so we do
the same. */
VISIT(c, expr, val);
VISIT(c, expr, elt);
ADDOP_I(c, MAP_ADD, gen_index + 1);
ADDOP_I(c, MAP_ADD, depth + 1);
break;
default:
return 0;
Expand Down Expand Up @@ -4067,7 +4101,7 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type,
ADDOP_I(c, op, 0);
}

if (!compiler_comprehension_generator(c, generators, 0, elt,
if (!compiler_comprehension_generator(c, generators, 0, 0, elt,
val, type))
goto error_in_scope;

Expand Down