@@ -1017,7 +1017,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
10171017logs ["step" ]= state .global_step
10181018logs ["max_steps" ]= state .max_steps
10191019logs ["timestamp" ]= str (datetime .now ())
1020- print_info (json .dumps (logs ))
1020+ print_info (json .dumps (logs , indent = 4 ))
10211021insert_logs (self .project_id ,self .model_id ,json .dumps (logs ))
10221022
10231023
@@ -1275,7 +1275,6 @@ def evaluate(self):
12751275
12761276if "eval_accuracy" in metrics .keys ():
12771277metrics ["accuracy" ]= metrics .pop ("eval_accuracy" )
1278-
12791278
12801279# Drop all the keys that are not floats or ints to be compatible for pgml-extension metrics typechecks
12811280metrics = {
@@ -1286,6 +1285,7 @@ def evaluate(self):
12861285
12871286return metrics
12881287
1288+
12891289class FineTuningTextPairClassification (FineTuningTextClassification ):
12901290def __init__ (
12911291self ,
@@ -1313,7 +1313,7 @@ def __init__(
13131313super ().__init__ (
13141314project_id ,model_id ,train_dataset ,test_dataset ,path ,hyperparameters
13151315 )
1316-
1316+
13171317def tokenize_function (self ,example ):
13181318"""
13191319 Tokenizes the input text using the tokenizer specified in the class.
@@ -1326,13 +1326,20 @@ def tokenize_function(self, example):
13261326
13271327 """
13281328if self .tokenizer_args :
1329- tokenized_example = self .tokenizer (example ["text1" ],example ["text2" ],** self .tokenizer_args )
1329+ tokenized_example = self .tokenizer (
1330+ example ["text1" ],example ["text2" ],** self .tokenizer_args
1331+ )
13301332else :
13311333tokenized_example = self .tokenizer (
1332- example ["text1" ],example ["text2" ],padding = True ,truncation = True ,return_tensors = "pt"
1334+ example ["text1" ],
1335+ example ["text2" ],
1336+ padding = True ,
1337+ truncation = True ,
1338+ return_tensors = "pt" ,
13331339 )
13341340return tokenized_example
13351341
1342+
13361343class FineTuningConversation (FineTuningBase ):
13371344def __init__ (
13381345self ,
@@ -1459,7 +1466,7 @@ def formatting_prompts_func(example):
14591466callbacks = [PGMLCallback (self .project_id ,self .model_id )],
14601467 )
14611468print_info ("Creating Supervised Fine Tuning trainer done. Training ... " )
1462-
1469+
14631470# Train
14641471self .trainer .train ()
14651472