diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index c5191e5fb939..7f0c08223630 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -283,6 +283,19 @@ def emit_line() -> None: if emitter.capi_version < (3, 12): fields["tp_dictoffset"] = base_size fields["tp_weaklistoffset"] = weak_offset + elif cl.supports_weakref: + # __weakref__ lives right after the struct + # TODO: It should get a member in the struct instead of doing this nonsense. + emitter.emit_lines( + f"PyMemberDef {members_name}[] = {{", + f'{{"__weakref__", T_OBJECT_EX, {base_size}, 0, NULL}},', + "{0}", + "};", + ) + if emitter.capi_version < (3, 12): + # versions >= 3.12 set Py_TPFLAGS_MANAGED_WEAKREF flag instead + # https://docs.python.org/3.12/extending/newtypes.html#weak-reference-support + fields["tp_weaklistoffset"] = base_size else: fields["tp_basicsize"] = base_size @@ -343,6 +356,9 @@ def emit_line() -> None: fields["tp_call"] = "PyVectorcall_Call" if has_managed_dict(cl, emitter): flags.append("Py_TPFLAGS_MANAGED_DICT") + if cl.supports_weakref and emitter.capi_version >= (3, 12): + flags.append("Py_TPFLAGS_MANAGED_WEAKREF") + fields["tp_flags"] = " | ".join(flags) emitter.emit_line(f"static PyTypeObject {emitter.type_struct_name(cl)}_template_ = {{") @@ -782,6 +798,13 @@ def generate_dealloc_for_class( emitter.emit_line("static void") emitter.emit_line(f"{dealloc_func_name}({cl.struct_name(emitter.names)} *self)") emitter.emit_line("{") + if cl.supports_weakref: + if emitter.capi_version < (3, 12): + emitter.emit_line("if (self->weakreflist != NULL) {") + emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);") + emitter.emit_line("}") + else: + emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);") if has_tp_finalize: emitter.emit_line("if (!PyObject_GC_IsFinalized((PyObject *)self)) {") emitter.emit_line("Py_TYPE(self)->tp_finalize((PyObject *)self);") diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index c88b9b0c7afc..4e86d5ce0eb2 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -109,6 +109,8 @@ def __init__( self.inherits_python = False # Do instances of this class have __dict__? self.has_dict = False + # Do instances of this class have __weakref__? + self.supports_weakref = False # Do we allow interpreted subclasses? Derived from a mypyc_attr. self.allow_interpreted_subclasses = False # Does this class need getseters to be generated for its attributes? (getseters are also @@ -362,6 +364,7 @@ def serialize(self) -> JsonDict: "is_final_class": self.is_final_class, "inherits_python": self.inherits_python, "has_dict": self.has_dict, + "supports_weakref": self.supports_weakref, "allow_interpreted_subclasses": self.allow_interpreted_subclasses, "needs_getseters": self.needs_getseters, "_serializable": self._serializable, @@ -419,6 +422,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR: ir.is_final_class = data["is_final_class"] ir.inherits_python = data["inherits_python"] ir.has_dict = data["has_dict"] + ir.supports_weakref = data["supports_weakref"] ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"] ir.needs_getseters = data["needs_getseters"] ir._serializable = data["_serializable"] diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 98ff348d8c30..8272e246ec5b 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -280,6 +280,9 @@ def prepare_class_def( if attrs.get("serializable") is True: # Supports copy.copy and pickle (including subclasses) ir._serializable = True + if attrs.get("supports_weakref") is True: + # Has a tp_weakrefoffset slot allowing the creation of weak references (including subclasses) + ir.supports_weakref = True # Check for subclassing from builtin types for cls in info.mro: diff --git a/mypyc/irbuild/vtable.py b/mypyc/irbuild/vtable.py index 2d4f7261e4ca..766b4086c594 100644 --- a/mypyc/irbuild/vtable.py +++ b/mypyc/irbuild/vtable.py @@ -15,6 +15,8 @@ def compute_vtable(cls: ClassIR) -> None: if not cls.is_generated: cls.has_dict = any(x.inherits_python for x in cls.mro) + # TODO: define more weakref triggers + cls.supports_weakref = cls.supports_weakref or cls.has_dict for t in cls.mro[1:]: # Make sure all ancestors are processed first diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 5e7ecb70f55d..0328fe790987 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -364,11 +364,3 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription: # Import various modules that set up global state. -import mypyc.primitives.bytes_ops -import mypyc.primitives.dict_ops -import mypyc.primitives.float_ops -import mypyc.primitives.int_ops -import mypyc.primitives.list_ops -import mypyc.primitives.misc_ops -import mypyc.primitives.str_ops -import mypyc.primitives.tuple_ops # noqa: F401 diff --git a/mypyc/primitives/weakref_ops.py b/mypyc/primitives/weakref_ops.py new file mode 100644 index 000000000000..dede42fcd247 --- /dev/null +++ b/mypyc/primitives/weakref_ops.py @@ -0,0 +1,38 @@ +from mypyc.ir.rtypes import object_rprimitive +from mypyc.primitives.registry import function_op + +# Weakref operations + +""" +py_new_weak_ref_op = function_op( + name="weakref.weakref", + arg_types=[object_rprimitive], + # TODO: how do I pass NULL as the 2nd arg? + #extra_int_constants=[], + result_type=object_rprimitive, + c_function_name="PyWeakref_NewRef", +) +""" + +py_new_weak_ref_with_callback_op = function_op( + name="weakref.weakref", + arg_types=[object_rprimitive, object_rprimitive], + result_type=object_rprimitive, + c_function_name="PyWeakref_NewRef", +) + +""" +py_new_weak_proxy_op = function_op( + name="weakref.proxy", + arg_types=[object_rprimitive], + result_type=object_rprimitive, + c_function_name="PyWeakref_NewProxy", +) +""" + +py_new_weak_proxy_with_callback_op = function_op( + name="weakref.proxy", + arg_types=[object_rprimitive, object_rprimitive], + result_type=object_rprimitive, + c_function_name="PyWeakref_NewProxy", +) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 9d564a552a05..f1e38696b990 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1381,3 +1381,45 @@ class M(type): # E: Inheriting from most builtin types is unimplemented @mypyc_attr(native_class=True) class A(metaclass=M): # E: Class is marked as native_class=True but it can't be a native class. Classes with a metaclass other than ABCMeta, TypingMeta or GenericMeta can't be native classes. pass + +[case testMypycAttrSupportsWeakref] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +obj = WeakrefClass() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefInheritance] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +class WeakrefInheritor(WeakrefClass): + pass + +obj = WeakrefInheritor() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefSubclass] +import weakref +from mypy_extensions import mypyc_attr + +class NativeClass: + pass + +@mypyc_attr(supports_weakref=True) +class WeakrefSubclass(NativeClass): + pass + +obj = WeakrefSubclass() +ref = weakref.ref(obj) +assert ref() is obj