|
41 | 41 | PegasusTokenizer,
|
42 | 42 | TrainingArguments,
|
43 | 43 | Trainer,
|
| 44 | +GPTQConfig |
44 | 45 | )
|
45 | 46 | importthreading
|
46 | 47 |
|
@@ -279,7 +280,13 @@ def __init__(self, model_name, **kwargs):
|
279 | 280 | elifself.task=="summarization"orself.task=="translation":
|
280 | 281 | self.model=AutoModelForSeq2SeqLM.from_pretrained(model_name,**kwargs)
|
281 | 282 | elifself.task=="text-generation"orself.task=="conversational":
|
282 |
| -self.model=AutoModelForCausalLM.from_pretrained(model_name,**kwargs) |
| 283 | +# See: https://huggingface.co/docs/transformers/main/quantization |
| 284 | +if"quantization_config"inkwargs: |
| 285 | +quantization_config=kwargs.pop("quantization_config") |
| 286 | +quantization_config=GPTQConfig(**quantization_config) |
| 287 | +self.model=AutoModelForCausalLM.from_pretrained(model_name,quantization_config=quantization_config,**kwargs) |
| 288 | +else: |
| 289 | +self.model=AutoModelForCausalLM.from_pretrained(model_name,**kwargs) |
283 | 290 | else:
|
284 | 291 | raisePgMLException(f"Unhandled task:{self.task}")
|
285 | 292 |
|
@@ -409,10 +416,13 @@ def create_pipeline(task):
|
409 | 416 | else:
|
410 | 417 | try:
|
411 | 418 | pipe=StandardPipeline(model_name,**task)
|
412 |
| -exceptTypeError: |
413 |
| -# some models fail when given "device" kwargs, remove and try again |
414 |
| -task.pop("device") |
415 |
| -pipe=StandardPipeline(model_name,**task) |
| 419 | +exceptTypeErroraserror: |
| 420 | +if"device"intask: |
| 421 | +# some models fail when given "device" kwargs, remove and try again |
| 422 | +task.pop("device") |
| 423 | +pipe=StandardPipeline(model_name,**task) |
| 424 | +else: |
| 425 | +raiseerror |
416 | 426 | returnpipe
|
417 | 427 |
|
418 | 428 |
|
|