Movatterモバイル変換


[0]ホーム

URL:


grf2.5.0

Introduction to local linear forests

Source:vignettes/llf.Rmd
llf.Rmd
library(grf)library(glmnet)library(ggplot2)

This document aims to show how to use local linear forests (LLF). Webegin with the standard use case, walking through parameter choices andmethod details, and then discuss how to use local linear correctionswith larger datasets.

Local Linear Forests: the basics

Random forests are a popular and powerful nonparametric regressionmethod, but can suffer in the presence of strong, smooth effects. Locallinear regression is a great method for fitting relatively smoothfunctions in low dimensions, but quickly deteriorates due to the curseof dimensionality: it relies on Euclidean distance, which fast loses itslocality even in 4 or 5 dimensions. This algorithm leverages thestrengths of each method (the data adaptivity of random forests andsmooth fits of local linear regression) to give improved predictions andconfidence intervals. For a complete treatment of local linear forests(LLF), seeour paper onArXiv.

Consider a random forest with\(B\)trees predicting at a test point\(x_0\). In each tree\(b\), the test point falls into a leaf\(L_b(x_0)\). A regression forest predicts byaveraging all responses in\(L_b(x_0)\), and then averaging thosepredictions\(\hat{\mu}_b(x_0)\) overall trees. To gain a new perspective on random forests, we can swap thesum to start thinking about random forests as a kernel or weightingmethod in high dimensions.

\[\begin{align*}\hat{\mu}(x_0)&= \frac1B \sum_{b=1}^B \sum_{i=1}^n Y_i \frac{1\{x_i\inL_b(x_0)\}}{|L_b(x_0)|}\\&= \sum_{i=1}^n Y_i \frac1B \sum_{b=1}^B \frac{1\{x_i\inL_b(x_0)\}}{|L_b(x_0)|} \\&= \sum_{i=1}^n \alpha_i(x_0) Y_i,\end{align*}\] where the forest weight\(\alpha_i(x_0)\) is the fraction of trees inwhich an observation appears in the same leaf as the target value of thecovariate vector.\[\begin{equation}\alpha_i(x_0) = \frac1B \sum_{b=1}^B \frac{1\{x_i\inL_b(x_0)\}}{|L_b(x_0)|}\end{equation}\]

Local linear forests take this one step further: now, instead ofusing the weights to fit a local average at\(x_0\), we use them to fit a local linearregression, with a ridge penalty for regularization. This amounts tosolving the minimization problem below, with parameters:\(\mu(x)\) for the local average, and\(\theta(x)\) for the slope of the localline.\[\begin{equation}\begin{pmatrix} \hat{\mu}(x_0) \\ \hat{\theta}(x_0) \end{pmatrix} =\text{argmin}_{\mu,\theta} \left\{\sum_{i=1}^n \alpha_i(x_0) \left(Y_i -\mu(x_0) - (x_i - x_0)\theta(x_0) \right)^2 + \lambda||\theta(x_0)||_2^2\right\}\end{equation}\]

This enables us to (i) use local linear regression in high dimensionswith a meaningful kernel, and (ii) predict with random forests even inthe presence of smooth, strong signals.

Toy Example

We start with a simple example illustrating when local linear forestscan improve on random forests.

p<-20n<-1000sigma<-sqrt(20)mu<-function(x){log(1 +exp(6 *x)) }X<-matrix(runif(n *p, -1,1),nrow=n)Y<-mu(X[,1]) +sigma *rnorm(n)X.test<-matrix(runif(n *p, -1,1),nrow=n)ticks<-seq(-1,1,length=n)X.test[,1]<-tickstruth<-mu(ticks)forest<-regression_forest(X,Y)preds.forest<-predict(forest,X.test)$predictionsdf<-data.frame(cbind(ticks,truth,preds.forest))g1<-ggplot(df,aes(ticks)) +geom_point(aes(y=preds.forest,color="Regression Forest"),show.legend=F,size=0.6) +geom_line(aes(y=truth)) +xlab("x") +ylab("y") +theme_bw()g1

