Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbe11967

Browse files
committed
Working simple python thread metrics
1 parentf7401b8 commitbe11967

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

‎pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
importshutil
44
importtime
55
importqueue
6+
importsys
7+
importjson
68

79
importdatasets
810
fromInstructorEmbeddingimportINSTRUCTOR
@@ -40,7 +42,7 @@
4042
TrainingArguments,
4143
Trainer,
4244
)
43-
fromthreadingimportThread
45+
importthreading
4446

4547
__cache_transformer_by_model_id= {}
4648
__cache_sentence_transformer_by_name= {}
@@ -62,6 +64,26 @@
6264
}
6365

6466

67+
classWorkerThreads:
68+
def__init__(self):
69+
self.worker_threads= {}
70+
71+
defdelete_thread(self,id):
72+
delself.worker_threads[id]
73+
74+
defupdate_thread(self,id,value):
75+
self.worker_threads[id]=value
76+
77+
defget_thread(self,id):
78+
ifidinself.worker_threads:
79+
returnself.worker_threads[id]
80+
else:
81+
returnNone
82+
83+
84+
worker_threads=WorkerThreads()
85+
86+
6587
classPgMLException(Exception):
6688
pass
6789

@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
105127
self.token_cache= []
106128
self.text_index_cache= []
107129

130+
defset_worker_thread_id(self,id):
131+
self.worker_thread_id=id
132+
133+
defget_worker_thread_id(self):
134+
returnself.worker_thread_id
135+
108136
defput(self,values):
109137
ifself.skip_promptandself.next_tokens_are_prompt:
110138
self.next_tokens_are_prompt=False
@@ -149,6 +177,22 @@ def __next__(self):
149177
returnvalue
150178

151179

180+
defstreaming_worker(worker_threads,model,**kwargs):
181+
thread_id=threading.get_native_id()
182+
try:
183+
worker_threads.update_thread(
184+
thread_id,json.dumps({"model":model.name_or_path})
185+
)
186+
except:
187+
worker_threads.update_thread(thread_id,"Error setting data")
188+
try:
189+
model.generate(**kwargs)
190+
exceptBaseExceptionaserror:
191+
print(f"Error in streaming_worker:{error}",file=sys.stderr)
192+
finally:
193+
worker_threads.delete_thread(thread_id)
194+
195+
152196
classGGMLPipeline(object):
153197
def__init__(self,model_name,**task):
154198
importctransformers
@@ -185,7 +229,7 @@ def do_work():
185229
self.q.put(x)
186230
self.done=True
187231

188-
thread=Thread(target=do_work)
232+
thread=threading.Thread(target=do_work)
189233
thread.start()
190234

191235
def__iter__(self):
@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
283327
input,add_generation_prompt=True,tokenize=False
284328
)
285329
input=self.tokenizer(input,return_tensors="pt").to(self.model.device)
286-
generation_kwargs=dict(input,streamer=streamer,**kwargs)
330+
generation_kwargs=dict(
331+
input,
332+
worker_threads=worker_threads,
333+
model=self.model,
334+
streamer=streamer,
335+
**kwargs,
336+
)
287337
else:
288338
streamer=TextIteratorStreamer(
289339
self.tokenizer,
@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
292342
input=self.tokenizer(input,return_tensors="pt",padding=True).to(
293343
self.model.device
294344
)
295-
generation_kwargs=dict(input,streamer=streamer,**kwargs)
296-
thread=Thread(target=self.model.generate,kwargs=generation_kwargs)
345+
generation_kwargs=dict(
346+
input,
347+
worker_threads=worker_threads,
348+
model=self.model,
349+
streamer=streamer,
350+
**kwargs,
351+
)
352+
# thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
353+
thread=threading.Thread(target=streaming_worker,kwargs=generation_kwargs)
297354
thread.start()
355+
streamer.set_worker_thread_id(thread.native_id)
298356
returnstreamer
299357

300358
def__call__(self,inputs,**kwargs):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp