3
3
import shutil
4
4
import time
5
5
import queue
6
+ import sys
7
+ import json
6
8
7
9
import datasets
8
10
from InstructorEmbedding import INSTRUCTOR
40
42
TrainingArguments ,
41
43
Trainer ,
42
44
)
43
- from threading import Thread
45
+ import threading
44
46
45
47
__cache_transformer_by_model_id = {}
46
48
__cache_sentence_transformer_by_name = {}
62
64
}
63
65
64
66
67
+ class WorkerThreads :
68
+ def __init__ (self ):
69
+ self .worker_threads = {}
70
+
71
+ def delete_thread (self ,id ):
72
+ del self .worker_threads [id ]
73
+
74
+ def update_thread (self ,id ,value ):
75
+ self .worker_threads [id ]= value
76
+
77
+ def get_thread (self ,id ):
78
+ if id in self .worker_threads :
79
+ return self .worker_threads [id ]
80
+ else :
81
+ return None
82
+
83
+
84
+ worker_threads = WorkerThreads ()
85
+
86
+
65
87
class PgMLException (Exception ):
66
88
pass
67
89
@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
105
127
self .token_cache = []
106
128
self .text_index_cache = []
107
129
130
+ def set_worker_thread_id (self ,id ):
131
+ self .worker_thread_id = id
132
+
133
+ def get_worker_thread_id (self ):
134
+ return self .worker_thread_id
135
+
108
136
def put (self ,values ):
109
137
if self .skip_prompt and self .next_tokens_are_prompt :
110
138
self .next_tokens_are_prompt = False
@@ -149,6 +177,22 @@ def __next__(self):
149
177
return value
150
178
151
179
180
+ def streaming_worker (worker_threads ,model ,** kwargs ):
181
+ thread_id = threading .get_native_id ()
182
+ try :
183
+ worker_threads .update_thread (
184
+ thread_id ,json .dumps ({"model" :model .name_or_path })
185
+ )
186
+ except :
187
+ worker_threads .update_thread (thread_id ,"Error setting data" )
188
+ try :
189
+ model .generate (** kwargs )
190
+ except BaseException as error :
191
+ print (f"Error in streaming_worker:{ error } " ,file = sys .stderr )
192
+ finally :
193
+ worker_threads .delete_thread (thread_id )
194
+
195
+
152
196
class GGMLPipeline (object ):
153
197
def __init__ (self ,model_name ,** task ):
154
198
import ctransformers
@@ -185,7 +229,7 @@ def do_work():
185
229
self .q .put (x )
186
230
self .done = True
187
231
188
- thread = Thread (target = do_work )
232
+ thread = threading . Thread (target = do_work )
189
233
thread .start ()
190
234
191
235
def __iter__ (self ):
@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
283
327
input ,add_generation_prompt = True ,tokenize = False
284
328
)
285
329
input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
286
- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
330
+ generation_kwargs = dict (
331
+ input ,
332
+ worker_threads = worker_threads ,
333
+ model = self .model ,
334
+ streamer = streamer ,
335
+ ** kwargs ,
336
+ )
287
337
else :
288
338
streamer = TextIteratorStreamer (
289
339
self .tokenizer ,
@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
292
342
input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
293
343
self .model .device
294
344
)
295
- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
296
- thread = Thread (target = self .model .generate ,kwargs = generation_kwargs )
345
+ generation_kwargs = dict (
346
+ input ,
347
+ worker_threads = worker_threads ,
348
+ model = self .model ,
349
+ streamer = streamer ,
350
+ ** kwargs ,
351
+ )
352
+ # thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
353
+ thread = threading .Thread (target = streaming_worker ,kwargs = generation_kwargs )
297
354
thread .start ()
355
+ streamer .set_worker_thread_id (thread .native_id )
298
356
return streamer
299
357
300
358
def __call__ (self ,inputs ,** kwargs ):