Skip to content

Commit b9a31a1

Browse files
ericsnowcurrentlyaisk
authored andcommitted
pythongh-76785: Add Interpreter.prepare_main() (pythongh-113021)
This is one of the last pieces to get test.support.interpreters in sync with PEP 734.
1 parent 5818fae commit b9a31a1

File tree

6 files changed

+146
-17
lines changed

6 files changed

+146
-17
lines changed

Lib/test/support/interpreters/__init__.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,15 @@ def close(self):
130130
"""
131131
return _interpreters.destroy(self._id)
132132

133-
def exec_sync(self, code, /, channels=None):
133+
def prepare_main(self, ns=None, /, **kwargs):
134+
"""Bind the given values into the interpreter's __main__.
135+
136+
The values must be shareable.
137+
"""
138+
ns = dict(ns, **kwargs) if ns is not None else kwargs
139+
_interpreters.set___main___attrs(self._id, ns)
140+
141+
def exec_sync(self, code, /):
134142
"""Run the given source code in the interpreter.
135143
136144
This is essentially the same as calling the builtin "exec"
@@ -148,13 +156,13 @@ def exec_sync(self, code, /, channels=None):
148156
that time, the previous interpreter is allowed to run
149157
in other threads.
150158
"""
151-
excinfo = _interpreters.exec(self._id, code, channels)
159+
excinfo = _interpreters.exec(self._id, code)
152160
if excinfo is not None:
153161
raise ExecFailure(excinfo)
154162

155-
def run(self, code, /, channels=None):
163+
def run(self, code, /):
156164
def task():
157-
self.exec_sync(code, channels=channels)
165+
self.exec_sync(code)
158166
t = threading.Thread(target=task)
159167
t.start()
160168
return t

Lib/test/test__xxinterpchannels.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,12 @@ def test_run_string_arg_unresolved(self):
586586
cid = channels.create()
587587
interp = interpreters.create()
588588

589+
interpreters.set___main___attrs(interp, dict(cid=cid.send))
589590
out = _run_output(interp, dedent("""
590591
import _xxinterpchannels as _channels
591592
print(cid.end)
592593
_channels.send(cid, b'spam', blocking=False)
593-
"""),
594-
dict(cid=cid.send))
594+
"""))
595595
obj = channels.recv(cid)
596596

597597
self.assertEqual(obj, b'spam')

Lib/test/test__xxsubinterpreters.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def _captured_script(script):
3333
return wrapped, open(r, encoding="utf-8")
3434

3535

36-
def _run_output(interp, request, shared=None):
36+
def _run_output(interp, request):
3737
script, rpipe = _captured_script(request)
3838
with rpipe:
39-
interpreters.run_string(interp, script, shared)
39+
interpreters.run_string(interp, script)
4040
return rpipe.read()
4141

4242

@@ -630,10 +630,10 @@ def test_shareable_types(self):
630630
]
631631
for obj in objects:
632632
with self.subTest(obj):
633+
interpreters.set___main___attrs(interp, dict(obj=obj))
633634
interpreters.run_string(
634635
interp,
635636
f'assert(obj == {obj!r})',
636-
shared=dict(obj=obj),
637637
)
638638

639639
def test_os_exec(self):
@@ -721,7 +721,8 @@ def test_with_shared(self):
721721
with open({w}, 'wb') as chan:
722722
pickle.dump(ns, chan)
723723
""")
724-
interpreters.run_string(self.id, script, shared)
724+
interpreters.set___main___attrs(self.id, shared)
725+
interpreters.run_string(self.id, script)
725726
with open(r, 'rb') as chan:
726727
ns = pickle.load(chan)
727728

@@ -742,7 +743,8 @@ def test_shared_overwrites(self):
742743
ns2 = dict(vars())
743744
del ns2['__builtins__']
744745
""")
745-
interpreters.run_string(self.id, script, shared)
746+
interpreters.set___main___attrs(self.id, shared)
747+
interpreters.run_string(self.id, script)
746748

747749
r, w = os.pipe()
748750
script = dedent(f"""
@@ -773,7 +775,8 @@ def test_shared_overwrites_default_vars(self):
773775
with open({w}, 'wb') as chan:
774776
pickle.dump(ns, chan)
775777
""")
776-
interpreters.run_string(self.id, script, shared)
778+
interpreters.set___main___attrs(self.id, shared)
779+
interpreters.run_string(self.id, script)
777780
with open(r, 'rb') as chan:
778781
ns = pickle.load(chan)
779782

@@ -1036,7 +1039,8 @@ def script():
10361039
with open(w, 'w', encoding="utf-8") as spipe:
10371040
with contextlib.redirect_stdout(spipe):
10381041
print('it worked!', end='')
1039-
interpreters.run_func(self.id, script, shared=dict(w=w))
1042+
interpreters.set___main___attrs(self.id, dict(w=w))
1043+
interpreters.run_func(self.id, script)
10401044

10411045
with open(r, encoding="utf-8") as outfile:
10421046
out = outfile.read()
@@ -1052,7 +1056,8 @@ def script():
10521056
with contextlib.redirect_stdout(spipe):
10531057
print('it worked!', end='')
10541058
def f():
1055-
interpreters.run_func(self.id, script, shared=dict(w=w))
1059+
interpreters.set___main___attrs(self.id, dict(w=w))
1060+
interpreters.run_func(self.id, script)
10561061
t = threading.Thread(target=f)
10571062
t.start()
10581063
t.join()
@@ -1072,7 +1077,8 @@ def script():
10721077
with contextlib.redirect_stdout(spipe):
10731078
print('it worked!', end='')
10741079
code = script.__code__
1075-
interpreters.run_func(self.id, code, shared=dict(w=w))
1080+
interpreters.set___main___attrs(self.id, dict(w=w))
1081+
interpreters.run_func(self.id, code)
10761082