ll.forest<-ll_regression_forest(X,Y,enable.ll.split=TRUE)preds.llf<-predict(ll.forest,X.test,linear.correction.variables=1)$predictionsdf.llf<-data.frame(cbind(ticks,truth,preds.llf))g2<-ggplot(df.llf,aes(ticks)) +geom_point(aes(y=preds.llf,color="Local Linear Forest"),show.legend=F,size=0.6) +geom_line(aes(y=truth)) +xlab("x") +ylab("y") +theme_bw()g2

Parameters for LLF Prediction

There are several modifications to discuss. We begin by listing outthe subset of training parameters specific to local linear forests(users can also consider regression_forest tuning parameters, all ofwhich apply here as well).

LLF Parameters.
ParametersValue OptionsDefault ValueDetails
Training Parameters
enable.ll.splitTRUE/FALSEFALSEOptional choice to make forest splits based on ridgeresiduals as opposed to standard CART splits.
ll.split.weight.penaltyTRUE/FALSEFALSEIf using local linear splits, user can specify whetherto standaridze the ridge penalty by covariance (TRUE), or penalize allcovariates equally (FALSE).
ll.split.lambdaNon-negative double0.1Ridge penalty for splitting.
ll.split.variablesVector of covariate indexes1:pVariables to use in split regressions.
ll.split.cutoffInteger between 0 and n.Square root of n.If greater than 0, when leaves reach this size, forestuses regression coefficients from the full dataset for ridge regressionsduring tree training. If equal to 0, trees run a regression in eachleaf.
Prediction parameters
ll.lambdaNon-negative doubleTuned by default.Ridge penalty for prediction.
ll.weight.penaltyTRUE/FALSEFALSEStandardize ridge penalty by covariance (TRUE), orpenalize all covariates equally (FALSE).
linear.correction.variablesVector of covariate indexes1:pSubset of indexes for variables to be used in locallinear prediction.

Training the Algorithm

n<-600p<-20sigma<-sqrt(20)mu<-function(x ) {10 *sin(pi *x[1] *x[2]) +20 * ((x[3] -0.5) **2) +10 *x[4] +5 *x[5]}X<-matrix(runif(n *p,0,1),nrow=n)Y<-apply(X,FUN=mu,MARGIN=1) +sigma *rnorm(n)X.test<-matrix(runif(n *p,0,1),nrow=n)truth=apply(X.test,FUN=mu,MARGIN=1)# regression forest predictionsrforest<-regression_forest(X,Y,honesty=TRUE)results<-predict(rforest,X.test)preds<-results$predictionsmean((preds -truth)**2)#> [1] 8.651616

We can get LLF predictions both from a standard regression forest byspecifying linear correction variables, or from a ll_regression_forestobject. The parameter linear correction variables gives the variables touse for the final local regression step. This can simply be allvariables, or might be a subset.

# llf predictions from regression_forestresults.llf<-predict(rforest,X.test,linear.correction.variables=1:p)preds.llf<-results.llf$predictionsmean((preds.llf -truth)**2)#> [1] 5.792562# llf predictions from ll_regression_forestforest<-ll_regression_forest(X,Y,honesty=TRUE)results.llf<-predict(forest,X.test)preds.llf<-results.llf$predictionsmean((preds.llf -truth)**2)#> [1] 5.78837

Weight Penalties

