Note
Go to the endto download the full example code.
Demo for using and defining callback functions
Added in version 1.3.0.
importargparseimportosimporttempfilefromtypingimportDictimportnumpyasnpfrommatplotlibimportpyplotaspltfromsklearn.datasetsimportload_breast_cancerfromsklearn.model_selectionimporttrain_test_splitimportxgboostasxgbclassPlotting(xgb.callback.TrainingCallback):"""Plot evaluation result during training. Only for demonstration purpose as it's quite slow to draw using matplotlib. """def__init__(self,rounds:int)->None:self.fig=plt.figure()self.ax=self.fig.add_subplot(111)self.rounds=roundsself.lines:Dict[str,plt.Line2D]={}self.fig.show()self.x=np.linspace(0,self.rounds,self.rounds)plt.ion()def_get_key(self,data:str,metric:str)->str:returnf"{data}-{metric}"defafter_iteration(self,model:xgb.Booster,epoch:int,evals_log:Dict[str,dict])->bool:"""Update the plot."""ifnotself.lines:fordata,metricinevals_log.items():formetric_name,loginmetric.items():key=self._get_key(data,metric_name)expanded=log+[0]*(self.rounds-len(log))(self.lines[key],)=self.ax.plot(self.x,expanded,label=key)self.ax.legend()else:# https://pythonspot.com/matplotlib-update-plot/fordata,metricinevals_log.items():formetric_name,loginmetric.items():key=self._get_key(data,metric_name)expanded=log+[0]*(self.rounds-len(log))self.lines[key].set_ydata(expanded)self.fig.canvas.draw()# False to indicate training should not stop.returnFalsedefcustom_callback()->None:"""Demo for defining a custom callback function that plots evaluation result during training."""X,y=load_breast_cancer(return_X_y=True)X_train,X_valid,y_train,y_valid=train_test_split(X,y,random_state=0)D_train=xgb.DMatrix(X_train,y_train)D_valid=xgb.DMatrix(X_valid,y_valid)num_boost_round=100plotting=Plotting(num_boost_round)# Pass it to the `callbacks` parameter as a list.xgb.train({"objective":"binary:logistic","eval_metric":["error","rmse"],"tree_method":"hist","device":"cuda",},D_train,evals=[(D_train,"Train"),(D_valid,"Valid")],num_boost_round=num_boost_round,callbacks=[plotting],)defcheck_point_callback()->None:"""Demo for using the checkpoint callback. Custom logic for handling output is usually required and users are encouraged to define their own callback for checkpointing operations. The builtin one can be used as a starting point. """# Only for demo, set a larger value (like 100) in practice as checkpointing is quite# slow.rounds=2defcheck(as_pickle:bool)->None:foriinrange(0,10,rounds):ifi==0:continueifas_pickle:path=os.path.join(tmpdir,"model_"+str(i)+".pkl")else:path=os.path.join(tmpdir,f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",)assertos.path.exists(path)X,y=load_breast_cancer(return_X_y=True)m=xgb.DMatrix(X,y)# Check point to a temporary directory for demowithtempfile.TemporaryDirectory()astmpdir:# Use callback class from xgboost.callback# Feel free to subclass/customize it to suit your need.check_point=xgb.callback.TrainingCheckPoint(directory=tmpdir,interval=rounds,name="model")xgb.train({"objective":"binary:logistic"},m,num_boost_round=10,verbose_eval=False,callbacks=[check_point],)check(False)# This version of checkpoint saves everything including parameters and# model. See: doc/tutorials/saving_model.rstcheck_point=xgb.callback.TrainingCheckPoint(directory=tmpdir,interval=rounds,as_pickle=True,name="model")xgb.train({"objective":"binary:logistic"},m,num_boost_round=10,verbose_eval=False,callbacks=[check_point],)check(True)if__name__=="__main__":parser=argparse.ArgumentParser()parser.add_argument("--plot",default=1,type=int)args=parser.parse_args()check_point_callback()ifargs.plot:custom_callback()