|
| 1 | +importtorch |
| 2 | +fromtransformersimportpipeline |
| 3 | +frombulk_chain.core.llm_baseimportBaseLM |
| 4 | + |
| 5 | + |
| 6 | +classLlama32(BaseLM): |
| 7 | +""" This code has been tested under transformers==4.47.0 |
| 8 | + This is an experimential version of the LLaMA-3.2 |
| 9 | + that has support of the batching mode. |
| 10 | + """ |
| 11 | + |
| 12 | +def__init__(self,model_name,api_token=None,temp=0.1,device='cpu', |
| 13 | +max_new_tokens=32768,use_bf16=False,**kwargs): |
| 14 | +super(Llama32,self).__init__(name=model_name,support_batching=True,**kwargs) |
| 15 | +self.__max_new_tokens=max_new_tokens |
| 16 | +self.__pipe=pipeline("text-generation", |
| 17 | +model=model_name, |
| 18 | +torch_dtype=torch.bfloat16ifuse_bf16else"auto", |
| 19 | +device_map=device, |
| 20 | +temperature=temp, |
| 21 | +pad_token_id=128001, |
| 22 | +token=api_token) |
| 23 | + |
| 24 | +defask(self,batch): |
| 25 | +input= [{"role":"user","content":p}forpinbatch] |
| 26 | +outputs=self.__pipe(input, |
| 27 | +max_new_tokens=self.__max_new_tokens, |
| 28 | +batch_size=len(input)) |
| 29 | +return [out["generated_text"][-1]["content"]foroutinoutputs] |