Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add support for google/pegasus-xsum#1325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
levkk merged 2 commits intomasterfromlevkk-pegasus
Feb 22, 2024
Merged
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletionspgml-extension/src/bindings/transformers/transformers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -41,7 +41,9 @@
PegasusTokenizer,
TrainingArguments,
Trainer,
GPTQConfig
GPTQConfig,
PegasusForConditionalGeneration,
PegasusTokenizer,
)
import threading

Expand DownExpand Up@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
if "use_auth_token" in kwargs:
kwargs["token"] = kwargs.pop("use_auth_token")

self.model_name = model_name

if (
"task" in kwargs
and model_name is not None
Expand All@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
model_name, **kwargs
)
elif self.task == "summarization" or self.task == "translation":
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
if model_name == "google/pegasus-xsum":
# HF auto model doesn't detect GPUs
self.model = PegasusForConditionalGeneration.from_pretrained(
model_name
)
else:
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, **kwargs
)
elif self.task == "text-generation" or self.task == "conversational":
# See: https://huggingface.co/docs/transformers/main/quantization
if "quantization_config" in kwargs:
quantization_config = kwargs.pop("quantization_config")
quantization_config = GPTQConfig(**quantization_config)
self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=quantization_config, **kwargs
)
else:
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, **kwargs
)
else:
raise PgMLException(f"Unhandled task: {self.task}")

if model_name == "google/pegasus-xsum":
kwargs.pop("token", None)

if "token" in kwargs:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, token=kwargs["token"]
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if model_name == "google/pegasus-xsum":
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

pipe_kwargs = {
"model": self.model,
"tokenizer": self.tokenizer,
}

# https://huggingface.co/docs/transformers/en/model_doc/pegasus
if model_name == "google/pegasus-xsum":
pipe_kwargs["device"] = kwargs.get("device", "cpu")

self.pipe = transformers.pipeline(
self.task,
model=self.model,
tokenizer=self.tokenizer,
**pipe_kwargs,
)
else:
self.pipe = transformers.pipeline(**kwargs)
Expand All@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
self.tokenizer,
timeout=timeout,
skip_prompt=True,
skip_special_tokens=True
skip_special_tokens=True,
)
if "chat_template" in kwargs:
input = self.tokenizer.apply_chat_template(
Expand All@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
)
else:
streamer = TextIteratorStreamer(
self.tokenizer,
timeout=timeout,
skip_special_tokens=True
self.tokenizer, timeout=timeout, skip_special_tokens=True
)
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
self.model.device
Expand DownExpand Up@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
return embed_using(model, transformer, inputs, kwargs)



def clear_gpu_cache(memory_usage: None):
if not torch.cuda.is_available():
raise PgMLException(f"No GPU available")
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp