Note

Go to the endto download the full example code.

Demo for using and defining callback functions

Added in version 1.3.0.

importargparseimportosimporttempfilefromtypingimportDictimportnumpyasnpfrommatplotlibimportpyplotaspltfromsklearn.datasetsimportload_breast_cancerfromsklearn.model_selectionimporttrain_test_splitimportxgboostasxgbclassPlotting(xgb.callback.TrainingCallback):"""Plot evaluation result during training.  Only for demonstration purpose as it's    quite slow to draw using matplotlib.    """def__init__(self,rounds:int)->None:self.fig=plt.figure()self.ax=self.fig.add_subplot(111)self.rounds=roundsself.lines:Dict[str,plt.Line2D]={}self.fig.show()self.x=np.linspace(0,self.rounds,self.rounds)plt.ion()def_get_key(self,data:str,metric:str)->str:returnf"{data}-{metric}"defafter_iteration(self,model:xgb.Booster,epoch:int,evals_log:Dict[str,dict])->bool:"""Update the plot."""ifnotself.lines:fordata,metricinevals_log.items():formetric_name,loginmetric.items():key=self._get_key(data,metric_name)expanded=log+[0]*(self.rounds-len(log))(self.lines[key],)=self.ax.plot(self.x,expanded,label=key)self.ax.legend()else:# https://pythonspot.com/matplotlib-update-plot/fordata,metricinevals_log.items():formetric_name,loginmetric.items():key=self._get_key(data,metric_name)expanded=log+[0]*(self.rounds-len(log))self.lines[key].set_ydata(expanded)self.fig.canvas.draw()# False to indicate training should not stop.returnFalsedefcustom_callback()->None:"""Demo for defining a custom callback function that plots evaluation result during    training."""X,y=load_breast_cancer(return_X_y=True)X_train,X_valid,y_train,y_valid=train_test_split(X,y,random_state=0)D_train=xgb.DMatrix(X_train,y_train)D_valid=xgb.DMatrix(X_valid,y_valid)num_boost_round=100plotting=Plotting(num_boost_round)# Pass it to the `callbacks` parameter as a list.xgb.train({"objective":"binary:logistic","eval_metric":["error","rmse"],"tree_method":"hist","device":"cuda",},D_train,evals=[(D_train,"Train"),(D_valid,"Valid")],num_boost_round=num_boost_round,callbacks=[plotting],)defcheck_point_callback()->None:"""Demo for using the checkpoint callback. Custom logic for handling output is    usually required and users are encouraged to define their own callback for    checkpointing operations. The builtin one can be used as a starting point.    """# Only for demo, set a larger value (like 100) in practice as checkpointing is quite# slow.rounds=2defcheck(as_pickle:bool)->None:foriinrange(0,10,rounds):ifi==0:continueifas_pickle:path=os.path.join(tmpdir,"model_"+str(i)+".pkl")else:path=os.path.join(tmpdir,f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",)assertos.path.exists(path)X,y=load_breast_cancer(return_X_y=True)m=xgb.DMatrix(X,y)# Check point to a temporary directory for demowithtempfile.TemporaryDirectory()astmpdir:# Use callback class from xgboost.callback# Feel free to subclass/customize it to suit your need.check_point=xgb.callback.TrainingCheckPoint(directory=tmpdir,interval=rounds,name="model")xgb.train({"objective":"binary:logistic"},m,num_boost_round=10,verbose_eval=False,callbacks=[check_point],)check(False)# This version of checkpoint saves everything including parameters and# model.  See: doc/tutorials/saving_model.rstcheck_point=xgb.callback.TrainingCheckPoint(directory=tmpdir,interval=rounds,as_pickle=True,name="model")xgb.train({"objective":"binary:logistic"},m,num_boost_round=10,verbose_eval=False,callbacks=[check_point],)check(True)if__name__=="__main__":parser=argparse.ArgumentParser()parser.add_argument("--plot",default=1,type=int)args=parser.parse_args()check_point_callback()ifargs.plot:custom_callback()

Gallery generated by Sphinx-Gallery