diff --git a/Lib/test/test_ctypes/test_win32_com_foreign_func.py b/Lib/test/test_ctypes/test_win32_com_foreign_func.py index 651c9277d59af9..8d217fc17efa02 100644 --- a/Lib/test/test_ctypes/test_win32_com_foreign_func.py +++ b/Lib/test/test_ctypes/test_win32_com_foreign_func.py @@ -9,7 +9,7 @@ raise unittest.SkipTest("Windows-specific test") -from _ctypes import COMError +from _ctypes import COMError, CopyComPointer from ctypes import HRESULT @@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2): ) +def create_shelllink_persist(typ): + ppst = typ() + # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance + ole32.CoCreateInstance( + byref(CLSID_ShellLink), + None, + CLSCTX_SERVER, + byref(IID_IPersist), + byref(ppst), + ) + return ppst + + class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase): def setUp(self): # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex @@ -88,19 +101,6 @@ def tearDown(self): ole32.CoUninitialize() gc.collect() - @staticmethod - def create_shelllink_persist(typ): - ppst = typ() - # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance - ole32.CoCreateInstance( - byref(CLSID_ShellLink), - None, - CLSCTX_SERVER, - byref(IID_IPersist), - byref(ppst), - ) - return ppst - def test_without_paramflags_and_iid(self): class IUnknown(c_void_p): QueryInterface = proto_query_interface() @@ -110,7 +110,7 @@ class IUnknown(c_void_p): class IPersist(IUnknown): GetClassID = proto_get_class_id() - ppst = self.create_shelllink_persist(IPersist) + ppst = create_shelllink_persist(IPersist) clsid = GUID() hr_getclsid = ppst.GetClassID(byref(clsid)) @@ -142,7 +142,7 @@ class IUnknown(c_void_p): class IPersist(IUnknown): GetClassID = proto_get_class_id(((OUT, "pClassID"),)) - ppst = self.create_shelllink_persist(IPersist) + ppst = create_shelllink_persist(IPersist) clsid = ppst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) @@ -167,7 +167,7 @@ class IUnknown(c_void_p): class IPersist(IUnknown): GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) - ppst = self.create_shelllink_persist(IPersist) + ppst = create_shelllink_persist(IPersist) clsid = ppst.GetClassID() self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) @@ -184,5 +184,103 @@ class IPersist(IUnknown): self.assertEqual(0, ppst.Release()) +class CopyComPointerTests(unittest.TestCase): + def setUp(self): + ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) + + class IUnknown(c_void_p): + QueryInterface = proto_query_interface(None, IID_IUnknown) + AddRef = proto_add_ref() + Release = proto_release() + + class IPersist(IUnknown): + GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) + + self.IUnknown = IUnknown + self.IPersist = IPersist + + def tearDown(self): + ole32.CoUninitialize() + gc.collect() + + def test_both_are_null(self): + src = self.IPersist() + dst = self.IPersist() + + hr = CopyComPointer(src, byref(dst)) + + self.assertEqual(S_OK, hr) + + self.assertIsNone(src.value) + self.assertIsNone(dst.value) + + def test_src_is_nonnull_and_dest_is_null(self): + # The reference count of the COM pointer created by `CoCreateInstance` + # is initially 1. + src = create_shelllink_persist(self.IPersist) + dst = self.IPersist() + + # `CopyComPointer` calls `AddRef` explicitly in the C implementation. + # The refcount of `src` is incremented from 1 to 2 here. + hr = CopyComPointer(src, byref(dst)) + + self.assertEqual(S_OK, hr) + self.assertEqual(src.value, dst.value) + + # This indicates that the refcount was 2 before the `Release` call. + self.assertEqual(1, src.Release()) + + clsid = dst.GetClassID() + self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) + + self.assertEqual(0, dst.Release()) + + def test_src_is_null_and_dest_is_nonnull(self): + src = self.IPersist() + dst_orig = create_shelllink_persist(self.IPersist) + dst = self.IPersist() + CopyComPointer(dst_orig, byref(dst)) + self.assertEqual(1, dst_orig.Release()) + + clsid = dst.GetClassID() + self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) + + # This does NOT affects the refcount of `dst_orig`. + hr = CopyComPointer(src, byref(dst)) + + self.assertEqual(S_OK, hr) + self.assertIsNone(dst.value) + + with self.assertRaises(ValueError): + dst.GetClassID() # NULL COM pointer access + + # This indicates that the refcount was 1 before the `Release` call. + self.assertEqual(0, dst_orig.Release()) + + def test_both_are_nonnull(self): + src = create_shelllink_persist(self.IPersist) + dst_orig = create_shelllink_persist(self.IPersist) + dst = self.IPersist() + CopyComPointer(dst_orig, byref(dst)) + self.assertEqual(1, dst_orig.Release()) + + self.assertEqual(dst.value, dst_orig.value) + self.assertNotEqual(src.value, dst.value) + + hr = CopyComPointer(src, byref(dst)) + + self.assertEqual(S_OK, hr) + self.assertEqual(src.value, dst.value) + self.assertNotEqual(dst.value, dst_orig.value) + + self.assertEqual(1, src.Release()) + + clsid = dst.GetClassID() + self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) + + self.assertEqual(0, dst.Release()) + self.assertEqual(0, dst_orig.Release()) + + if __name__ == '__main__': unittest.main()