@@ -17,7 +17,8 @@ class Gemma3(BaseLM):
1717 """
1818
1919def __init__ (self ,model_name = "google/gemma-3-1b-it" ,temp = 0.1 ,device = 'cuda' ,use_bf16 = False ,
20- max_new_tokens = None ,api_token = None ,** kwargs ):
20+ max_new_tokens = 8192 ,api_token = None ,** kwargs ):
21+ assert (isinstance (max_new_tokens ,int )and max_new_tokens is not None )
2122super (Gemma3 ,self ).__init__ (name = model_name ,support_batching = True ,** kwargs )
2223self .__device = device
2324self .__model = Gemma3ForCausalLM .from_pretrained (
@@ -27,6 +28,20 @@ def __init__(self, model_name="google/gemma-3-1b-it", temp=0.1, device='cuda', u
2728self .__tokenizer = AutoTokenizer .from_pretrained (model_name ,token = api_token )
2829self .__temp = temp
2930
31+ @staticmethod
32+ def __handle_response (response ,prompt ):
33+
34+ # We attempt to crop the mentioned prompt.
35+ if prompt not in response :
36+ return response
37+ response = response [response .index (prompt )+ len (prompt ):]
38+
39+ # We attempt to keep only the first response turn from the model.
40+ response_turns = response .split ("\n model\n " )
41+ if len (response_turns )== 0 :
42+ return response
43+ return response_turns [1 ]
44+
3045def ask (self ,batch ):
3146
3247messages = [
@@ -39,16 +54,17 @@ def ask(self, batch):
3954
4055inputs = self .__tokenizer .apply_chat_template (
4156messages ,
42- add_generation_prompt = False ,
57+ add_generation_prompt = True ,
4358tokenize = True ,
4459return_dict = True ,
45- return_tensors = "pt" ,
46- padding = True ,
60+ return_tensors = "pt" ,
61+ padding = True ,
4762truncation = True )
4863inputs .to (self .__device )
4964
5065with torch .inference_mode ():
5166outputs = self .__model .generate (** inputs ,max_new_tokens = self .__max_new_tokens ,
5267temperature = self .__temp ,do_sample = True )
53-
54- return self .__tokenizer .batch_decode (outputs ,skip_special_tokens = True )
68+
69+ return [self .__handle_response (response = r ,prompt = batch [i ])
70+ for i ,r in enumerate (self .__tokenizer .batch_decode (outputs ,skip_special_tokens = True ))]