44import time
55import queue
66import sys
7+ import json
78
89import datasets
910from InstructorEmbedding import INSTRUCTOR
4142TrainingArguments ,
4243Trainer ,
4344)
44- from threading import Thread
45+ import threading
4546from typing import Optional
4647
4748__cache_transformer_by_model_id = {}
6465}
6566
6667
68+ class WorkerThreads :
69+ def __init__ (self ):
70+ self .worker_threads = {}
71+
72+ def delete_thread (self ,id ):
73+ del self .worker_threads [id ]
74+
75+ def update_thread (self ,id ,value ):
76+ self .worker_threads [id ]= value
77+
78+ def get_thread (self ,id ):
79+ if id in self .worker_threads :
80+ return self .worker_threads [id ]
81+ else :
82+ return None
83+
84+
85+ worker_threads = WorkerThreads ()
86+
87+
6788class PgMLException (Exception ):
6889pass
6990
@@ -107,6 +128,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
107128self .token_cache = []
108129self .text_index_cache = []
109130
131+ def set_worker_thread_id (self ,id ):
132+ self .worker_thread_id = id
133+
134+ def get_worker_thread_id (self ):
135+ return self .worker_thread_id
136+
110137def put (self ,values ):
111138if self .skip_prompt and self .next_tokens_are_prompt :
112139self .next_tokens_are_prompt = False
@@ -151,6 +178,22 @@ def __next__(self):
151178return value
152179
153180
181+ def streaming_worker (worker_threads ,model ,** kwargs ):
182+ thread_id = threading .get_native_id ()
183+ try :
184+ worker_threads .update_thread (
185+ thread_id ,json .dumps ({"model" :model .name_or_path })
186+ )
187+ except :
188+ worker_threads .update_thread (thread_id ,"Error setting data" )
189+ try :
190+ model .generate (** kwargs )
191+ except BaseException as error :
192+ print (f"Error in streaming_worker:{ error } " ,file = sys .stderr )
193+ finally :
194+ worker_threads .delete_thread (thread_id )
195+
196+
154197class GGMLPipeline (object ):
155198def __init__ (self ,model_name ,** task ):
156199import ctransformers
@@ -187,7 +230,7 @@ def do_work():
187230self .q .put (x )
188231self .done = True
189232
190- thread = Thread (target = do_work )
233+ thread = threading . Thread (target = do_work )
191234thread .start ()
192235
193236def __iter__ (self ):
@@ -285,7 +328,13 @@ def stream(self, input, timeout=None, **kwargs):
285328input ,add_generation_prompt = True ,tokenize = False
286329 )
287330input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
288- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
331+ generation_kwargs = dict (
332+ input ,
333+ worker_threads = worker_threads ,
334+ model = self .model ,
335+ streamer = streamer ,
336+ ** kwargs ,
337+ )
289338else :
290339streamer = TextIteratorStreamer (
291340self .tokenizer ,
@@ -294,9 +343,17 @@ def stream(self, input, timeout=None, **kwargs):
294343input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
295344self .model .device
296345 )
297- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
298- thread = Thread (target = self .model .generate ,kwargs = generation_kwargs )
346+ generation_kwargs = dict (
347+ input ,
348+ worker_threads = worker_threads ,
349+ model = self .model ,
350+ streamer = streamer ,
351+ ** kwargs ,
352+ )
353+ # thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
354+ thread = threading .Thread (target = streaming_worker ,kwargs = generation_kwargs )
299355thread .start ()
356+ streamer .set_worker_thread_id (thread .native_id )
300357return streamer
301358
302359def __call__ (self ,inputs ,** kwargs ):