10771083
with open(r, encoding="utf-8") as outfile:
10781084
out = outfile.read()

Lib/test/test_interpreters/test_api.py

+57
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,63 @@ def task():
452452
self.assertEqual(os.read(r_interp, 1), FINISHED)
453453

454454

455+
class TestInterpreterPrepareMain(TestBase):
456+
457+
def test_empty(self):
458+
interp = interpreters.create()
459+
with self.assertRaises(ValueError):
460+
interp.prepare_main()
461+
462+
def test_dict(self):
463+
values = {'spam': 42, 'eggs': 'ham'}
464+
interp = interpreters.create()
465+
interp.prepare_main(values)
466+
out = _run_output(interp, dedent("""
467+
print(spam, eggs)
468+
"""))
469+
self.assertEqual(out.strip(), '42 ham')
470+
471+
def test_tuple(self):
472+
values = {'spam': 42, 'eggs': 'ham'}
473+
values = tuple(values.items())
474+
interp = interpreters.create()
475+
interp.prepare_main(values)
476+
out = _run_output(interp, dedent("""
477+
print(spam, eggs)
478+
"""))
479+
self.assertEqual(out.strip(), '42 ham')
480+
481+
def test_kwargs(self):
482+
values = {'spam': 42, 'eggs': 'ham'}
483+
interp = interpreters.create()
484+
interp.prepare_main(**values)
485+
out = _run_output(interp, dedent("""
486+
print(spam, eggs)
487+
"""))
488+
self.assertEqual(out.strip(), '42 ham')
489+
490+
def test_dict_and_kwargs(self):
491+
values = {'spam': 42, 'eggs': 'ham'}
492+
interp = interpreters.create()
493+
interp.prepare_main(values, foo='bar')
494+
out = _run_output(interp, dedent("""
495+
print(spam, eggs, foo)
496+
"""))
497+
self.assertEqual(out.strip(), '42 ham bar')
498+
499+
def test_not_shareable(self):
500+
interp = interpreters.create()
501+
# XXX TypeError?
502+
with self.assertRaises(ValueError):
503+
interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'})
504+
505+
# Make sure neither was actually bound.
506+
with self.assertRaises(interpreters.ExecFailure):
507+
interp.exec_sync('print(foo)')
508+
with self.assertRaises(interpreters.ExecFailure):
509+
interp.exec_sync('print(spam)')
510+
511+
455512
class TestInterpreterExecSync(TestBase):
456513

457514
def test_success(self):

Lib/test/test_interpreters/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def clean_up_interpreters():
2929
pass # already destroyed
3030

3131

32-
def _run_output(interp, request, channels=None):
32+
def _run_output(interp, request, init=None):
3333
script, rpipe = _captured_script(request)
3434
with rpipe:
35-
interp.exec_sync(script, channels=channels)
35+
if init:
36+
interp.prepare_main(init)
37+
interp.exec_sync(script)
3638
return rpipe.read()
3739

3840

Modules/_xxsubinterpretersmodule.c

+56
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc,
685685
\n\
686686
Return the ID of main interpreter.");
687687

688+
static PyObject *
689+
interp_set___main___attrs(PyObject *self, PyObject *args)
690+
{
691+
PyObject *id, *updates;
692+
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs",
693+
&id, &updates))
694+
{
695+
return NULL;
696+
}
697+
698+
// Look up the interpreter.
699+
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
700+
if (interp == NULL) {
701+
return NULL;
702+
}
703+
704+
// Check the updates.
705+
if (updates != Py_None) {
706+
Py_ssize_t size = PyObject_Size(updates);
707+
if (size < 0) {
708+
return NULL;
709+
}
710+
if (size == 0) {
711+
PyErr_SetString(PyExc_ValueError,
712+
"arg 2 must be a non-empty mapping");
713+
return NULL;
714+
}
715+
}
716+
717+
_PyXI_session session = {0};
718+
719+
// Prep and switch interpreters, including apply the updates.
720+
if (_PyXI_Enter(&session, interp, updates) < 0) {
721+
if (!PyErr_Occurred()) {
722+
_PyXI_ApplyCapturedException(&session);
723+
assert(PyErr_Occurred());
724+
}
725+
else {
726+
assert(!_PyXI_HasCapturedException(&session));
727+
}
728+
return NULL;
729+
}
730+
731+
// Clean up and switch back.
732+
_PyXI_Exit(&session);
733+
734+
Py_RETURN_NONE;
735+
}
736+
737+
PyDoc_STRVAR(set___main___attrs_doc,
738+
"set___main___attrs(id, ns)\n\
739+
\n\
740+
Bind the given attributes in the interpreter's __main__ module.");
741+
688742
static PyUnicodeObject *
689743
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
690744
const char *expected)
@@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = {
10331087
{"run_func", _PyCFunction_CAST(interp_run_func),
10341088
METH_VARARGS | METH_KEYWORDS, run_func_doc},
10351089

1090+
{"set___main___attrs", _PyCFunction_CAST(interp_set___main___attrs),
1091+
METH_VARARGS, set___main___attrs_doc},
10361092
{"is_shareable", _PyCFunction_CAST(object_is_shareable),
10371093
METH_VARARGS | METH_KEYWORDS, is_shareable_doc},
10381094

0 commit comments

Comments
 (0)