Skip to content

Commit 3032fcd

Browse files
nineteendopicnixzeendebakpterlend-aaslandrhettinger
authored
pythongh-119793: Add optional length-checking to map() (pythonGH-120471)
Co-authored-by: Bénédikt Tran <[email protected]> Co-authored-by: Pieter Eendebak <[email protected]> Co-authored-by: Erlend E. Aasland <[email protected]> Co-authored-by: Raymond Hettinger <[email protected]>
1 parent bfc1d25 commit 3032fcd

File tree

6 files changed

+210
-17
lines changed

6 files changed

+210
-17
lines changed

Doc/library/functions.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,14 +1205,19 @@ are always available. They are listed here in alphabetical order.
12051205
unchanged from previous versions.
12061206

12071207

1208-
.. function:: map(function, iterable, *iterables)
1208+
.. function:: map(function, iterable, /, *iterables, strict=False)
12091209

12101210
Return an iterator that applies *function* to every item of *iterable*,
12111211
yielding the results. If additional *iterables* arguments are passed,
12121212
*function* must take that many arguments and is applied to the items from all
12131213
iterables in parallel. With multiple iterables, the iterator stops when the
1214-
shortest iterable is exhausted. For cases where the function inputs are
1215-
already arranged into argument tuples, see :func:`itertools.starmap`\.
1214+
shortest iterable is exhausted. If *strict* is ``True`` and one of the
1215+
iterables is exhausted before the others, a :exc:`ValueError` is raised. For
1216+
cases where the function inputs are already arranged into argument tuples,
1217+
see :func:`itertools.starmap`.
1218+
1219+
.. versionchanged:: 3.14
1220+
Added the *strict* parameter.
12161221

12171222

12181223
.. function:: max(iterable, *, key=None)

Doc/whatsnew/3.14.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ Improved error messages
175175
Other language changes
176176
======================
177177

178+
* The :func:`map` built-in now has an optional keyword-only *strict* flag
179+
like :func:`zip` to check that all the iterables are of equal length.
180+
(Contributed by Wannes Boeykens in :gh:`119793`.)
181+
178182
* Incorrect usage of :keyword:`await` and asynchronous comprehensions
179183
is now detected even if the code is optimized away by the :option:`-O`
180184
command-line option. For example, ``python -O -c 'assert await 1'``

Lib/test/test_builtin.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def filter_char(arg):
148148
def map_char(arg):
149149
return chr(ord(arg)+1)
150150

151+
def pack(*args):
152+
return args
153+
151154
class BuiltinTest(unittest.TestCase):
152155
# Helper to check picklability
153156
def check_iter_pickle(self, it, seq, proto):
@@ -1269,6 +1272,108 @@ def test_map_pickle(self):
12691272
m2 = map(map_char, "Is this the real life?")
12701273
self.check_iter_pickle(m1, list(m2), proto)
12711274

1275+
# strict map tests based on strict zip tests
1276+
1277+
def test_map_pickle_strict(self):
1278+
a = (1, 2, 3)
1279+
b = (4, 5, 6)
1280+
t = [(1, 4), (2, 5), (3, 6)]
1281+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1282+
m1 = map(pack, a, b, strict=True)
1283+
self.check_iter_pickle(m1, t, proto)
1284+
1285+
def test_map_pickle_strict_fail(self):
1286+
a = (1, 2, 3)
1287+
b = (4, 5, 6, 7)
1288+
t = [(1, 4), (2, 5), (3, 6)]
1289+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1290+
m1 = map(pack, a, b, strict=True)
1291+
m2 = pickle.loads(pickle.dumps(m1, proto))
1292+
self.assertEqual(self.iter_error(m1, ValueError), t)
1293+
self.assertEqual(self.iter_error(m2, ValueError), t)
1294+
1295+
def test_map_strict(self):
1296+
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
1297+
((1, 'a'), (2, 'b'), (3, 'c')))
1298+
self.assertRaises(ValueError, tuple,
1299+
map(pack, (1, 2, 3, 4), 'abc', strict=True))
1300+
self.assertRaises(ValueError, tuple,
1301+
map(pack, (1, 2), 'abc', strict=True))
1302+
self.assertRaises(ValueError, tuple,
1303+
map(pack, (1, 2), (1, 2), 'abc', strict=True))
1304+
1305+
def test_map_strict_iterators(self):
1306+
x = iter(range(5))
1307+
y = [0]
1308+
z = iter(range(5))
1309+
self.assertRaises(ValueError, list,
1310+
(map(pack, x, y, z, strict=True)))
1311+
self.assertEqual(next(x), 2)
1312+
self.assertEqual(next(z), 1)
1313+
1314+
def test_map_strict_error_handling(self):
1315+
1316+
class Error(Exception):
1317+
pass
1318+
1319+
class Iter:
1320+
def __init__(self, size):
1321+
self.size = size
1322+
def __iter__(self):
1323+
return self
1324+
def __next__(self):
1325+
self.size -= 1
1326+
if self.size < 0:
1327+
raise Error
1328+
return self.size
1329+
1330+
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
1331+
self.assertEqual(l1, [("A", 0)])
1332+
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
1333+
self.assertEqual(l2, [("A", 1, "A")])
1334+
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
1335+
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
1336+
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
1337+
self.assertEqual(l4, [("A", 2), ("B", 1)])
1338+
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
1339+
self.assertEqual(l5, [(0, "A")])
1340+
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
1341+
self.assertEqual(l6, [(1, "A")])
1342+
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
1343+
self.assertEqual(l7, [(1, "A"), (0, "B")])
1344+
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
1345+
self.assertEqual(l8, [(2, "A"), (1, "B")])
1346+
1347+
def test_map_strict_error_handling_stopiteration(self):
1348+
1349+
class Iter:
1350+
def __init__(self, size):
1351+
self.size = size
1352+
def __iter__(self):
1353+
return self
1354+
def __next__(self):
1355+
self.size -= 1
1356+
if self.size < 0:
1357+
raise StopIteration
1358+
return self.size
1359+
1360+
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
1361+
self.assertEqual(l1, [("A", 0)])
1362+
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
1363+
self.assertEqual(l2, [("A", 1, "A")])
1364+
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
1365+
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
1366+
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
1367+
self.assertEqual(l4, [("A", 2), ("B", 1)])
1368+
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
1369+
self.assertEqual(l5, [(0, "A")])
1370+
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
1371+
self.assertEqual(l6, [(1, "A")])
1372+
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
1373+
self.assertEqual(l7, [(1, "A"), (0, "B")])
1374+
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
1375+
self.assertEqual(l8, [(2, "A"), (1, "B")])
1376+
12721377
def test_max(self):
12731378
self.assertEqual(max('123123'), '3')
12741379
self.assertEqual(max(1, 2, 3), 3)

Lib/test/test_itertools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,10 +2433,10 @@ class subclass(cls):
24332433
subclass(*args, newarg=3)
24342434

24352435
for cls, args, result in testcases:
2436-
# Constructors of repeat, zip, compress accept keyword arguments.
2436+
# Constructors of repeat, zip, map, compress accept keyword arguments.
24372437
# Their subclasses need overriding __new__ to support new
24382438
# keyword arguments.
2439-
if cls in [repeat, zip, compress]:
2439+
if cls in [repeat, zip, map, compress]:
24402440
continue
24412441
with self.subTest(cls):
24422442
class subclass_with_init(cls):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The :func:`map` built-in now has an optional keyword-only *strict* flag
2+
like :func:`zip` to check that all the iterables are of equal length.
3+
Patch by Wannes Boeykens.

Python/bltinmodule.c

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ typedef struct {
13111311
PyObject_HEAD
13121312
PyObject *iters;
13131313
PyObject *func;
1314+
int strict;
13141315
} mapobject;
13151316

13161317
static PyObject *
@@ -1319,10 +1320,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
13191320
PyObject *it, *iters, *func;
13201321
mapobject *lz;
13211322
Py_ssize_t numargs, i;
1323+
int strict = 0;
13221324

1323-
if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
1324-
!_PyArg_NoKeywords("map", kwds))
1325-
return NULL;
1325+
if (kwds) {
1326+
PyObject *empty = PyTuple_New(0);
1327+
if (empty == NULL) {
1328+
return NULL;
1329+
}
1330+
static char *kwlist[] = {"strict", NULL};
1331+
int parsed = PyArg_ParseTupleAndKeywords(
1332+
empty, kwds, "|$p:map", kwlist, &strict);
1333+
Py_DECREF(empty);
1334+
if (!parsed) {
1335+
return NULL;
1336+
}
1337+
}
13261338

13271339
numargs = PyTuple_Size(args);
13281340
if (numargs < 2) {
@@ -1354,6 +1366,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
13541366
lz->iters = iters;
13551367
func = PyTuple_GET_ITEM(args, 0);
13561368
lz->func = Py_NewRef(func);
1369+
lz->strict = strict;
13571370

13581371
return (PyObject *)lz;
13591372
}
@@ -1363,11 +1376,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
13631376
size_t nargsf, PyObject *kwnames)
13641377
{
13651378
PyTypeObject *tp = _PyType_CAST(type);
1366-
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
1367-
return NULL;
1368-
}
13691379

