11import torch
2- from transformers import AutoModelForCausalLM ,AutoTokenizer
2+ from transformers import AutoModelForCausalLM ,AutoTokenizer , pipeline
33from bulk_chain .core .llm_base import BaseLM
44
55
@@ -9,23 +9,31 @@ def __init__(self, model_name, temp=0.1, device='cpu',
99max_new_tokens = None ,token = None ,use_bf16 = False ,** kwargs ):
1010super (Qwen2 ,self ).__init__ (name = model_name ,** kwargs )
1111
12- self .__device = device
1312self .__max_new_tokens = max_new_tokens
14- self .__model = AutoModelForCausalLM .from_pretrained (
15- model_name ,torch_dtype = torch .bfloat16 if use_bf16 else "auto" ,token = token )
16- self .__model .to (device )
17- self .__tokenizer = AutoTokenizer .from_pretrained (
13+ model = AutoModelForCausalLM .from_pretrained (
1814model_name ,torch_dtype = torch .bfloat16 if use_bf16 else "auto" ,token = token )
15+ model .to (device )
16+ tokenizer = AutoTokenizer .from_pretrained (
17+ model_name ,torch_dtype = torch .bfloat16 if use_bf16 else "auto" ,token = token ,padding_side = "left" )
18+
1919self .__temp = temp
2020
21- def ask (self ,prompt ):
22- messages = [{"role" :"user" ,"content" :prompt }]
23- text = self .__tokenizer .apply_chat_template (messages ,tokenize = False ,add_generation_prompt = True )
24- inputs = self .__tokenizer ([text ],return_tensors = "pt" )
25- inputs .to (self .__device )
26- outputs = self .__model .generate (** inputs ,max_new_tokens = self .__max_new_tokens ,
27- temperature = self .__temp ,do_sample = True )
28- generated_ids = [
29- output_ids [len (input_ids ):]for input_ids ,output_ids in zip (inputs .input_ids ,outputs )
30- ]
31- return self .__tokenizer .batch_decode (generated_ids ,skip_special_tokens = True )[0 ]
21+ self .__pipe = pipeline (
22+ "text-generation" ,
23+ model = model ,
24+ tokenizer = tokenizer ,
25+ )
26+
27+ def ask (self ,batch ):
28+
29+ messages = [[{"role" :"user" ,"content" :prompt }]for prompt in batch ]
30+
31+ generation_args = {
32+ "max_new_tokens" :self .__max_new_tokens ,
33+ "return_full_text" :False ,
34+ "temperature" :self .__temp ,
35+ "do_sample" :True ,
36+ }
37+
38+ output = self .__pipe (messages ,** generation_args )
39+ return [response [0 ]["generated_text" ]for response in output ]