Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4908be4

Browse files
authored
Add support for google/pegasus-xsum (#1325)
1 parentc7494db commit4908be4

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

‎pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
PegasusTokenizer,
4242
TrainingArguments,
4343
Trainer,
44-
GPTQConfig
44+
GPTQConfig,
45+
PegasusForConditionalGeneration,
46+
PegasusTokenizer,
4547
)
4648
importthreading
4749

@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254256
if"use_auth_token"inkwargs:
255257
kwargs["token"]=kwargs.pop("use_auth_token")
256258

259+
self.model_name=model_name
260+
257261
if (
258262
"task"inkwargs
259263
andmodel_nameisnotNone
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278282
model_name,**kwargs
279283
)
280284
elifself.task=="summarization"orself.task=="translation":
281-
self.model=AutoModelForSeq2SeqLM.from_pretrained(model_name,**kwargs)
285+
ifmodel_name=="google/pegasus-xsum":
286+
# HF auto model doesn't detect GPUs
287+
self.model=PegasusForConditionalGeneration.from_pretrained(
288+
model_name
289+
)
290+
else:
291+
self.model=AutoModelForSeq2SeqLM.from_pretrained(
292+
model_name,**kwargs
293+
)
282294
elifself.task=="text-generation"orself.task=="conversational":
283295
# See: https://huggingface.co/docs/transformers/main/quantization
284296
if"quantization_config"inkwargs:
285297
quantization_config=kwargs.pop("quantization_config")
286298
quantization_config=GPTQConfig(**quantization_config)
287-
self.model=AutoModelForCausalLM.from_pretrained(model_name,quantization_config=quantization_config,**kwargs)
299+
self.model=AutoModelForCausalLM.from_pretrained(
300+
model_name,quantization_config=quantization_config,**kwargs
301+
)
288302
else:
289-
self.model=AutoModelForCausalLM.from_pretrained(model_name,**kwargs)
303+
self.model=AutoModelForCausalLM.from_pretrained(
304+
model_name,**kwargs
305+
)
290306
else:
291307
raisePgMLException(f"Unhandled task:{self.task}")
292308

309+
ifmodel_name=="google/pegasus-xsum":
310+
kwargs.pop("token",None)
311+
293312
if"token"inkwargs:
294313
self.tokenizer=AutoTokenizer.from_pretrained(
295314
model_name,token=kwargs["token"]
296315
)
297316
else:
298-
self.tokenizer=AutoTokenizer.from_pretrained(model_name)
317+
ifmodel_name=="google/pegasus-xsum":
318+
self.tokenizer=PegasusTokenizer.from_pretrained(model_name)
319+
else:
320+
self.tokenizer=AutoTokenizer.from_pretrained(model_name)
321+
322+
pipe_kwargs= {
323+
"model":self.model,
324+
"tokenizer":self.tokenizer,
325+
}
326+
327+
# https://huggingface.co/docs/transformers/en/model_doc/pegasus
328+
ifmodel_name=="google/pegasus-xsum":
329+
pipe_kwargs["device"]=kwargs.get("device","cpu")
299330

300331
self.pipe=transformers.pipeline(
301332
self.task,
302-
model=self.model,
303-
tokenizer=self.tokenizer,
333+
**pipe_kwargs,
304334
)
305335
else:
306336
self.pipe=transformers.pipeline(**kwargs)
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320350
self.tokenizer,
321351
timeout=timeout,
322352
skip_prompt=True,
323-
skip_special_tokens=True
353+
skip_special_tokens=True,
324354
)
325355
if"chat_template"inkwargs:
326356
input=self.tokenizer.apply_chat_template(
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343373
)
344374
else:
345375
streamer=TextIteratorStreamer(
346-
self.tokenizer,
347-
timeout=timeout,
348-
skip_special_tokens=True
376+
self.tokenizer,timeout=timeout,skip_special_tokens=True
349377
)
350378
input=self.tokenizer(input,return_tensors="pt",padding=True).to(
351379
self.model.device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496524
returnembed_using(model,transformer,inputs,kwargs)
497525

498526

499-
500527
defclear_gpu_cache(memory_usage:None):
501528
ifnottorch.cuda.is_available():
502529
raisePgMLException(f"No GPU available")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp