@@ -493,6 +493,8 @@ _PyRuntimeState_Init(_PyRuntimeState *runtime)
493
493
return _PyStatus_OK ();
494
494
}
495
495
496
+ static void _xidregistry_clear (struct _xidregistry * );
497
+
496
498
void
497
499
_PyRuntimeState_Fini (_PyRuntimeState * runtime )
498
500
{
@@ -501,6 +503,8 @@ _PyRuntimeState_Fini(_PyRuntimeState *runtime)
501
503
assert (runtime -> object_state .interpreter_leaks == 0 );
502
504
#endif
503
505
506
+ _xidregistry_clear (& runtime -> xidregistry );
507
+
504
508
if (gilstate_tss_initialized (runtime )) {
505
509
gilstate_tss_fini (runtime );
506
510
}
@@ -546,6 +550,11 @@ _PyRuntimeState_ReInitThreads(_PyRuntimeState *runtime)
546
550
for (int i = 0 ; i < NUMLOCKS ; i ++ ) {
547
551
reinit_err += _PyThread_at_fork_reinit (lockptrs [i ]);
548
552
}
553
+ /* PyOS_AfterFork_Child(), which calls this function, later calls
554
+ _PyInterpreterState_DeleteExceptMain(), so we only need to update
555
+ the main interpreter here. */
556
+ assert (runtime -> interpreters .main != NULL );
557
+ runtime -> interpreters .main -> xidregistry .mutex = runtime -> xidregistry .mutex ;
549
558
550
559
PyMem_SetAllocator (PYMEM_DOMAIN_RAW , & old_alloc );
551
560
@@ -695,6 +704,10 @@ init_interpreter(PyInterpreterState *interp,
695
704
interp -> dtoa = (struct _dtoa_state )_dtoa_state_INIT (interp );
696
705
}
697
706
interp -> f_opcode_trace_set = false;
707
+
708
+ assert (runtime -> xidregistry .mutex != NULL );
709
+ interp -> xidregistry .mutex = runtime -> xidregistry .mutex ;
710
+
698
711
interp -> _initialized = 1 ;
699
712
}
700
713
@@ -887,6 +900,10 @@ interpreter_clear(PyInterpreterState *interp, PyThreadState *tstate)
887
900
Py_CLEAR (interp -> builtins );
888
901
Py_CLEAR (interp -> interpreter_trampoline );
889
902
903
+ _xidregistry_clear (& interp -> xidregistry );
904
+ /* The lock is owned by the runtime, so we don't free it here. */
905
+ interp -> xidregistry .mutex = NULL ;
906
+
890
907
if (tstate -> interp == interp ) {
891
908
/* We are now safe to fix tstate->_status.cleared. */
892
909
// XXX Do this (much) earlier?
@@ -2467,23 +2484,27 @@ _PyCrossInterpreterData_Release(_PyCrossInterpreterData *data)
2467
2484
crossinterpdatafunc. It would be simpler and more efficient. */
2468
2485
2469
2486
static int
2470
- _xidregistry_add_type (struct _xidregistry * xidregistry , PyTypeObject * cls ,
2471
- crossinterpdatafunc getdata )
2487
+ _xidregistry_add_type (struct _xidregistry * xidregistry ,
2488
+ PyTypeObject * cls , crossinterpdatafunc getdata )
2472
2489
{
2473
- // Note that we effectively replace already registered classes
2474
- // rather than failing.
2475
2490
struct _xidregitem * newhead = PyMem_RawMalloc (sizeof (struct _xidregitem ));
2476
2491
if (newhead == NULL ) {
2477
2492
return -1 ;
2478
2493
}
2479
- // XXX Assign a callback to clear the entry from the registry?
2480
- newhead -> cls = PyWeakref_NewRef ((PyObject * )cls , NULL );
2481
- if (newhead -> cls == NULL ) {
2482
- PyMem_RawFree (newhead );
2483
- return -1 ;
2494
+ * newhead = (struct _xidregitem ){
2495
+ // We do not keep a reference, to avoid keeping the class alive.
2496
+ .cls = cls ,
2497
+ .refcount = 1 ,
2498
+ .getdata = getdata ,
2499
+ };
2500
+ if (cls -> tp_flags & Py_TPFLAGS_HEAPTYPE ) {
2501
+ // XXX Assign a callback to clear the entry from the registry?
2502
+ newhead -> weakref = PyWeakref_NewRef ((PyObject * )cls , NULL );
2503
+ if (newhead -> weakref == NULL ) {
2504
+ PyMem_RawFree (newhead );
2505
+ return -1 ;
2506
+ }
2484
2507
}
2485
- newhead -> getdata = getdata ;
2486
- newhead -> prev = NULL ;
2487
2508
newhead -> next = xidregistry -> head ;
2488
2509
if (newhead -> next != NULL ) {
2489
2510
newhead -> next -> prev = newhead ;
@@ -2508,37 +2529,78 @@ _xidregistry_remove_entry(struct _xidregistry *xidregistry,
2508
2529
if (next != NULL ) {
2509
2530
next -> prev = entry -> prev ;
2510
2531
}
2511
- Py_DECREF (entry -> cls );
2532
+ Py_XDECREF (entry -> weakref );
2512
2533
PyMem_RawFree (entry );
2513
2534
return next ;
2514
2535
}
2515
2536
2537
+ static void
2538
+ _xidregistry_clear (struct _xidregistry * xidregistry )
2539
+ {
2540
+ struct _xidregitem * cur = xidregistry -> head ;
2541
+ xidregistry -> head = NULL ;
2542
+ while (cur != NULL ) {
2543
+ struct _xidregitem * next = cur -> next ;
2544
+ Py_XDECREF (cur -> weakref );
2545
+ PyMem_RawFree (cur );
2546
+ cur = next ;
2547
+ }
2548
+ }
2549
+
2516
2550
static struct _xidregitem *
2517
2551
_xidregistry_find_type (struct _xidregistry * xidregistry , PyTypeObject * cls )
2518
2552
{
2519
2553
struct _xidregitem * cur = xidregistry -> head ;
2520
2554
while (cur != NULL ) {
2521
- PyObject * registered = PyWeakref_GetObject (cur -> cls );
2522
- if (registered == Py_None ) {
2523
- // The weakly ref'ed object was freed.
2524
- cur = _xidregistry_remove_entry (xidregistry , cur );
2525
- }
2526
- else {
2527
- assert (PyType_Check (registered ));
2528
- if (registered == (PyObject * )cls ) {
2529
- return cur ;
2555
+ if (cur -> weakref != NULL ) {
2556
+ // cur is/was a heap type.
2557
+ PyObject * registered = PyWeakref_GetObject (cur -> weakref );
2558
+ assert (registered != NULL );
2559
+ if (registered == Py_None ) {
2560
+ // The weakly ref'ed object was freed.
2561
+ cur = _xidregistry_remove_entry (xidregistry , cur );
2562
+ continue ;
2530
2563
}
2531
- cur = cur -> next ;
2564
+ assert (PyType_Check (registered ));
2565
+ assert (cur -> cls == (PyTypeObject * )registered );
2566
+ assert (cur -> cls -> tp_flags & Py_TPFLAGS_HEAPTYPE );
2567
+ Py_DECREF (registered );
2532
2568
}
2569
+ if (cur -> cls == cls ) {
2570
+ return cur ;
2571
+ }
2572
+ cur = cur -> next ;
2533
2573
}
2534
2574
return NULL ;
2535
2575
}
2536
2576
2577
+ static inline struct _xidregistry *
2578
+ _get_xidregistry (PyInterpreterState * interp , PyTypeObject * cls )
2579
+ {
2580
+ struct _xidregistry * xidregistry = & interp -> runtime -> xidregistry ;
2581
+ if (cls -> tp_flags & Py_TPFLAGS_HEAPTYPE ) {
2582
+ assert (interp -> xidregistry .mutex == xidregistry -> mutex );
2583
+ xidregistry = & interp -> xidregistry ;
2584
+ }
2585
+ return xidregistry ;
2586
+ }
2587
+
2537
2588
static void _register_builtins_for_crossinterpreter_data (struct _xidregistry * xidregistry );
2538
2589
2590
+ static inline void
2591
+ _ensure_builtins_xid (PyInterpreterState * interp , struct _xidregistry * xidregistry )
2592
+ {
2593
+ if (xidregistry != & interp -> xidregistry ) {
2594
+ assert (xidregistry == & interp -> runtime -> xidregistry );
2595
+ if (xidregistry -> head == NULL ) {
2596
+ _register_builtins_for_crossinterpreter_data (xidregistry );
2597
+ }
2598
+ }
2599
+ }
2600
+
2539
2601
int
2540
2602
_PyCrossInterpreterData_RegisterClass (PyTypeObject * cls ,
2541
- crossinterpdatafunc getdata )
2603
+ crossinterpdatafunc getdata )
2542
2604
{
2543
2605
if (!PyType_Check (cls )) {
2544
2606
PyErr_Format (PyExc_ValueError , "only classes may be registered" );
@@ -2549,12 +2611,23 @@ _PyCrossInterpreterData_RegisterClass(PyTypeObject *cls,
2549
2611
return -1 ;
2550
2612
}
2551
2613
2552
- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2614
+ int res = 0 ;
2615
+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2616
+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
2553
2617
PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2554
- if (xidregistry -> head == NULL ) {
2555
- _register_builtins_for_crossinterpreter_data (xidregistry );
2618
+
2619
+ _ensure_builtins_xid (interp , xidregistry );
2620
+
2621
+ struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
2622
+ if (matched != NULL ) {
2623
+ assert (matched -> getdata == getdata );
2624
+ matched -> refcount += 1 ;
2625
+ goto finally ;
2556
2626
}
2557
- int res = _xidregistry_add_type (xidregistry , cls , getdata );
2627
+
2628
+ res = _xidregistry_add_type (xidregistry , cls , getdata );
2629
+
2630
+ finally :
2558
2631
PyThread_release_lock (xidregistry -> mutex );
2559
2632
return res ;
2560
2633
}
@@ -2563,13 +2636,20 @@ int
2563
2636
_PyCrossInterpreterData_UnregisterClass (PyTypeObject * cls )
2564
2637
{
2565
2638
int res = 0 ;
2566
- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2639
+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2640
+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
2567
2641
PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2642
+
2568
2643
struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
2569
2644
if (matched != NULL ) {
2570
- (void )_xidregistry_remove_entry (xidregistry , matched );
2645
+ assert (matched -> refcount > 0 );
2646
+ matched -> refcount -= 1 ;
2647
+ if (matched -> refcount == 0 ) {
2648
+ (void )_xidregistry_remove_entry (xidregistry , matched );
2649
+ }
2571
2650
res = 1 ;
2572
2651
}
2652
+
2573
2653
PyThread_release_lock (xidregistry -> mutex );
2574
2654
return res ;
2575
2655
}
@@ -2582,17 +2662,19 @@ _PyCrossInterpreterData_UnregisterClass(PyTypeObject *cls)
2582
2662
crossinterpdatafunc
2583
2663
_PyCrossInterpreterData_Lookup (PyObject * obj )
2584
2664
{
2585
- struct _xidregistry * xidregistry = & _PyRuntime .xidregistry ;
2586
- PyObject * cls = PyObject_Type (obj );
2665
+ PyTypeObject * cls = Py_TYPE (obj );
2666
+
2667
+ PyInterpreterState * interp = _PyInterpreterState_GET ();
2668
+ struct _xidregistry * xidregistry = _get_xidregistry (interp , cls );
2587
2669
PyThread_acquire_lock (xidregistry -> mutex , WAIT_LOCK );
2588
- if ( xidregistry -> head == NULL ) {
2589
- _register_builtins_for_crossinterpreter_data ( xidregistry );
2590
- }
2591
- struct _xidregitem * matched = _xidregistry_find_type (xidregistry ,
2592
- ( PyTypeObject * ) cls ) ;
2593
- Py_DECREF ( cls );
2670
+
2671
+ _ensure_builtins_xid ( interp , xidregistry );
2672
+
2673
+ struct _xidregitem * matched = _xidregistry_find_type (xidregistry , cls );
2674
+ crossinterpdatafunc func = matched != NULL ? matched -> getdata : NULL ;
2675
+
2594
2676
PyThread_release_lock (xidregistry -> mutex );
2595
- return matched != NULL ? matched -> getdata : NULL ;
2677
+ return func ;
2596
2678
}
2597
2679
2598
2680
/* cross-interpreter data for builtin types */
0 commit comments