Note

Go to the endto download the full example code.

Demo for defining a custom regression objective and metric

Demo for defining customized metric and objective. Notice that for simplicity reasonweight is not used in following example. In this script, we implement the Squared LogError (SLE) objective and RMSLE metric as customized functions, then compare it withnative implementation in XGBoost.

SeeCustom Objective and Evaluation Metric for a step by step walkthrough, with otherdetails.

TheSLE objective reduces impact of outliers in training dataset, hence here we alsocompare its performance with standard squared error.

importargparsefromtimeimporttimefromtypingimportDict,List,Tupleimportnumpyasnpfrommatplotlibimportpyplotaspltimportxgboostasxgb# shape of generated data.kRows=4096kCols=16kOutlier=10000# mean of generated outlierskNumberOfOutliers=64kRatio=0.7kSeed=1994kBoostRound=20np.random.seed(seed=kSeed)defgenerate_data()->Tuple[xgb.DMatrix,xgb.DMatrix]:'''Generate data containing outliers.'''x=np.random.randn(kRows,kCols)y=np.random.randn(kRows)y+=np.abs(np.min(y))# Create outliersforiinrange(0,kNumberOfOutliers):ind=np.random.randint(0,len(y)-1)y[ind]+=np.random.randint(0,kOutlier)train_portion=int(kRows*kRatio)# rmsle requires all label be greater than -1.assertnp.all(y>-1.0)train_x:np.ndarray=x[:train_portion]train_y:np.ndarray=y[:train_portion]dtrain=xgb.DMatrix(train_x,label=train_y)test_x=x[train_portion:]test_y=y[train_portion:]dtest=xgb.DMatrix(test_x,label=test_y)returndtrain,dtestdefnative_rmse(dtrain:xgb.DMatrix,dtest:xgb.DMatrix)->Dict[str,Dict[str,List[float]]]:'''Train using native implementation of Root Mean Squared Loss.'''print('Squared Error')squared_error={'objective':'reg:squarederror','eval_metric':'rmse','tree_method':'hist','seed':kSeed}start=time()results:Dict[str,Dict[str,List[float]]]={}xgb.train(squared_error,dtrain=dtrain,num_boost_round=kBoostRound,evals=[(dtrain,'dtrain'),(dtest,'dtest')],evals_result=results)print('Finished Squared Error in:',time()-start,'\n')returnresultsdefnative_rmsle(dtrain:xgb.DMatrix,dtest:xgb.DMatrix)->Dict[str,Dict[str,List[float]]]:'''Train using native implementation of Squared Log Error.'''print('Squared Log Error')results:Dict[str,Dict[str,List[float]]]={}squared_log_error={'objective':'reg:squaredlogerror','eval_metric':'rmsle','tree_method':'hist','seed':kSeed}start=time()xgb.train(squared_log_error,dtrain=dtrain,num_boost_round=kBoostRound,evals=[(dtrain,'dtrain'),(dtest,'dtest')],evals_result=results)print('Finished Squared Log Error in:',time()-start)returnresultsdefpy_rmsle(dtrain:xgb.DMatrix,dtest:xgb.DMatrix)->Dict:'''Train using Python implementation of Squared Log Error.'''defgradient(predt:np.ndarray,dtrain:xgb.DMatrix)->np.ndarray:'''Compute the gradient squared log error.'''y=dtrain.get_label()return(np.log1p(predt)-np.log1p(y))/(predt+1)defhessian(predt:np.ndarray,dtrain:xgb.DMatrix)->np.ndarray:'''Compute the hessian for squared log error.'''y=dtrain.get_label()return((-np.log1p(predt)+np.log1p(y)+1)/np.power(predt+1,2))defsquared_log(predt:np.ndarray,dtrain:xgb.DMatrix)->Tuple[np.ndarray,np.ndarray]:'''Squared Log Error objective. A simplified version for RMSLE used as        objective function.        :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`        '''predt[predt<-1]=-1+1e-6grad=gradient(predt,dtrain)hess=hessian(predt,dtrain)returngrad,hessdefrmsle(predt:np.ndarray,dtrain:xgb.DMatrix)->Tuple[str,float]:''' Root mean squared log error metric.        :math:`\\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}`        '''y=dtrain.get_label()predt[predt<-1]=-1+1e-6elements=np.power(np.log1p(y)-np.log1p(predt),2)return'PyRMSLE',float(np.sqrt(np.sum(elements)/len(y)))results:Dict[str,Dict[str,List[float]]]={}xgb.train({'tree_method':'hist','seed':kSeed,'disable_default_eval_metric':1},dtrain=dtrain,num_boost_round=kBoostRound,obj=squared_log,custom_metric=rmsle,evals=[(dtrain,'dtrain'),(dtest,'dtest')],evals_result=results)returnresultsdefplot_history(rmse_evals:Dict[str,Dict],rmsle_evals:Dict[str,Dict],py_rmsle_evals:Dict[str,Dict])->None:fig,axs=plt.subplots(3,1)assertisinstance(axs,np.ndarray)ax0=axs[0]ax1=axs[1]ax2=axs[2]x=np.arange(0,kBoostRound,1)ax0.plot(x,rmse_evals['dtrain']['rmse'],label='train-RMSE')ax0.plot(x,rmse_evals['dtest']['rmse'],label='test-RMSE')ax0.legend()ax1.plot(x,rmsle_evals['dtrain']['rmsle'],label='train-native-RMSLE')ax1.plot(x,rmsle_evals['dtest']['rmsle'],label='test-native-RMSLE')ax1.legend()ax2.plot(x,py_rmsle_evals['dtrain']['PyRMSLE'],label='train-PyRMSLE')ax2.plot(x,py_rmsle_evals['dtest']['PyRMSLE'],label='test-PyRMSLE')ax2.legend()defmain(args:argparse.Namespace)->None:dtrain,dtest=generate_data()rmse_evals=native_rmse(dtrain,dtest)rmsle_evals=native_rmsle(dtrain,dtest)py_rmsle_evals=py_rmsle(dtrain,dtest)ifargs.plot!=0:plot_history(rmse_evals,rmsle_evals,py_rmsle_evals)plt.show()if__name__=="__main__":parser=argparse.ArgumentParser(description='Arguments for custom RMSLE objective function demo.')parser.add_argument('--plot',type=int,default=1,help='Set to 0 to disable plotting the evaluation history.')args=parser.parse_args()main(args)

Gallery generated by Sphinx-Gallery