Skip to content

Commit e6e1a1d

Browse files
committed
Call PySys_SetArgv when initializing interpreter.
1 parent 1491c94 commit e6e1a1d

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

include/pybind11/embed.h

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,95 @@ struct embedded_module {
8787
}
8888
};
8989

90+
/// Python 2.x/3.x-compatible version of `PySys_SetArgv`
91+
inline void set_interpreter_argv(int argc, char** argv, bool add_current_dir_to_path) {
92+
// Before it was special-cased in python 3.8, passing an empty or null argv
93+
// caused a segfault, so we have to reimplement the special case ourselves.
94+
char** safe_argv = argv;
95+
if (nullptr == argv || argc <= 0) {
96+
safe_argv = new char*[1];
97+
if (nullptr == safe_argv) return;
98+
safe_argv[0] = new char[1];
99+
if (nullptr == safe_argv[0]) {
100+
delete[] safe_argv;
101+
return;
102+
}
103+
safe_argv[0][0] = '\0';
104+
argc = 1;
105+
}
106+
#if PY_MAJOR_VERSION >= 3
107+
// SetArgv* on python 3 takes wchar_t, so we have to convert.
108+
wchar_t** widened_argv = new wchar_t*[argc];
109+
for (int ii = 0; ii < argc; ++ii) {
110+
# if PY_MINOR_VERSION >= 5
111+
// From Python 3.5 onwards, we're supposed to use Py_DecodeLocale to
112+
// generate the wchar_t version of argv.
113+
widened_argv[ii] = Py_DecodeLocale(safe_argv[ii], nullptr);
114+
# define FREE_WIDENED_ARG(X) PyMem_RawFree(X)
115+
# else
116+
// Before Python 3.5, we're stuck with mbstowcs, which may or may not
117+
// actually work. Mercifully, pyconfig.h provides this define:
118+
# ifdef HAVE_BROKEN_MBSTOWCS
119+
size_t count = strlen(safe_argv[ii]);
120+
# else
121+
size_t count = mbstowcs(nullptr, safe_argv[ii], 0);
122+
# endif
123+
widened_argv[ii] = nullptr;
124+
if (count != static_cast<size_t>(-1)) {
125+
widened_argv[ii] = new wchar_t[count + 1];
126+
mbstowcs(widened_argv[ii], safe_argv[ii], count + 1);
127+
}
128+
# define FREE_WIDENED_ARG(X) delete[] X
129+
# endif
130+
if (nullptr == widened_argv[ii]) {
131+
// Either we ran out of memory or had a unicode encoding issue.
132+
// Free what we've encoded so far and bail.
133+
for (--ii; ii >= 0; --ii)
134+
FREE_WIDENED_ARG(widened_argv[ii]);
135+
return;
136+
}
137+
}
138+
139+
# if PY_MINOR_VERSION < 1 || (PY_MINOR_VERSION == 1 && PY_MICRO_VERSION < 3)
140+
# define NEED_PYRUN_TO_SANITIZE_PATH 1
141+
// don't have SetArgvEx yet
142+
PySys_SetArgv(argc, widened_argv);
143+
# else
144+
PySys_SetArgvEx(argc, widened_argv, add_current_dir_to_path ? 1 : 0);
145+
# endif
146+
147+
// PySys_SetArgv makes new PyUnicode objects so we can clean up this memory
148+
if (nullptr != widened_argv) {
149+
for (int ii = 0; ii < argc; ++ii)
150+
if (nullptr != widened_argv[ii])
151+
FREE_WIDENED_ARG(widened_argv[ii]);
152+
delete[] widened_argv;
153+
}
154+
# undef FREE_WIDENED_ARG
155+
#else
156+
// python 2.x
157+
# if PY_MINOR_VERSION < 6 || (PY_MINOR_VERSION == 6 && PY_MICRO_VERSION < 6)
158+
# define NEED_PYRUN_TO_SANITIZE_PATH 1
159+
// don't have SetArgvEx yet
160+
PySys_SetArgv(argc, safe_argv);
161+
# else
162+
PySys_SetArgvEx(argc, safe_argv, add_current_dir_to_path ? 1 : 0);
163+
# endif
164+
#endif
165+
166+
#ifdef NEED_PYRUN_TO_SANITIZE_PATH
167+
# undef NEED_PYRUN_TO_SANITIZE_PATH
168+
if (!add_current_dir_to_path)
169+
PyRun_SimpleString("import sys; sys.path.pop(0)\n");
170+
#endif
171+
172+
// if we allocated new memory to make safe_argv, we need to free it
173+
if (safe_argv != argv) {
174+
delete[] safe_argv[0];
175+
delete[] safe_argv;
176+
}
177+
}
178+
90179
PYBIND11_NAMESPACE_END(detail)
91180

