Note

Go to the endto download the full example code.

Getting started with XGBoost

This is a simple example of using the native XGBoost interface, there are otherinterfaces in the Python package like scikit-learn interface and Dask interface.

SeePython Package Introduction andXGBoost Tutorials for other references.

importosimportpickleimportnumpyasnpfromsklearn.datasetsimportload_svmlight_fileimportxgboostasxgb# Make sure the demo knows where to load the data.CURRENT_DIR=os.path.dirname(os.path.abspath(__file__))XGBOOST_ROOT_DIR=os.path.dirname(os.path.dirname(CURRENT_DIR))DEMO_DIR=os.path.join(XGBOOST_ROOT_DIR,"demo")# X is a scipy csr matrix, XGBoost supports many other input types,X,y=load_svmlight_file(os.path.join(DEMO_DIR,"data","agaricus.txt.train"))dtrain=xgb.DMatrix(X,y)# validation setX_test,y_test=load_svmlight_file(os.path.join(DEMO_DIR,"data","agaricus.txt.test"))dtest=xgb.DMatrix(X_test,y_test)# specify parameters via map, definition are same as c++ versionparam={"max_depth":2,"eta":1,"objective":"binary:logistic"}# specify validations set to watch performancewatchlist=[(dtest,"eval"),(dtrain,"train")]# number of boosting roundsnum_round=2bst=xgb.train(param,dtrain,num_boost_round=num_round,evals=watchlist)# run predictionpreds=bst.predict(dtest)labels=dtest.get_label()print("error=%f"%(sum(1foriinrange(len(preds))ifint(preds[i]>0.5)!=labels[i])/float(len(preds))))bst.save_model("model-0.json")# dump modelbst.dump_model("dump.raw.txt")# dump model with feature mapbst.dump_model("dump.nice.txt",os.path.join(DEMO_DIR,"data/featmap.txt"))# save dmatrix into binary bufferdtest.save_binary("dtest.dmatrix")# save modelbst.save_model("model-1.json")# load model and data inbst2=xgb.Booster(model_file="model-1.json")dtest2=xgb.DMatrix("dtest.dmatrix")preds2=bst2.predict(dtest2)# assert they are the sameassertnp.sum(np.abs(preds2-preds))==0# alternatively, you can pickle the boosterpks=pickle.dumps(bst2)# load model and data inbst3=pickle.loads(pks)preds3=bst3.predict(dtest2)# assert they are the sameassertnp.sum(np.abs(preds3-preds))==0

Gallery generated by Sphinx-Gallery