13701380
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
1381+
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
1382+
// Fallback to map_new()
1383+
PyThreadState *tstate = _PyThreadState_GET();
1384+
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
1385+
}
1386+
13711387
if (nargs < 2) {
13721388
PyErr_SetString(PyExc_TypeError,
13731389
"map() must have at least two arguments.");
@@ -1395,6 +1411,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
13951411
}
13961412
lz->iters = iters;
13971413
lz->func = Py_NewRef(args[0]);
1414+
lz->strict = 0;
13981415

13991416
return (PyObject *)lz;
14001417
}
@@ -1419,6 +1436,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
14191436
static PyObject *
14201437
map_next(mapobject *lz)
14211438
{
1439+
Py_ssize_t i;
14221440
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
14231441
PyObject **stack;
14241442
PyObject *result = NULL;
@@ -1437,10 +1455,13 @@ map_next(mapobject *lz)
14371455
}
14381456

14391457
Py_ssize_t nargs = 0;
1440-
for (Py_ssize_t i=0; i < niters; i++) {
1458+
for (i=0; i < niters; i++) {
14411459
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
14421460
PyObject *val = Py_TYPE(it)->tp_iternext(it);
14431461
if (val == NULL) {
1462+
if (lz->strict) {
1463+
goto check;
1464+
}
14441465
goto exit;
14451466
}
14461467
stack[i] = val;
@@ -1450,13 +1471,50 @@ map_next(mapobject *lz)
14501471
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
14511472

14521473
exit:
1453-
for (Py_ssize_t i=0; i < nargs; i++) {
1474+
for (i=0; i < nargs; i++) {
14541475
Py_DECREF(stack[i]);
14551476
}
14561477
if (stack != small_stack) {
14571478
PyMem_Free(stack);
14581479
}
14591480
return result;
1481+
check:
1482+
if (PyErr_Occurred()) {
1483+
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
1484+
// next() on argument i raised an exception (not StopIteration)
1485+
return NULL;
1486+
}
1487+
PyErr_Clear();
1488+
}
1489+
if (i) {
1490+
// ValueError: map() argument 2 is shorter than argument 1
1491+
// ValueError: map() argument 3 is shorter than arguments 1-2
1492+
const char* plural = i == 1 ? " " : "s 1-";
1493+
return PyErr_Format(PyExc_ValueError,
1494+
"map() argument %d is shorter than argument%s%d",
1495+
i + 1, plural, i);
1496+
}
1497+
for (i = 1; i < niters; i++) {
1498+
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
1499+
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
1500+
if (val) {
1501+
Py_DECREF(val);
1502+
const char* plural = i == 1 ? " " : "s 1-";
1503+
return PyErr_Format(PyExc_ValueError,
1504+
"map() argument %d is longer than argument%s%d",
1505+
i + 1, plural, i);
1506+
}
1507+
if (PyErr_Occurred()) {
1508+
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
1509+
// next() on argument i raised an exception (not StopIteration)
1510+
return NULL;
1511+
}
1512+
PyErr_Clear();
1513+
}
1514+
// Argument i is exhausted. So far so good...
1515+
}
1516+
// All arguments are exhausted. Success!
1517+
goto exit;
14601518
}
14611519

14621520
static PyObject *
@@ -1473,21 +1531,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
14731531
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
14741532
}
14751533

1534+
if (lz->strict) {
1535+
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
1536+
}
14761537
return Py_BuildValue("ON", Py_TYPE(lz), args);
14771538
}
14781539

1540+
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
1541+
1542+
static PyObject *
1543+
map_setstate(mapobject *lz, PyObject *state)
1544+
{
1545+
int strict = PyObject_IsTrue(state);
1546+
if (strict < 0) {
1547+
return NULL;
1548+
}
1549+
lz->strict = strict;
1550+
Py_RETURN_NONE;
1551+
}
1552+
14791553
static PyMethodDef map_methods[] = {
14801554
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
1555+
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
14811556
{NULL, NULL} /* sentinel */
14821557
};
14831558

14841559

14851560
PyDoc_STRVAR(map_doc,
1486-
"map(function, iterable, /, *iterables)\n\
1561+
"map(function, iterable, /, *iterables, strict=False)\n\
14871562
--\n\
14881563
\n\
14891564
Make an iterator that computes the function using arguments from\n\
1490-
each of the iterables. Stops when the shortest iterable is exhausted.");
1565+
each of the iterables. Stops when the shortest iterable is exhausted.\n\
1566+
\n\
1567+
If strict is true and one of the arguments is exhausted before the others,\n\
1568+
raise a ValueError.");
14911569

14921570
PyTypeObject PyMap_Type = {
14931571
PyVarObject_HEAD_INIT(&PyType_Type, 0)
@@ -3068,8 +3146,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
30683146
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
30693147
}
30703148

3071-
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
3072-
30733149
static PyObject *
30743150
zip_setstate(zipobject *lz, PyObject *state)
30753151
{

0 commit comments

Comments
 (0)