Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9e46629

Browse files
authored
Update transformers_qwen2.py to pipelines API
1 parentb4b65e3 commit9e46629

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

‎llm/transformers_qwen2.py‎

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
importtorch
2-
fromtransformersimportAutoModelForCausalLM,AutoTokenizer
2+
fromtransformersimportAutoModelForCausalLM,AutoTokenizer,pipeline
33
frombulk_chain.core.llm_baseimportBaseLM
44

55

@@ -9,23 +9,31 @@ def __init__(self, model_name, temp=0.1, device='cpu',
99
max_new_tokens=None,token=None,use_bf16=False,**kwargs):
1010
super(Qwen2,self).__init__(name=model_name,**kwargs)
1111

12-
self.__device=device
1312
self.__max_new_tokens=max_new_tokens
14-
self.__model=AutoModelForCausalLM.from_pretrained(
15-
model_name,torch_dtype=torch.bfloat16ifuse_bf16else"auto",token=token)
16-
self.__model.to(device)
17-
self.__tokenizer=AutoTokenizer.from_pretrained(
13+
model=AutoModelForCausalLM.from_pretrained(
1814
model_name,torch_dtype=torch.bfloat16ifuse_bf16else"auto",token=token)
15+
model.to(device)
16+
tokenizer=AutoTokenizer.from_pretrained(
17+
model_name,torch_dtype=torch.bfloat16ifuse_bf16else"auto",token=token,padding_side="left")
18+
1919
self.__temp=temp
2020

21-
defask(self,prompt):
22-
messages= [{"role":"user","content":prompt}]
23-
text=self.__tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
24-
inputs=self.__tokenizer([text],return_tensors="pt")
25-
inputs.to(self.__device)
26-
outputs=self.__model.generate(**inputs,max_new_tokens=self.__max_new_tokens,
27-
temperature=self.__temp,do_sample=True)
28-
generated_ids= [
29-
output_ids[len(input_ids):]forinput_ids,output_idsinzip(inputs.input_ids,outputs)
30-
]
31-
returnself.__tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]
21+
self.__pipe=pipeline(
22+
"text-generation",
23+
model=model,
24+
tokenizer=tokenizer,
25+
)
26+
27+
defask(self,batch):
28+
29+
messages= [[{"role":"user","content":prompt}]forpromptinbatch]
30+
31+
generation_args= {
32+
"max_new_tokens":self.__max_new_tokens,
33+
"return_full_text":False,
34+
"temperature":self.__temp,
35+
"do_sample":True,
36+
}
37+
38+
output=self.__pipe(messages,**generation_args)
39+
return [response[0]["generated_text"]forresponseinoutput]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp