Note
Go to the endto download the full example code.
Quantile Regression
Added in version 2.0.0.
The script is inspired by this awesome example in sklearn:https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html
Note
The feature is only supported using the Python, R, and C packages. In addition, quantilecrossing can happen due to limitation in the algorithm.
importargparsefromtypingimportDictimportnumpyasnpfromsklearn.model_selectionimporttrain_test_splitimportxgboostasxgbdeff(x:np.ndarray)->np.ndarray:"""The function to predict."""returnx*np.sin(x)defquantile_loss(args:argparse.Namespace)->None:"""Train a quantile regression model."""rng=np.random.RandomState(1994)# Generate a synthetic dataset for demo, the generate process is from the sklearn# example.X=np.atleast_2d(rng.uniform(0,10.0,size=1000)).Texpected_y=f(X).ravel()sigma=0.5+X.ravel()/10.0noise=rng.lognormal(sigma=sigma)-np.exp(sigma**2.0/2.0)y=expected_y+noise# Train on 0.05 and 0.95 quantiles. The model is similar to multi-class and# multi-target models.alpha=np.array([0.05,0.5,0.95])evals_result:Dict[str,Dict]={}X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=rng)# We will be using the `hist` tree method, quantile DMatrix can be used to preserve# memory (which has nothing to do with quantile regression itself, see its document# for details).# Do not use the `exact` tree method for quantile regression, otherwise the# performance might drop.Xy=xgb.QuantileDMatrix(X_train,y_train)# use Xy as a referenceXy_test=xgb.QuantileDMatrix(X_test,y_test,ref=Xy)booster=xgb.train({# Use the quantile objective function."objective":"reg:quantileerror","tree_method":"hist","quantile_alpha":alpha,# Let's try not to overfit."learning_rate":0.04,"max_depth":5,},Xy,num_boost_round=32,early_stopping_rounds=2,# The evaluation result is a weighted average across multiple quantiles.evals=[(Xy,"Train"),(Xy_test,"Test")],evals_result=evals_result,)xx=np.atleast_2d(np.linspace(0,10,1000)).Tscores=booster.inplace_predict(xx)# dim 1 is the quantilesassertscores.shape[0]==xx.shape[0]assertscores.shape[1]==alpha.shape[0]y_lower=scores[:,0]# alpha=0.05y_med=scores[:,1]# alpha=0.5, mediany_upper=scores[:,2]# alpha=0.95# Train a mse model for comparisonbooster=xgb.train({"objective":"reg:squarederror","tree_method":"hist",# Let's try not to overfit."learning_rate":0.04,"max_depth":5,},Xy,num_boost_round=32,early_stopping_rounds=2,evals=[(Xy,"Train"),(Xy_test,"Test")],evals_result=evals_result,)xx=np.atleast_2d(np.linspace(0,10,1000)).Ty_pred=booster.inplace_predict(xx)ifargs.plot:frommatplotlibimportpyplotaspltfig=plt.figure(figsize=(10,10))plt.plot(xx,f(xx),"g:",linewidth=3,label=r"$f(x) = x\,\sin(x)$")plt.plot(X_test,y_test,"b.",markersize=10,label="Test observations")plt.plot(xx,y_med,"r-",label="Predicted median")plt.plot(xx,y_pred,"m-",label="Predicted mean")plt.plot(xx,y_upper,"k-")plt.plot(xx,y_lower,"k-")plt.fill_between(xx.ravel(),y_lower,y_upper,alpha=0.4,label="Predicted 90% interval")plt.xlabel("$x$")plt.ylabel("$f(x)$")plt.ylim(-10,25)plt.legend(loc="upper left")plt.show()if__name__=="__main__":parser=argparse.ArgumentParser()parser.add_argument("--plot",action="store_true",help="Specify it to enable plotting the outputs.",)args=parser.parse_args()quantile_loss(args)