Note
Go to the endto download the full example code.
Demo for training continuation
importosimportpickleimporttempfilefromsklearn.datasetsimportload_breast_cancerimportxgboostdeftraining_continuation(tmpdir:str,use_pickle:bool)->None:"""Basic training continuation."""# Train 128 iterations in 1 sessionX,y=load_breast_cancer(return_X_y=True)clf=xgboost.XGBClassifier(n_estimators=128,eval_metric="logloss")clf.fit(X,y,eval_set=[(X,y)])print("Total boosted rounds:",clf.get_booster().num_boosted_rounds())# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and# the second one runs for 96 iterationsclf=xgboost.XGBClassifier(n_estimators=32,eval_metric="logloss")clf.fit(X,y,eval_set=[(X,y)])assertclf.get_booster().num_boosted_rounds()==32# load back the model, this could be a checkpointifuse_pickle:path=os.path.join(tmpdir,"model-first-32.pkl")withopen(path,"wb")asfd:pickle.dump(clf,fd)withopen(path,"rb")asfd:loaded=pickle.load(fd)else:path=os.path.join(tmpdir,"model-first-32.json")clf.save_model(path)loaded=xgboost.XGBClassifier()loaded.load_model(path)clf=xgboost.XGBClassifier(n_estimators=128-32,eval_metric="logloss")clf.fit(X,y,eval_set=[(X,y)],xgb_model=loaded)print("Total boosted rounds:",clf.get_booster().num_boosted_rounds())assertclf.get_booster().num_boosted_rounds()==128deftraining_continuation_early_stop(tmpdir:str,use_pickle:bool)->None:"""Training continuation with early stopping."""early_stopping_rounds=5early_stop=xgboost.callback.EarlyStopping(rounds=early_stopping_rounds,save_best=True)n_estimators=512X,y=load_breast_cancer(return_X_y=True)clf=xgboost.XGBClassifier(n_estimators=n_estimators,eval_metric="logloss",callbacks=[early_stop])clf.fit(X,y,eval_set=[(X,y)])print("Total boosted rounds:",clf.get_booster().num_boosted_rounds())best=clf.best_iteration# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and# the second one runs until early stop.clf=xgboost.XGBClassifier(n_estimators=128,eval_metric="logloss",callbacks=[early_stop])# Reinitialize the early stop callbackearly_stop=xgboost.callback.EarlyStopping(rounds=early_stopping_rounds,save_best=True)clf.set_params(callbacks=[early_stop])clf.fit(X,y,eval_set=[(X,y)])assertclf.get_booster().num_boosted_rounds()==128# load back the model, this could be a checkpointifuse_pickle:path=os.path.join(tmpdir,"model-first-128.pkl")withopen(path,"wb")asfd:pickle.dump(clf,fd)withopen(path,"rb")asfd:loaded=pickle.load(fd)else:path=os.path.join(tmpdir,"model-first-128.json")clf.save_model(path)loaded=xgboost.XGBClassifier()loaded.load_model(path)early_stop=xgboost.callback.EarlyStopping(rounds=early_stopping_rounds,save_best=True)clf=xgboost.XGBClassifier(n_estimators=n_estimators-128,eval_metric="logloss",callbacks=[early_stop])clf.fit(X,y,eval_set=[(X,y)],xgb_model=loaded,)print("Total boosted rounds:",clf.get_booster().num_boosted_rounds())assertclf.best_iteration==bestif__name__=="__main__":withtempfile.TemporaryDirectory()astmpdir:training_continuation_early_stop(tmpdir,False)training_continuation_early_stop(tmpdir,True)training_continuation(tmpdir,True)training_continuation(tmpdir,False)