4
4
import time
5
5
import queue
6
6
import sys
7
+ import json
7
8
8
9
import datasets
9
10
from InstructorEmbedding import INSTRUCTOR
41
42
TrainingArguments ,
42
43
Trainer ,
43
44
)
44
- from threading import Thread
45
+ import threading
45
46
from typing import Optional
46
47
47
48
__cache_transformer_by_model_id = {}
64
65
}
65
66
66
67
68
+ class WorkerThreads :
69
+ def __init__ (self ):
70
+ self .worker_threads = {}
71
+
72
+ def delete_thread (self ,id ):
73
+ del self .worker_threads [id ]
74
+
75
+ def update_thread (self ,id ,value ):
76
+ self .worker_threads [id ]= value
77
+
78
+ def get_thread (self ,id ):
79
+ if id in self .worker_threads :
80
+ return self .worker_threads [id ]
81
+ else :
82
+ return None
83
+
84
+
85
+ worker_threads = WorkerThreads ()
86
+
87
+
67
88
class PgMLException (Exception ):
68
89
pass
69
90
@@ -107,6 +128,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
107
128
self .token_cache = []
108
129
self .text_index_cache = []
109
130
131
+ def set_worker_thread_id (self ,id ):
132
+ self .worker_thread_id = id
133
+
134
+ def get_worker_thread_id (self ):
135
+ return self .worker_thread_id
136
+
110
137
def put (self ,values ):
111
138
if self .skip_prompt and self .next_tokens_are_prompt :
112
139
self .next_tokens_are_prompt = False
@@ -151,6 +178,22 @@ def __next__(self):
151
178
return value
152
179
153
180
181
+ def streaming_worker (worker_threads ,model ,** kwargs ):
182
+ thread_id = threading .get_native_id ()
183
+ try :
184
+ worker_threads .update_thread (
185
+ thread_id ,json .dumps ({"model" :model .name_or_path })
186
+ )
187
+ except :
188
+ worker_threads .update_thread (thread_id ,"Error setting data" )
189
+ try :
190
+ model .generate (** kwargs )
191
+ except BaseException as error :
192
+ print (f"Error in streaming_worker:{ error } " ,file = sys .stderr )
193
+ finally :
194
+ worker_threads .delete_thread (thread_id )
195
+
196
+
154
197
class GGMLPipeline (object ):
155
198
def __init__ (self ,model_name ,** task ):
156
199
import ctransformers
@@ -187,7 +230,7 @@ def do_work():
187
230
self .q .put (x )
188
231
self .done = True
189
232
190
- thread = Thread (target = do_work )
233
+ thread = threading . Thread (target = do_work )
191
234
thread .start ()
192
235
193
236
def __iter__ (self ):
@@ -285,7 +328,13 @@ def stream(self, input, timeout=None, **kwargs):
285
328
input ,add_generation_prompt = True ,tokenize = False
286
329
)
287
330
input = self .tokenizer (input ,return_tensors = "pt" ).to (self .model .device )
288
- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
331
+ generation_kwargs = dict (
332
+ input ,
333
+ worker_threads = worker_threads ,
334
+ model = self .model ,
335
+ streamer = streamer ,
336
+ ** kwargs ,
337
+ )
289
338
else :
290
339
streamer = TextIteratorStreamer (
291
340
self .tokenizer ,
@@ -294,9 +343,17 @@ def stream(self, input, timeout=None, **kwargs):
294
343
input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
295
344
self .model .device
296
345
)
297
- generation_kwargs = dict (input ,streamer = streamer ,** kwargs )
298
- thread = Thread (target = self .model .generate ,kwargs = generation_kwargs )
346
+ generation_kwargs = dict (
347
+ input ,
348
+ worker_threads = worker_threads ,
349
+ model = self .model ,
350
+ streamer = streamer ,
351
+ ** kwargs ,
352
+ )
353
+ # thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
354
+ thread = threading .Thread (target = streaming_worker ,kwargs = generation_kwargs )
299
355
thread .start ()
356
+ streamer .set_worker_thread_id (thread .native_id )
300
357
return streamer
301
358
302
359
def __call__ (self ,inputs ,** kwargs ):