92181
/** \rst
@@ -102,14 +191,16 @@ PYBIND11_NAMESPACE_END(detail)
102191
103192
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
104193
\endrst */
105-
inline void initialize_interpreter(bool init_signal_handlers = true) {
194+
inline void initialize_interpreter(bool init_signal_handlers = true,
195+
int argc = 0,
196+
char** argv = nullptr,
197+
bool add_current_dir_to_path = true) {
106198
if (Py_IsInitialized())
107199
pybind11_fail("The interpreter is already running");
108200

109201
Py_InitializeEx(init_signal_handlers ? 1 : 0);
110202

111-
// Make .py files in the working directory available by default
112-
module::import("sys").attr("path").cast<list>().append(".");
203+
detail::set_interpreter_argv(argc, argv, add_current_dir_to_path);
113204
}
114205

115206
/** \rst
@@ -182,8 +273,11 @@ inline void finalize_interpreter() {
182273
\endrst */
183274
class scoped_interpreter {
184275
public:
185-
scoped_interpreter(bool init_signal_handlers = true) {
186-
initialize_interpreter(init_signal_handlers);
276+
scoped_interpreter(bool init_signal_handlers = true,
277+
int argc = 0,
278+
char** argv = nullptr,
279+
bool add_current_dir_to_path = true) {
280+
initialize_interpreter(init_signal_handlers, argc, argv, add_current_dir_to_path);
187281
}
188282

189283
scoped_interpreter(const scoped_interpreter &) = delete;

tests/test_embed/test_interpreter.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Widget {
2222

2323
std::string the_message() const { return message; }
2424
virtual int the_answer() const = 0;
25+
virtual std::string argv0() const = 0;
2526

2627
private:
2728
std::string message;
@@ -31,6 +32,7 @@ class PyWidget final : public Widget {
3132
using Widget::Widget;
3233

3334
int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
35+
std::string argv0() const override { PYBIND11_OVERLOAD_PURE(std::string, Widget, argv0); }
3436
};
3537

3638
PYBIND11_EMBEDDED_MODULE(widget_module, m) {
@@ -282,3 +284,25 @@ TEST_CASE("Reload module from file") {
282284
result = module.attr("test")().cast<int>();
283285
REQUIRE(result == 2);
284286
}
287+
288+
TEST_CASE("sys.argv gets initialized properly") {
289+
py::finalize_interpreter();
290+
{
291+
py::scoped_interpreter default_scope;
292+
auto module = py::module::import("test_interpreter");
293+
auto py_widget = module.attr("DerivedWidget")("The question");
294+
const auto &cpp_widget = py_widget.cast<const Widget &>();
295+
REQUIRE(cpp_widget.argv0() == "");
296+
}
297+
298+
{
299+
char* argv[] = { strdup("a.out") };
300+
py::scoped_interpreter argv_scope(true, 1, argv);
301+
free(argv[0]);
302+
auto module = py::module::import("test_interpreter");
303+
auto py_widget = module.attr("DerivedWidget")("The question");
304+
const auto &cpp_widget = py_widget.cast<const Widget &>();
305+
REQUIRE(cpp_widget.argv0() == "a.out");
306+
}
307+
py::initialize_interpreter();
308+
}

tests/test_embed/test_interpreter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from widget_module import Widget
3+
import sys
34

45

56
class DerivedWidget(Widget):
@@ -8,3 +9,6 @@ def __init__(self, message):
89

910
def the_answer(self):
1011
return 42
12+
13+
def argv0(self):
14+
return sys.argv[0]

0 commit comments

Comments
 (0)