Gradient Boosted Decision Trees

  • Gradient boosting creates a strong predictive model by iteratively combining multiple weak models, typically decision trees.

  • In each iteration, a new weak model is trained to predict the errors of the current strong model, and then added to the strong model to improve its accuracy.

  • Shrinkage, similar to learning rate in neural networks, is used to control the learning speed and prevent overfitting by scaling the contribution of each weak model.

  • Gradient boosted trees are a specific implementation of gradient boosting that utilizes decision trees as the weak learners.

  • TensorFlow Decision Forests provides a practical implementation throughtfdf.keras.GradientBoostedTreeModel, streamlining the model building process.

Like bagging and boosting, gradient boosting is a methodology applied on top ofanother machine learning algorithm. Informally,gradient boosting involvestwo types of models:

  • a "weak" machine learning model, which is typically a decision tree.
  • a "strong" machine learning model, which is composed of multiple weakmodels.

In gradient boosting, at each step, a new weak model is trained to predict the"error" of the current strong model (which is called thepseudo response).We will detail "error" later. For now, assume "error" is the difference betweenthe prediction and a regressive label. The weak model (that is, the "error") isthen added to the strong model with a negative sign to reduce the error of thestrong model.

Gradient boosting is iterative. Each iteration invokes the following formula:

\[F_{i+1} = F_i - f_i\]

where:

  • $F_i$ is the strong model at step $i$.
  • $f_i$ is the weak model at step $i$.

This operation repeats until a stopping criterion is met, such as a maximumnumber of iterations or if the (strong) model begins to overfit as measured on aseparate validation dataset.

Let's illustrate gradient boosting on a simple regression dataset where:

  • The objective is to predict $y$ from $x$.
  • The strong model is initialized to be a zero constant: $F_0(x) = 0$.
Note: The following code is for educational aid only. In practice, you willsimply calltfdf.keras.GradientBoostedTreeModel.
#Simplifiedexampleofregressivegradientboosting.y=...#thelabelsx=...#thefeaturesstrong_model=[]strong_predictions=np.zeros_like(y)#Initially,thestrongmodelisempty.foriinrange(num_iters):#Errorofthestrongmodelerror=strong_predictions-y#Theweakmodelisadecisiontree(seeCARTchapter)#withoutpruningandamaximumdepthof3.weak_model=tfdf.keras.CartModel(task=tfdf.keras.Task.REGRESSION,validation_ratio=0.0,max_depth=3)weak_model.fit(x=x,y=error)strong_model.append(weak_model)weak_predictions=weak_model.predict(x)[:,0]strong_predictions-=weak_predictions

Let's apply this code on the following dataset:

A plot of ground truth for one feature, x, and its label, y. The plot is aseries of somewhat damped sinewaves.

Figure 25. A synthetic regressive dataset with one numerical feature.

 

Here are three plots after the first iteration of the gradient boostingalgorithm:

Three plots. The first plot shows the prediction of the strong model, which isa straight line of slope 0 and y-intercept 0. The second plot shows the error ofthe strong model, which is a series of sine waves. The third plot shows theprediction of the weak model, which is a set of squarewaves.

Figure 26. Three plots after the first iteration.

 

Note the following about the plots in Figure 26:

  • The first plot shows the predictions of the strong model, which is currentlyalways 0.
  • The second plot shows the error, which is the label of the weak model.
  • The third plot shows the weak model.

The first weak model is learning a coarse representation of the label and mostlyfocuses on the left part of the feature space (the part with the most variation,and therefore the most error for the constant wrong model).

Following are the same plots for another iteration of the algorithm:

Three plots. The first plot shows the prediction of the strong model, which isan inverse of the plot of the prediction of the weak model from the previousFigure. The second plot shows the error of the strong model, which is a noisyset of sine waves. The third plot shows the prediction of the weak model, whichis a couple of squarewaves.

Figure 27. Three plots after the second iteration.

 

Note the following about the plots in Figure 27:

  • The strong model now contains the prediction of the weak model of theprevious iteration.
  • The new error of the strong model is a bit smaller.
  • The new prediction of the weak model now focuses on the right part of thefeature space.

We run the algorithm for 8 more iterations:

The plots show the strong model gradually becoming closer to ground truthwhile the prediction of the weak model gradually becomes, well,weaker.

Figure 28. Three plots after the third iteration and the tenth iteration.

 

In Figure 28, note that the prediction of strong model starts to resemblethe plot of the dataset.

These figures illustrate the gradient boosting algorithm using decision trees asweak learners. This combination is calledgradient boosted (decision) trees.

The preceding plots suggest the essence of gradient boosting. However, thisexample lacks the following two real-world operations:

  • The shrinkage
  • The optimization of leaf values with one step of Newton's method
Note: In practice, there are multiple variants of the gradient boostingalgorithm with other operations.

Shrinkage

The weak model $f_i$ is multiplied by a small value $\nu$ (for example, $\nu =0.1$) before being added to the strong model $F_i$. This small value is calledtheshrinkage. In other words, instead of each iteration using the followingformula:

\[F_{i+1} = F_i - f_i\]

Each iteration uses the following formula:

\[F_{i+1} = F_i - \nu f_i\]

Shrinkage in gradient boosting is analogous to learning rate in neural networks.Shrinkage controls how fast the strong model is learning, which helps limitoverfitting. That is, a shrinkage value closer to 0.0 reduces overfitting morethan a shrinkage value closer to 1.0.

In our code above, the shrinkage would be implemented as follows:

shrinkage = 0.1   # 0.1 is a common shrinkage value.strong_predictions -= shrinkage * weak_predictions

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-08-25 UTC.