@@ -219,5 +219,81 @@ def test_function():
219219self .assertEqual (m .call_count ,LOOPS * THREADS )
220220
221221
222+ def test_call_args_thread_safe (self ):
223+ m = ThreadingMock ()
224+ LOOPS = 100
225+ THREADS = 10
226+ def test_function (thread_id ):
227+ for i in range (LOOPS ):
228+ m (thread_id ,i )
229+
230+ oldswitchinterval = sys .getswitchinterval ()
231+ setswitchinterval (1e-6 )
232+ try :
233+ threads = [
234+ threading .Thread (target = test_function ,args = (thread_id ,))
235+ for thread_id in range (THREADS )
236+ ]
237+ with threading_helper .start_threads (threads ):
238+ pass
239+ finally :
240+ sys .setswitchinterval (oldswitchinterval )
241+ expected_calls = {
242+ (thread_id ,i )
243+ for thread_id in range (THREADS )
244+ for i in range (LOOPS )
245+ }
246+ self .assertSetEqual ({call .args for call in m .call_args_list },expected_calls )
247+
248+ def test_method_calls_thread_safe (self ):
249+ m = ThreadingMock ()
250+ LOOPS = 100
251+ THREADS = 10
252+ def test_function (thread_id ):
253+ for i in range (LOOPS ):
254+ getattr (m ,f"method_{ thread_id } " )(i )
255+
256+ oldswitchinterval = sys .getswitchinterval ()
257+ setswitchinterval (1e-6 )
258+ try :
259+ threads = [
260+ threading .Thread (target = test_function ,args = (thread_id ,))
261+ for thread_id in range (THREADS )
262+ ]
263+ with threading_helper .start_threads (threads ):
264+ pass
265+ finally :
266+ sys .setswitchinterval (oldswitchinterval )
267+ for thread_id in range (THREADS ):
268+ self .assertEqual (getattr (m ,f"method_{ thread_id } " ).call_count ,LOOPS )
269+ self .assertEqual ({call .args for call in getattr (m ,f"method_{ thread_id } " ).call_args_list },
270+ {(i ,)for i in range (LOOPS )})
271+
272+ def test_mock_calls_thread_safe (self ):
273+ m = ThreadingMock ()
274+ LOOPS = 100
275+ THREADS = 10
276+ def test_function (thread_id ):
277+ for i in range (LOOPS ):
278+ m (thread_id ,i )
279+
280+ oldswitchinterval = sys .getswitchinterval ()
281+ setswitchinterval (1e-6 )
282+ try :
283+ threads = [
284+ threading .Thread (target = test_function ,args = (thread_id ,))
285+ for thread_id in range (THREADS )
286+ ]
287+ with threading_helper .start_threads (threads ):
288+ pass
289+ finally :
290+ sys .setswitchinterval (oldswitchinterval )
291+ expected_calls = {
292+ (thread_id ,i )
293+ for thread_id in range (THREADS )
294+ for i in range (LOOPS )
295+ }
296+ self .assertSetEqual ({call .args for call in m .mock_calls },expected_calls )
297+
222298if __name__ == "__main__" :
223299unittest .main ()