Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit84d92db

Browse files
authored
Enhanced gemma-3 output. Fixed problem with short output generation
1 parentf2e2ec1 commit84d92db

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

‎llm/transformers_gemma3.py‎

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class Gemma3(BaseLM):
1717
"""
1818

1919
def__init__(self,model_name="google/gemma-3-1b-it",temp=0.1,device='cuda',use_bf16=False,
20-
max_new_tokens=None,api_token=None,**kwargs):
20+
max_new_tokens=8192,api_token=None,**kwargs):
21+
assert (isinstance(max_new_tokens,int)andmax_new_tokensisnotNone)
2122
super(Gemma3,self).__init__(name=model_name,support_batching=True,**kwargs)
2223
self.__device=device
2324
self.__model=Gemma3ForCausalLM.from_pretrained(
@@ -27,6 +28,20 @@ def __init__(self, model_name="google/gemma-3-1b-it", temp=0.1, device='cuda', u
2728
self.__tokenizer=AutoTokenizer.from_pretrained(model_name,token=api_token)
2829
self.__temp=temp
2930

31+
@staticmethod
32+
def__handle_response(response,prompt):
33+
34+
# We attempt to crop the mentioned prompt.
35+
ifpromptnotinresponse:
36+
returnresponse
37+
response=response[response.index(prompt)+len(prompt):]
38+
39+
# We attempt to keep only the first response turn from the model.
40+
response_turns=response.split("\nmodel\n")
41+
iflen(response_turns)==0:
42+
returnresponse
43+
returnresponse_turns[1]
44+
3045
defask(self,batch):
3146

3247
messages= [
@@ -39,16 +54,17 @@ def ask(self, batch):
3954

4055
inputs=self.__tokenizer.apply_chat_template(
4156
messages,
42-
add_generation_prompt=False,
57+
add_generation_prompt=True,
4358
tokenize=True,
4459
return_dict=True,
45-
return_tensors="pt",
46-
padding=True,
60+
return_tensors="pt",
61+
padding=True,
4762
truncation=True)
4863
inputs.to(self.__device)
4964

5065
withtorch.inference_mode():
5166
outputs=self.__model.generate(**inputs,max_new_tokens=self.__max_new_tokens,
5267
temperature=self.__temp,do_sample=True)
53-
54-
returnself.__tokenizer.batch_decode(outputs,skip_special_tokens=True)
68+
69+
return [self.__handle_response(response=r,prompt=batch[i])
70+
fori,rinenumerate(self.__tokenizer.batch_decode(outputs,skip_special_tokens=True))]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp