99raise unittest .SkipTest ("Windows-specific test" )
1010
1111
12- from _ctypes import COMError
12+ from _ctypes import COMError , CopyComPointer
1313from ctypes import HRESULT
1414
1515
@@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
7878)
7979
8080
81+ def create_shelllink_persist (typ ):
82+ ppst = typ ()
83+ # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
84+ ole32 .CoCreateInstance (
85+ byref (CLSID_ShellLink ),
86+ None ,
87+ CLSCTX_SERVER ,
88+ byref (IID_IPersist ),
89+ byref (ppst ),
90+ )
91+ return ppst
92+
93+
8194class ForeignFunctionsThatWillCallComMethodsTests (unittest .TestCase ):
8295def setUp (self ):
8396# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
@@ -88,19 +101,6 @@ def tearDown(self):
88101ole32 .CoUninitialize ()
89102gc .collect ()
90103
91- @staticmethod
92- def create_shelllink_persist (typ ):
93- ppst = typ ()
94- # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
95- ole32 .CoCreateInstance (
96- byref (CLSID_ShellLink ),
97- None ,
98- CLSCTX_SERVER ,
99- byref (IID_IPersist ),
100- byref (ppst ),
101- )
102- return ppst
103-
104104def test_without_paramflags_and_iid (self ):
105105class IUnknown (c_void_p ):
106106QueryInterface = proto_query_interface ()
@@ -110,7 +110,7 @@ class IUnknown(c_void_p):
110110class IPersist (IUnknown ):
111111GetClassID = proto_get_class_id ()
112112
113- ppst = self . create_shelllink_persist (IPersist )
113+ ppst = create_shelllink_persist (IPersist )
114114
115115clsid = GUID ()
116116hr_getclsid = ppst .GetClassID (byref (clsid ))
@@ -142,7 +142,7 @@ class IUnknown(c_void_p):
142142class IPersist (IUnknown ):
143143GetClassID = proto_get_class_id (((OUT ,"pClassID" ),))
144144
145- ppst = self . create_shelllink_persist (IPersist )
145+ ppst = create_shelllink_persist (IPersist )
146146
147147clsid = ppst .GetClassID ()
148148self .assertEqual (TRUE ,is_equal_guid (CLSID_ShellLink ,clsid ))
@@ -167,7 +167,7 @@ class IUnknown(c_void_p):
167167class IPersist (IUnknown ):
168168GetClassID = proto_get_class_id (((OUT ,"pClassID" ),),IID_IPersist )
169169
170- ppst = self . create_shelllink_persist (IPersist )
170+ ppst = create_shelllink_persist (IPersist )
171171
172172clsid = ppst .GetClassID ()
173173self .assertEqual (TRUE ,is_equal_guid (CLSID_ShellLink ,clsid ))
@@ -184,5 +184,103 @@ class IPersist(IUnknown):
184184self .assertEqual (0 ,ppst .Release ())
185185
186186
187+ class CopyComPointerTests (unittest .TestCase ):
188+ def setUp (self ):
189+ ole32 .CoInitializeEx (None ,COINIT_APARTMENTTHREADED )
190+
191+ class IUnknown (c_void_p ):
192+ QueryInterface = proto_query_interface (None ,IID_IUnknown )
193+ AddRef = proto_add_ref ()
194+ Release = proto_release ()
195+
196+ class IPersist (IUnknown ):
197+ GetClassID = proto_get_class_id (((OUT ,"pClassID" ),),IID_IPersist )
198+
199+ self .IUnknown = IUnknown
200+ self .IPersist = IPersist
201+
202+ def tearDown (self ):
203+ ole32 .CoUninitialize ()
204+ gc .collect ()
205+
206+ def test_both_are_null (self ):
207+ src = self .IPersist ()
208+ dst = self .IPersist ()
209+
210+ hr = CopyComPointer (src ,byref (dst ))
211+
212+ self .assertEqual (S_OK ,hr )
213+
214+ self .assertIsNone (src .value )
215+ self .assertIsNone (dst .value )
216+
217+ def test_src_is_nonnull_and_dest_is_null (self ):
218+ # The reference count of the COM pointer created by `CoCreateInstance`
219+ # is initially 1.
220+ src = create_shelllink_persist (self .IPersist )
221+ dst = self .IPersist ()
222+
223+ # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
224+ # The refcount of `src` is incremented from 1 to 2 here.
225+ hr = CopyComPointer (src ,byref (dst ))
226+
227+ self .assertEqual (S_OK ,hr )
228+ self .assertEqual (src .value ,dst .value )
229+
230+ # This indicates that the refcount was 2 before the `Release` call.
231+ self .assertEqual (1 ,src .Release ())
232+
233+ clsid = dst .GetClassID ()
234+ self .assertEqual (TRUE ,is_equal_guid (CLSID_ShellLink ,clsid ))
235+
236+ self .assertEqual (0 ,dst .Release ())
237+
238+ def test_src_is_null_and_dest_is_nonnull (self ):
239+ src = self .IPersist ()
240+ dst_orig = create_shelllink_persist (self .IPersist )
241+ dst = self .IPersist ()
242+ CopyComPointer (dst_orig ,byref (dst ))
243+ self .assertEqual (1 ,dst_orig .Release ())
244+
245+ clsid = dst .GetClassID ()
246+ self .assertEqual (TRUE ,is_equal_guid (CLSID_ShellLink ,clsid ))
247+
248+ # This does NOT affects the refcount of `dst_orig`.
249+ hr = CopyComPointer (src ,byref (dst ))
250+
251+ self .assertEqual (S_OK ,hr )
252+ self .assertIsNone (dst .value )
253+
254+ with self .assertRaises (ValueError ):
255+ dst .GetClassID ()# NULL COM pointer access
256+
257+ # This indicates that the refcount was 1 before the `Release` call.
258+ self .assertEqual (0 ,dst_orig .Release ())
259+
260+ def test_both_are_nonnull (self ):
261+ src = create_shelllink_persist (self .IPersist )
262+ dst_orig = create_shelllink_persist (self .IPersist )
263+ dst = self .IPersist ()
264+ CopyComPointer (dst_orig ,byref (dst ))
265+ self .assertEqual (1 ,dst_orig .Release ())
266+
267+ self .assertEqual (dst .value ,dst_orig .value )
268+ self .assertNotEqual (src .value ,dst .value )
269+
270+ hr = CopyComPointer (src ,byref (dst ))
271+
272+ self .assertEqual (S_OK ,hr )
273+ self .assertEqual (src .value ,dst .value )
274+ self .assertNotEqual (dst .value ,dst_orig .value )
275+
276+ self .assertEqual (1 ,src .Release ())
277+
278+ clsid = dst .GetClassID ()
279+ self .assertEqual (TRUE ,is_equal_guid (CLSID_ShellLink ,clsid ))
280+
281+ self .assertEqual (0 ,dst .Release ())
282+ self .assertEqual (0 ,dst_orig .Release ())
283+
284+
187285if __name__ == '__main__' :
188286unittest .main ()