Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Working simple python thread metrics#1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
SilasMarvin merged 1 commit intomasterfromsilas-python-thread-metrics
Dec 13, 2023
Merged
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletionspgml-extension/src/bindings/transformers/transformers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -3,6 +3,8 @@
import shutil
import time
import queue
import sys
import json

import datasets
from InstructorEmbedding import INSTRUCTOR
Expand DownExpand Up@@ -40,7 +42,7 @@
TrainingArguments,
Trainer,
)
from threadingimportThread
importthreading

__cache_transformer_by_model_id = {}
__cache_sentence_transformer_by_name = {}
Expand All@@ -62,6 +64,26 @@
}


class WorkerThreads:
def __init__(self):
self.worker_threads = {}

def delete_thread(self, id):
del self.worker_threads[id]

def update_thread(self, id, value):
self.worker_threads[id] = value

def get_thread(self, id):
if id in self.worker_threads:
return self.worker_threads[id]
else:
return None


worker_threads = WorkerThreads()


class PgMLException(Exception):
pass

Expand DownExpand Up@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
self.token_cache = []
self.text_index_cache = []

def set_worker_thread_id(self, id):
self.worker_thread_id = id

def get_worker_thread_id(self):
return self.worker_thread_id

def put(self, values):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
Expand DownExpand Up@@ -149,6 +177,22 @@ def __next__(self):
return value


def streaming_worker(worker_threads, model, **kwargs):
thread_id = threading.get_native_id()
try:
worker_threads.update_thread(
thread_id, json.dumps({"model": model.name_or_path})
)
except:
worker_threads.update_thread(thread_id, "Error setting data")
try:
model.generate(**kwargs)
except BaseException as error:
print(f"Error in streaming_worker: {error}", file=sys.stderr)
finally:
worker_threads.delete_thread(thread_id)


class GGMLPipeline(object):
def __init__(self, model_name, **task):
import ctransformers
Expand DownExpand Up@@ -185,7 +229,7 @@ def do_work():
self.q.put(x)
self.done = True

thread = Thread(target=do_work)
thread =threading.Thread(target=do_work)
thread.start()

def __iter__(self):
Expand DownExpand Up@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
input, add_generation_prompt=True, tokenize=False
)
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
generation_kwargs = dict(input, streamer=streamer, **kwargs)
generation_kwargs = dict(
input,
worker_threads=worker_threads,
model=self.model,
streamer=streamer,
**kwargs,
)
else:
streamer = TextIteratorStreamer(
self.tokenizer,
Expand All@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
self.model.device
)
generation_kwargs = dict(input, streamer=streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
generation_kwargs = dict(
input,
worker_threads=worker_threads,
model=self.model,
streamer=streamer,
**kwargs,
)
# thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs)
thread.start()
streamer.set_worker_thread_id(thread.native_id)
return streamer

def __call__(self, inputs, **kwargs):
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp