4141PegasusTokenizer ,
4242TrainingArguments ,
4343Trainer ,
44- GPTQConfig
44+ GPTQConfig ,
45+ PegasusForConditionalGeneration ,
46+ PegasusTokenizer ,
4547)
4648import threading
4749
@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254256if "use_auth_token" in kwargs :
255257kwargs ["token" ]= kwargs .pop ("use_auth_token" )
256258
259+ self .model_name = model_name
260+
257261if (
258262"task" in kwargs
259263and model_name is not None
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278282model_name ,** kwargs
279283 )
280284elif self .task == "summarization" or self .task == "translation" :
281- self .model = AutoModelForSeq2SeqLM .from_pretrained (model_name ,** kwargs )
285+ if model_name == "google/pegasus-xsum" :
286+ # HF auto model doesn't detect GPUs
287+ self .model = PegasusForConditionalGeneration .from_pretrained (
288+ model_name
289+ )
290+ else :
291+ self .model = AutoModelForSeq2SeqLM .from_pretrained (
292+ model_name ,** kwargs
293+ )
282294elif self .task == "text-generation" or self .task == "conversational" :
283295# See: https://huggingface.co/docs/transformers/main/quantization
284296if "quantization_config" in kwargs :
285297quantization_config = kwargs .pop ("quantization_config" )
286298quantization_config = GPTQConfig (** quantization_config )
287- self .model = AutoModelForCausalLM .from_pretrained (model_name ,quantization_config = quantization_config ,** kwargs )
299+ self .model = AutoModelForCausalLM .from_pretrained (
300+ model_name ,quantization_config = quantization_config ,** kwargs
301+ )
288302else :
289- self .model = AutoModelForCausalLM .from_pretrained (model_name ,** kwargs )
303+ self .model = AutoModelForCausalLM .from_pretrained (
304+ model_name ,** kwargs
305+ )
290306else :
291307raise PgMLException (f"Unhandled task:{ self .task } " )
292308
309+ if model_name == "google/pegasus-xsum" :
310+ kwargs .pop ("token" ,None )
311+
293312if "token" in kwargs :
294313self .tokenizer = AutoTokenizer .from_pretrained (
295314model_name ,token = kwargs ["token" ]
296315 )
297316else :
298- self .tokenizer = AutoTokenizer .from_pretrained (model_name )
317+ if model_name == "google/pegasus-xsum" :
318+ self .tokenizer = PegasusTokenizer .from_pretrained (model_name )
319+ else :
320+ self .tokenizer = AutoTokenizer .from_pretrained (model_name )
321+
322+ pipe_kwargs = {
323+ "model" :self .model ,
324+ "tokenizer" :self .tokenizer ,
325+ }
326+
327+ # https://huggingface.co/docs/transformers/en/model_doc/pegasus
328+ if model_name == "google/pegasus-xsum" :
329+ pipe_kwargs ["device" ]= kwargs .get ("device" ,"cpu" )
299330
300331self .pipe = transformers .pipeline (
301332self .task ,
302- model = self .model ,
303- tokenizer = self .tokenizer ,
333+ ** pipe_kwargs ,
304334 )
305335else :
306336self .pipe = transformers .pipeline (** kwargs )
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320350self .tokenizer ,
321351timeout = timeout ,
322352skip_prompt = True ,
323- skip_special_tokens = True
353+ skip_special_tokens = True ,
324354 )
325355if "chat_template" in kwargs :
326356input = self .tokenizer .apply_chat_template (
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343373 )
344374else :
345375streamer = TextIteratorStreamer (
346- self .tokenizer ,
347- timeout = timeout ,
348- skip_special_tokens = True
376+ self .tokenizer ,timeout = timeout ,skip_special_tokens = True
349377 )
350378input = self .tokenizer (input ,return_tensors = "pt" ,padding = True ).to (
351379self .model .device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496524return embed_using (model ,transformer ,inputs ,kwargs )
497525
498526
499-
500527def clear_gpu_cache (memory_usage :None ):
501528if not torch .cuda .is_available ():
502529raise PgMLException (f"No GPU available" )