33import shutil
44import time
55import queue
6+ import sys
7+ import json
68
79import datasets
810from InstructorEmbedding import INSTRUCTOR
4042TrainingArguments ,
4143Trainer ,
4244)
43- from threading import Thread
45+ import threading
4446
4547__cache_transformer_by_model_id = {}
4648__cache_sentence_transformer_by_name = {}
6264}
6365
6466
67+ class WorkerThreads :
68+ def __init__ (self ):
69+ self .worker_threads = {}
70+
71+ def delete_thread (self ,id ):
72+ del self .worker_threads [id ]
73+
74+ def update_thread (self ,id ,value ):
75+ self .worker_threads [id ]= value
76+
77+ def get_thread (self ,id ):
78+ if id in self .worker_threads :
79+ return self .worker_threads [id ]
80+ else :
81+ return None
82+
83+
84+ worker_threads = WorkerThreads ()
85+
86+
6587class PgMLException (Exception ):
6688pass
6789
@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
105127self .token_cache = []
106128self .text_index_cache = []
107129
130+ def set_worker_thread_id (self ,id ):
131+ self .worker_thread_id = id
132+
133+ def get_worker_thread_id (self ):
134+ return self .worker_thread_id
135+
108136def put (self ,values ):
109137if self .skip_prompt and self .next_tokens_are_prompt :
110138self .next_tokens_are_prompt = False
@@ -149,6 +177,22 @@ def __next__(self):
149177return value
150178
151179
180+ def streaming_worker (worker_threads ,model ,** kwargs ):
181+ thread_id = threading .get_native_id ()
182+ try :
183+ worker_threads .update_thread (
184+ thread_id ,json .dumps ({"model" :model .name_or_path })
185+ )
186+ except :
187+ worker_threads .update_thread (thread_id ,"Error setting data" )
188+ try :
189+ model .generate (** kwargs )
190+ except BaseException as error :
191+ print (f"Error in streaming_worker:{ error } " ,file = sys .stderr )
192+ finally :
193+ worker_threads .delete_thread (thread_id )
194+
195+
152196class GGMLPipeline (object ):
153197def __init__ (self ,model_name ,** task ):
154198import ctransformers
@@ -185,7 +229,7 @@ def do_work():
185229self .q .put (x )
186230self .done = True
187231
188- thread = Thread (target = do_work )
232+ thread = threading . Thread (target = do_work )
189233thread .start ()
190234
191235def __iter__ (self ):
@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
283327input ,add_generation_prompt = True ,tokenize = False
284328 )
285329input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
286- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
330+ generation_kwargs = dict (
331+ input ,
332+ worker_threads = worker_threads ,
333+ model = self .model ,
334+ streamer = streamer ,
335+ ** kwargs ,
336+ )
287337else :
288338streamer = TextIteratorStreamer (
289339self .tokenizer ,
@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
292342input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
293343self .model .device
294344 )
295- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
296- thread = Thread (target = self .model .generate ,kwargs = generation_kwargs )
345+ generation_kwargs = dict (
346+ input ,
347+ worker_threads = worker_threads ,
348+ model = self .model ,
349+ streamer = streamer ,
350+ ** kwargs ,
351+ )
352+ # thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
353+ thread = threading .Thread (target = streaming_worker ,kwargs = generation_kwargs )
297354thread .start ()
355+ streamer .set_worker_thread_id (thread .native_id )
298356return streamer
299357
300358def __call__ (self ,inputs ,** kwargs ):