Skip to content

Commit 90ad156

Browse files
andfoypeterbell10
andauthored
BUG: immortalize uarray global strings in order to prevent negative refcounts (#22038)
#### Reference issue Fixes #21214 Closes #21218 #### What does this implement/fix? This PR addresses the segfaults caused by the occurrence of negative refcounts when uarray static strings were being released when the interpreter didn't exist. --------- Co-authored-by: peterbell10 <[email protected]>
1 parent f2e38db commit 90ad156

File tree

1 file changed

+52
-26
lines changed

1 file changed

+52
-26
lines changed

scipy/_lib/_uarray/_uarray_dispatch.cxx

+52-26
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@
1515

1616
namespace {
1717

18+
template <typename T>
19+
class immortal {
20+
alignas(T) std::byte storage[sizeof(T)];
21+
22+
public:
23+
template <typename... Args>
24+
immortal(Args&&... args) {
25+
// Construct new T in storage
26+
new(&storage) T(std::forward<Args>(args)...);
27+
}
28+
~immortal() {
29+
// Intentionally don't call destructor
30+
}
31+
32+
T* get() { return reinterpret_cast<T*>(&storage); }
33+
const T* get() const { return reinterpret_cast<const T*>(&storage); }
34+
const T* get_const() const { return reinterpret_cast<const T*>(&storage); }
35+
36+
const T* operator ->() const { return get(); }
37+
T* operator ->() { return get(); }
38+
39+
T& operator*() { return *get(); }
40+
const T& operator*() const { return *get(); }
41+
42+
};
43+
1844
/** Handle to a python object that automatically DECREFs */
1945
class py_ref {
2046
explicit py_ref(PyObject * object): obj_(object) {}
@@ -129,8 +155,8 @@ using global_state_t = std::unordered_map<std::string, global_backends>;
129155
using local_state_t = std::unordered_map<std::string, local_backends>;
130156

131157
static py_ref BackendNotImplementedError;
132-
static global_state_t global_domain_map;
133-
thread_local global_state_t * current_global_state = &global_domain_map;
158+
static immortal<global_state_t> global_domain_map;
159+
thread_local global_state_t * current_global_state = global_domain_map.get();
134160
thread_local global_state_t thread_local_domain_map;
135161
thread_local local_state_t local_domain_map;
136162

@@ -140,30 +166,30 @@ Using these with PyObject_GetAttr is faster than PyObject_GetAttrString which
140166
has to create a new python string internally.
141167
*/
142168
struct {
143-
py_ref ua_convert;
144-
py_ref ua_domain;
145-
py_ref ua_function;
169+
immortal<py_ref> ua_convert;
170+
immortal<py_ref> ua_domain;
171+
immortal<py_ref> ua_function;
146172

147173
bool init() {
148-
ua_convert = py_ref::steal(PyUnicode_InternFromString("__ua_convert__"));
149-
if (!ua_convert)
174+
*ua_convert = py_ref::steal(PyUnicode_InternFromString("__ua_convert__"));
175+
if (!*ua_convert)
150176
return false;
151177

152-
ua_domain = py_ref::steal(PyUnicode_InternFromString("__ua_domain__"));
153-
if (!ua_domain)
178+
*ua_domain = py_ref::steal(PyUnicode_InternFromString("__ua_domain__"));
179+
if (!*ua_domain)
154180
return false;
155181

156-
ua_function = py_ref::steal(PyUnicode_InternFromString("__ua_function__"));
157-
if (!ua_function)
182+
*ua_function = py_ref::steal(PyUnicode_InternFromString("__ua_function__"));
183+
if (!*ua_function)
158184
return false;
159185

160186
return true;
161187
}
162188

163189
void clear() {
164-
ua_convert.reset();
165-
ua_domain.reset();
166-
ua_function.reset();
190+
ua_convert->reset();
191+
ua_domain->reset();
192+
ua_function->reset();
167193
}
168194
} identifiers;
169195

@@ -202,7 +228,7 @@ std::string domain_to_string(PyObject * domain) {
202228

203229
Py_ssize_t backend_get_num_domains(PyObject * backend) {
204230
auto domain =
205-
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain.get()));
231+
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain->get()));
206232
if (!domain)
207233
return -1;
208234

@@ -225,7 +251,7 @@ enum class LoopReturn { Continue, Break, Error };
225251
template <typename Func>
226252
LoopReturn backend_for_each_domain(PyObject * backend, Func f) {
227253
auto domain =
228-
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain.get()));
254+
py_ref::steal(PyObject_GetAttr(backend, identifiers.ua_domain->get()));
229255
if (!domain)
230256
return LoopReturn::Error;
231257