When we perform LLF predictions, we can either do a standard ridgeregression (ll.weight.penalty set to FALSE), or scale by the covariancematrix (ll.weight.penalty set to TRUE):\(\hat{\beta}_\text{TRUE} = (X'AX (1 +\lambda))^{-1} X'AY\). This defaults to FALSE.

results.llf.unweighted<-predict(forest,X.test,ll.weight.penalty=FALSE)preds.llf.unweighted<-results.llf.unweighted$predictionsmean((preds.llf.unweighted -truth)**2)#> [1] 5.78837results.llf.weighted<-predict(forest,X.test,ll.weight.penalty=TRUE)preds.llf.weighted<-results.llf.weighted$predictionsmean((preds.llf.weighted -truth)**2)#> [1] 5.736676

Residual Splits

We also consider the role of tree training in local linear forests. Astandard CART split minimizes prediction error from predicting leaf-wideaverages. Instead, we can use residual splits, which minimize thecorresponding prediction errors on ridge regression residuals. We mightexpect this to help in cases where there are strong linear signals fromsome covariates; we won’t waste any forest splits modelling those, butcan still discover them in the final regression step. Essentially thishelps us make the most efficient possible splits in the forest, knowingthat we have a local regression coming up to predict. This is currentlyan experimental feature.

forest<-ll_regression_forest(X,Y)preds.cart.splits<-predict(forest,X.test)ll.forest<-ll_regression_forest(X,Y,enable.ll.split=TRUE,ll.split.weight.penalty=TRUE)preds.ll.splits<-predict(ll.forest,X.test)mse.cart.splits<-mean((preds.cart.splits$predictions -truth)^2)mse.ll.splits<-mean((preds.ll.splits$predictions -truth)^2)mse.cart.splits#> [1] 5.771045mse.ll.splits#> [1] 4.703251

To uncover exactly why this works, we can look at plots showing thesplit frequencies of both forests. In each plot, tiles represent howmany splits at at given level (y-axis) of the tree were at each feature(x-axis).

p<-5XX<-matrix(runif(n *p,0 ,1),nrow=n)YY<-apply(XX,MARGIN=1,FUN=mu) +sigma *rnorm(n)forest2<-regression_forest(XX,YY)max.depth<-4freqs<-split_frequencies(forest2,max.depth=max.depth)d<-data.frame(freqs)dm<-data.frame(variable=sort(rep(names(d),nrow(d))),value=as.vector(as.matrix(d)),depth=rep(1:max.depth,p))# normalize value by sum of value at depthfor(iin1:max.depth){tot.depth<-sum(dm[dm$depth==i,]$value)dm[dm$depth==i,]$value<-dm[dm$depth==i,]$value /tot.depth}g<-ggplot(dm,aes(x=variable,y= -depth,fill=value)) +geom_tile() +xlab("Variable") +ylab("Depth") +scale_fill_gradient(low="white",high="blue",limits=c(0,1),"Frequency \n") +ggtitle("") +theme(text=element_text(size=15))g

ll.forest2<-ll_regression_forest(XX,YY,enable.ll.split=TRUE,ll.split.weight.penalty=TRUE)freqs<-split_frequencies(ll.forest2,max.depth=max.depth)d<-data.frame(freqs)dm<-data.frame(variable=sort(rep(names(d),nrow(d))),value=as.vector(as.matrix(d)),depth=rep(1:max.depth,p))for(iin1:max.depth){tot.depth<-sum(dm[dm$depth==i,]$value)dm[dm$depth==i,]$value<-dm[dm$depth==i,]$value /tot.depth}g2<-ggplot(dm,aes(x=variable,y=-depth,fill=value)) +geom_tile() +xlab("Variable") +ylab("Depth") +scale_fill_gradient(low="white",high="blue",limits=c(0,1),"Frequency \n") +ggtitle("Split Frequencies: LLF") +theme(text=element_text(size=15))g2

LLF Prediction

Ridge parameter selection

The user can choose to specify a ridge regression parameterll.lambda. When this variable is not set by the user, it will beselected by automatic parameter tuning. In general, we recommend lettingthe forest tune this parameter, or performing your own cross-validationloop. The exception to this would be for very large datasets.

results.llf.lambda<-predict(forest,X.test,ll.lambda=0.1)preds.llf.lambda<-results.llf.lambda$predictionsmean((preds.llf.lambda -truth)**2)#> [1] 5.771045results.llf.lambda<-predict(forest,X.test)# automatic tuningpreds.llf.lambda<-results.llf.lambda$predictionsmean((preds.llf.lambda -truth)**2)#> [1] 5.771045

Linear Correction Variable Selection

Especially with many covariates, it is reasonable to restrict thelocal regression to only include a few features of interest. Werecommend using the lasso.

# Train forestforest<-ll_regression_forest(X,Y)# Select covariateslasso.mod<-cv.glmnet(X,Y,alpha=1)lasso.coef<-predict(lasso.mod,type="nonzero")selected<-lasso.coef[,1]selected#> [1] 1 2 4 5# Predict with all covariatesllf.all.preds<-predict(forest,X.test)results.all<-llf.all.preds$predictionsmean((results.all -truth)**2)#> [1] 5.800366# Predict with just those covariatesllf.lasso.preds<-predict(forest,X.test,linear.correction.variables=selected)results.llf.lasso<-llf.lasso.preds$predictionsmean((results.llf.lasso -truth)**2)#> [1] 4.81568

Pointwise Confidence Intervals

Last, consider variance estimates and confidence intervals, which areanalogous to grf variance estimates. We use our first data-generatingprocess for easier visualization.

mu<-function(x){log(1 +exp(6 *x)) }p<-20X<-matrix(runif(n *p, -1,1),nrow=n)Y<-mu(X[,1]) +sigma *rnorm(n)X.test<-matrix(runif(n *p, -1,1),nrow=n)ticks<-seq(-1,1,length=n)X.test[,1]<-tickstruth<-mu(ticks)# Select covariateslasso.mod<-cv.glmnet(X,Y,alpha=1)lasso.coef<-predict(lasso.mod,type="nonzero")selected<-lasso.coef[,1]selected#> [1] 1ll.forest<-ll_regression_forest(X,Y,enable.ll.split=TRUE)results.llf.var<-predict(ll.forest,X.test,linear.correction.variables=selected,estimate.variance=TRUE)preds.llf.var<-results.llf.var$predictionsvariance.estimates<-results.llf.var$variance.estimates# find lower and upper bounds for 95% intervalslower.llf<-preds.llf.var -1.96*sqrt(variance.estimates)upper.llf<-preds.llf.var +1.96*sqrt(variance.estimates)df<-data.frame(cbind(ticks,truth,preds.llf.var,lower.llf,upper.llf))ggplot(df,aes(ticks)) +geom_point(aes(y=preds.llf.var,color="Local Linear Forest"),show.legend=F,size=0.6) +geom_line(aes(y=truth)) +geom_line(aes(y=lower.llf),color="gray",lty=2) +geom_line(aes(y=upper.llf),color="gray",lty=2) +xlab("x") +ylab("y") +theme_bw()

A Note on Larger Datasets

Although generally local linear forests are an improvement overregular random forests, when the number of dimensions is very high,training and predicting using them can take long time. This is becausewith n_train and n_test train and test points, we are running n_testregressions with n_train data points each. However, sometimes we stillwant to use random forests and correct for linear trends. In this case(datasets with roughly 100,000 or more observations, although alwayscontext-dependent), selecting a small number linear correction variablesis especially important. The current ridge parameter tuning will alsotake prohibitively long, and so we recommend either setting the value to0.01 consistently, tuning this on a subset of the data, restricting therange of values considered, or cross-validating using a small number ofshallow trees.

LLF Parameters and their function with largedatasets.
ParametersValue OptionsDefault ValuePerformance as n,p increase
Training Parameters
enable.ll.splitTRUE/FALSEFALSEFor large n and p, ridge regressions in each leaf will take anespecially long time. We therefore recommend either not using thisfeature, or setting the tuning parameters discussed below carefully.Please also note that this is still an experimental feature.
ll.split.weight.penaltyTRUE/FALSEFALSEUnaffected
ll.split.lambdaNon-negative double0.1Unaffected
ll.split.variablesVector of covariate indexes1:pFor large p (over 50), we recommend either limitingll.split.variables to a subset, using CART splits, or enforcing areasonably large split cutoff (below).
ll.split.cutoffInteger between 0 and n.Square root of n.Increasing this parameter can help to speed up LL splits, and isrecommended for users who want to use LL splits when training forestswith large data.
Prediction parameters
ll.lambdaNon-negative doubleTuned by default.Tuning by default takes a long time with large datasets, and werecommend either writing your own, shorter cross-validation loop, orsetting ll.lambda to a default value around 0.1 instead of usingautomatic tuning.
ll.weight.penaltyTRUE/FALSEFALSEUnaffected
linear.correction.variablesVector of covariate indexes1:pLimiting linear correction variables is a crucial step for efficientpredictions as n, p increase. We highly recommend using a lasso,stepwise regression, prior knowledge, etc. to select a fairly smallnumber of linear correction variables for LLF prediction in thiscase.

The following code is set to not run currently; users can expect itto take approximately 5-6 minutes. LLF predictions with all p variableswill be very slow, as will automatic tuning with this scale of data.However, we can still use linear corrections, just with more carefulparameters. Users can increase the number of trees for betterperformance from both methods.

# generate datan<-5e5p<-20sigma<-sqrt(20)f<-function(x){10 *sin(pi *x[1] *x[2]) +10 * (x[3] -0.5)**2 +10 *x[4] +5 *x[5]}X<-matrix(runif(n *p,0,1),nrow=n,ncol=p)Y<-apply(X,MARGIN=1,FUN=f) +sigma *rnorm(n)X.test<-matrix(runif(n *p,0,1),nrow=n,ncol=p)truth.test<-apply(X.test,MARGIN=1,FUN=f)ptm<-proc.time()forest<-regression_forest(X,Y,tune.parameters="none",num.trees=50)time.train<- (proc.time() -ptm)[[3]]ptm<-proc.time()preds.grf<-predict(forest,X.test)$predictionsmse.grf<-mean((preds.grf -truth.test)**2)time.grf<- (proc.time() -ptm)[[3]]ptm<-proc.time()ll.forest<-ll_regression_forest(X,Y,tune.parameters="none",enable.ll.split=TRUE ,num.trees=50)time.train.ll<- (proc.time() -ptm)[[3]]ptm<-proc.time()lasso.mod<-cv.glmnet(X,Y,alpha=1)lasso.coef<-predict(lasso.mod,type="nonzero")selected<-lasso.coef[,1]selectedtime.lasso<- (proc.time() -ptm)[[3]]ptm<-proc.time()preds.llf<-predict(ll.forest,X.test,linear.correction.variables=selected,ll.lambda=0.1)$predictionsmse.llf<-mean((preds.llf -truth.test)**2)time.llf<- (proc.time() -ptm)[[3]]cat(paste("GRF training took",round(time.train,2),"seconds. \n","GRF predictions took",round(time.train.ll,2),"seconds. \n","LLF training took",round(time.grf,2),"seconds. \n","Lasso selection took",round(time.lasso,2),"seconds. \n","LLF prediction took",round(time.llf,2),"seconds. \n","LLF and lasso all in all took",round(time.lasso +time.llf,2),"seconds. \n"))# GRF training took 342.14 seconds# LLF training took 215.64 seconds# GRF predictions took 20.25 seconds# Lasso selection took 15.10 seconds# LLF prediction took 45.40 seconds# LLF and lasso all in all took 60.50 secondscat(paste("GRF predictions had MSE",round(mse.grf,2),"\n","LLF predictions had MSE",round(mse.llf,2)))# GRF predictions had MSE 0.81# LLF predictions had MSE 0.69

[8]ページ先頭

©2009-2025 Movatter.jp