@@ -990,7 +990,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
990990logs ["step" ]= state .global_step
991991logs ["max_steps" ]= state .max_steps
992992logs ["timestamp" ]= str (datetime .now ())
993- print_info (json .dumps (logs ))
993+ print_info (json .dumps (logs , indent = 4 ))
994994insert_logs (self .project_id ,self .model_id ,json .dumps (logs ))
995995
996996
@@ -1248,7 +1248,6 @@ def evaluate(self):
12481248
12491249if "eval_accuracy" in metrics .keys ():
12501250metrics ["accuracy" ]= metrics .pop ("eval_accuracy" )
1251-
12521251
12531252# Drop all the keys that are not floats or ints to be compatible for pgml-extension metrics typechecks
12541253metrics = {
@@ -1259,6 +1258,7 @@ def evaluate(self):
12591258
12601259return metrics
12611260
1261+
12621262class FineTuningTextPairClassification (FineTuningTextClassification ):
12631263def __init__ (
12641264self ,
@@ -1286,7 +1286,7 @@ def __init__(
12861286super ().__init__ (
12871287project_id ,model_id ,train_dataset ,test_dataset ,path ,hyperparameters
12881288 )
1289-
1289+
12901290def tokenize_function (self ,example ):
12911291"""
12921292 Tokenizes the input text using the tokenizer specified in the class.
@@ -1299,13 +1299,20 @@ def tokenize_function(self, example):
12991299
13001300 """
13011301if self .tokenizer_args :
1302- tokenized_example = self .tokenizer (example ["text1" ],example ["text2" ],** self .tokenizer_args )
1302+ tokenized_example = self .tokenizer (
1303+ example ["text1" ],example ["text2" ],** self .tokenizer_args
1304+ )
13031305else :
13041306tokenized_example = self .tokenizer (
1305- example ["text1" ],example ["text2" ],padding = True ,truncation = True ,return_tensors = "pt"
1307+ example ["text1" ],
1308+ example ["text2" ],
1309+ padding = True ,
1310+ truncation = True ,
1311+ return_tensors = "pt" ,
13061312 )
13071313return tokenized_example
13081314
1315+
13091316class FineTuningConversation (FineTuningBase ):
13101317def __init__ (
13111318self ,
@@ -1432,7 +1439,7 @@ def formatting_prompts_func(example):
14321439callbacks = [PGMLCallback (self .project_id ,self .model_id )],
14331440 )
14341441print_info ("Creating Supervised Fine Tuning trainer done. Training ... " )
1435-
1442+
14361443# Train
14371444self .trainer .train ()
14381445