Note

Go to the endto download the full example code.

Collection of examples for using sklearn interface

For an introduction to XGBoost’s scikit-learn estimator interface, seeUsing the Scikit-Learn Estimator Interface.

Created on 1 Apr 2015

@author: Jamie Hall

importpicklefromurllib.errorimportHTTPErrorimportnumpyasnpfromsklearn.datasetsimport(fetch_california_housing,load_digits,load_iris,make_regression,)fromsklearn.metricsimportconfusion_matrix,mean_squared_errorfromsklearn.model_selectionimportGridSearchCV,KFold,train_test_splitimportxgboostasxgbrng=np.random.RandomState(31337)print("Zeros and Ones from the Digits dataset: binary classification")digits=load_digits(n_class=2)y=digits["target"]X=digits["data"]kf=KFold(n_splits=2,shuffle=True,random_state=rng)fortrain_index,test_indexinkf.split(X):xgb_model=xgb.XGBClassifier(n_jobs=1).fit(X[train_index],y[train_index])predictions=xgb_model.predict(X[test_index])actuals=y[test_index]print(confusion_matrix(actuals,predictions))print("Iris: multiclass classification")iris=load_iris()y=iris["target"]X=iris["data"]kf=KFold(n_splits=2,shuffle=True,random_state=rng)fortrain_index,test_indexinkf.split(X):xgb_model=xgb.XGBClassifier(n_jobs=1).fit(X[train_index],y[train_index])predictions=xgb_model.predict(X[test_index])actuals=y[test_index]print(confusion_matrix(actuals,predictions))print("California Housing: regression")try:X,y=fetch_california_housing(return_X_y=True)exceptHTTPError:# Use a synthetic dataset instead if we couldn'tX,y=make_regression(n_samples=20640,n_features=8,random_state=1234)kf=KFold(n_splits=2,shuffle=True,random_state=rng)fortrain_index,test_indexinkf.split(X):xgb_model=xgb.XGBRegressor(n_jobs=1).fit(X[train_index],y[train_index])predictions=xgb_model.predict(X[test_index])actuals=y[test_index]print(mean_squared_error(actuals,predictions))print("Parameter optimization")xgb_model=xgb.XGBRegressor(n_jobs=1)clf=GridSearchCV(xgb_model,{"max_depth":[2,4],"n_estimators":[50,100]},verbose=1,n_jobs=1,cv=3,)clf.fit(X,y)print(clf.best_score_)print(clf.best_params_)# The sklearn API models are picklableprint("Pickling sklearn API models")# must open in binary format to picklepickle.dump(clf,open("best_calif.pkl","wb"))clf2=pickle.load(open("best_calif.pkl","rb"))print(np.allclose(clf.predict(X),clf2.predict(X)))# Early-stoppingX=digits["data"]y=digits["target"]X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)clf=xgb.XGBClassifier(n_jobs=1,early_stopping_rounds=10,eval_metric="auc")clf.fit(X_train,y_train,eval_set=[(X_test,y_test)])

Gallery generated by Sphinx-Gallery