@@ -16,12 +16,11 @@ class Gemma3(BaseLM):
1616 - Total input context of 128K tokens for the 4B, 12B, and 27B sizes, and 32K tokens for the 1B size
1717 """
1818
19- def __init__ (self ,model_name = "google/gemma-3-1b-it" ,temp = 0.1 ,device = 'cpu ' ,
20- max_new_tokens = None ,api_token = None ,use_bf16 = False , ** kwargs ):
21- super (Gemma ,self ).__init__ (name = model_name ,support_batching = True ,** kwargs )
19+ def __init__ (self ,model_name = "google/gemma-3-1b-it" ,temp = 0.1 ,device = 'cuda ' ,
20+ max_new_tokens = None ,api_token = None ,** kwargs ):
21+ super (Gemma3 ,self ).__init__ (name = model_name ,support_batching = True ,** kwargs )
2222self .__device = device
23- self .__model = AutoModelForCausalLM .from_pretrained (
24- model_name ,torch_dtype = torch .bfloat16 if use_bf16 else "auto" ,token = api_token )
23+ self .__model = AutoModelForCausalLM .from_pretrained (model_name ,torch_dtype = "auto" ,token = api_token )
2524self .__max_new_tokens = max_new_tokens
2625self .__model .to (device )
2726self .__tokenizer = AutoTokenizer .from_pretrained (model_name ,token = api_token )