@@ -537,7 +563,7 @@ struct BackendState {
537563

538564
/** Clean up global python references when the module is finalized. */
539565
void globals_free(void * /* self */) {
540-
global_domain_map.clear();
566+
global_domain_map->clear();
541567
BackendNotImplementedError.reset();
542568
identifiers.clear();
543569
}
@@ -550,7 +576,7 @@ void globals_free(void * /* self */) {
550576
* cleanup.
551577
*/
552578
int globals_traverse(PyObject * self, visitproc visit, void * arg) {
553-
for (const auto & kv : global_domain_map) {
579+
for (const auto & kv : *global_domain_map) {
554580
const auto & globals = kv.second;
555581
PyObject * backend = globals.global.backend.get();
556582
Py_VISIT(backend);
@@ -563,7 +589,7 @@ int globals_traverse(PyObject * self, visitproc visit, void * arg) {
563589
}
564590

565591
int globals_clear(PyObject * /* self */) {
566-
global_domain_map.clear();
592+
global_domain_map->clear();
567593
return 0;
568594
}
569595

@@ -1170,7 +1196,7 @@ py_ref Function::canonicalize_kwargs(PyObject * kwargs) {
11701196

11711197
py_func_args Function::replace_dispatchables(
11721198
PyObject * backend, PyObject * args, PyObject * kwargs, PyObject * coerce) {
1173-
auto has_ua_convert = PyObject_HasAttr(backend, identifiers.ua_convert.get());
1199+
auto has_ua_convert = PyObject_HasAttr(backend, identifiers.ua_convert->get());
11741200
if (!has_ua_convert) {
11751201
return {py_ref::ref(args), py_ref::ref(kwargs)};
11761202
}
@@ -1182,7 +1208,7 @@ py_func_args Function::replace_dispatchables(
11821208

11831209
PyObject * convert_args[] = {backend, dispatchables.get(), coerce};
11841210
auto res = py_ref::steal(Q_PyObject_VectorcallMethod(
1185-
identifiers.ua_convert.get(), convert_args,
1211+
identifiers.ua_convert->get(), convert_args,
11861212
array_size(convert_args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr));
11871213
if (!res) {
11881214
return {};
@@ -1287,7 +1313,7 @@ PyObject * Function::call(PyObject * args_, PyObject * kwargs_) {
12871313
backend, reinterpret_cast<PyObject *>(this), new_args.args.get(),
12881314
new_args.kwargs.get()};
12891315
result = py_ref::steal(Q_PyObject_VectorcallMethod(
1290-
identifiers.ua_function.get(), args,
1316+
identifiers.ua_function->get(), args,
12911317
array_size(args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr));
12921318

12931319
// raise BackendNotImplemeted is equivalent to return NotImplemented
@@ -1499,7 +1525,7 @@ PyObject * get_state(PyObject * /* self */, PyObject * /* args */) {
14991525

15001526
output->locals = local_domain_map;
15011527
output->use_thread_local_globals =
1502-
(current_global_state != &global_domain_map);
1528+
(current_global_state != global_domain_map.get());
15031529
output->globals = *current_global_state;
15041530

15051531
return ref.release();
@@ -1523,7 +1549,7 @@ PyObject * set_state(PyObject * /* self */, PyObject * args) {
15231549
bool use_thread_local_globals =
15241550
(!reset_allowed) || state->use_thread_local_globals;
15251551
current_global_state =
1526-
use_thread_local_globals ? &thread_local_domain_map : &global_domain_map;
1552+
use_thread_local_globals ? &thread_local_domain_map : global_domain_map.get();
15271553

15281554
if (use_thread_local_globals)
15291555
thread_local_domain_map = state->globals;
@@ -1554,7 +1580,7 @@ PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
15541580
auto result = for_each_backend_in_domain(
15551581
domain, [&](PyObject * backend, bool coerce_backend) {
15561582
auto has_ua_convert =
1557-
PyObject_HasAttr(backend, identifiers.ua_convert.get());
1583+
PyObject_HasAttr(backend, identifiers.ua_convert->get());
15581584

15591585
if (!has_ua_convert) {
15601586
// If no __ua_convert__, assume it won't accept the type
@@ -1566,7 +1592,7 @@ PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
15661592
(coerce && coerce_backend) ? Py_True : Py_False};
15671593

15681594
auto res = py_ref::steal(Q_PyObject_VectorcallMethod(
1569-
identifiers.ua_convert.get(), convert_args,
1595+
identifiers.ua_convert->get(), convert_args,
15701596
array_size(convert_args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET,
15711597
nullptr));
15721598
if (!res) {

0 commit comments

Comments
 (0)