Movatterモバイル変換


[0]ホーム

URL:


Getting Started With stacks

In this article, we’ll be working through an example of the workflowof model stacking with the stacks package. At a high level, the workflowlooks something like this:

  1. Define candidate ensemble members using functionality from rsample,parsnip, workflows, recipes, and tune
  2. Initialize adata_stack object withstacks()
  3. Iteratively add candidate ensemble members to thedata_stack withadd_candidates()
  4. Evaluate how to combine their predictions withblend_predictions()
  5. Fit candidate ensemble members with non-zero stacking coefficientswithfit_members()
  6. Predict on new data withpredict()!

The package is closely integrated with the rest of the functionalityin tidymodels—we’ll load those packages as well, in addition to sometidyverse packages to evaluate our results later on.

library(tidymodels)library(stacks)

In this example, we’ll make use of thetree_frogs dataexported withstacks, giving experimental results onhatching behavior of red-eyed tree frog embryos!

Red-eyed tree frog (RETF) embryos can hatch earlier than their normal7ish days if they detect potential predator threat. Researchers wantedto determine how, and when, these tree frog embryos were able to detectstimulus from their environment. To do so, they subjected the embryos atvarying developmental stages to “predator stimulus” by jiggling theembryos with a blunt probe. Beforehand, though some of the embryos weretreated with gentamicin, a compound that knocks out their lateral line(a sensory organ.) Researcher Julie Jung and her crew found that thesefactors inform whether an embryo hatches prematurely or not!

We’ll start out with predictinglatency (i.e., time tohatch) based on other attributes. We’ll need to filter out NAs (i.e.,cases where the embryo did not hatch) first.

data("tree_frogs")# subset the datatree_frogs<- tree_frogs|>filter(!is.na(latency))|>select(-c(clutch, hatched))

Taking a quick look at the data, it seems like the hatch time ispretty closely related to some of our predictors!

theme_set(theme_bw())ggplot(tree_frogs)+aes(x = age,y = latency,color = treatment)+geom_point()+labs(x ="Embryo Age (s)",y ="Time to Hatch (s)",col ="Treatment")

Let’s give this a go!

Define candidate ensemble members

At the highest level, ensembles are formed frommodeldefinitions. In this package, model definitions are an instance ofa minimalworkflow,containing amodel specification (as defined in theparsnip package)and, optionally, apreprocessor (as defined in therecipespackage). Model definitions specify the form of candidate ensemblemembers.

Defining the constituent model definitions is undoubtedly the longestpart of building an ensemble withstacks. If you’refamiliar with tidymodels “proper,” you’re probably fine to skip thissection, keeping a few things in mind:

We’ll first start out with splitting up the training data, generatingresamples, and setting some options that will be used by each modeldefinition.

# some setup: resampling and a basic recipeset.seed(1)tree_frogs_split<-initial_split(tree_frogs)tree_frogs_train<-training(tree_frogs_split)tree_frogs_test<-testing(tree_frogs_split)set.seed(1)folds<- rsample::vfold_cv(tree_frogs_train,v =5)tree_frogs_rec<-recipe(latency~ .,data = tree_frogs_train)metric<-metric_set(rmse)

Tuning and fitting results for use in ensembles need to be fittedwith the control argumentssave_pred = TRUE andsave_workflow = TRUE—these settings ensure that theassessment set predictions, as well as the workflow used to fit theresamples, are stored in the resulting object. For convenience, stackssupplies somecontrol_stack_*() functions to generate theappropriate objects for you.

In this example, we’ll be working withtune_grid() andfit_resamples() from the tune package, so we will use thefollowing control settings:

ctrl_grid<-control_stack_grid()ctrl_res<-control_stack_resamples()

We’ll define three different model definitions to try to predict timeto hatch—a K-nearest neighbors model (with hyperparameters to tune), alinear model, and a support vector machine model (again, withhyperparameters to tune).

Starting out with K-nearest neighbors, we begin by creating aparsnip model specification:

# create a model definitionknn_spec<-nearest_neighbor(mode ="regression",neighbors =tune("k")  )|>set_engine("kknn")knn_spec

Note that, since we are tuning over several possible numbers ofneighbors, this model specification defines multiple modelconfigurations. The specific form of those configurations will bedetermined when specifying the grid search intune_grid().

From here, we extend the basic recipe defined earlier to fullyspecify the form of the design matrix for use in a K-nearest neighborsmodel:

# extend the recipeknn_rec<-  tree_frogs_rec|>step_dummy(all_nominal_predictors())|>step_zv(all_predictors())|>step_impute_mean(all_numeric_predictors())|>step_normalize(all_numeric_predictors())knn_rec

Starting with the basic recipe, we convert categorical variables todummy variables, remove column with only one observation, impute missingvalues in numeric variables using the mean, and normalize numericpredictors. Pre-processing instructions for the remaining models aredefined similarly.

Now, we combine the model specification and pre-processinginstructions defined above to form aworkflow object:

# add both to a workflowknn_wflow<-workflow()|>add_model(knn_spec)|>add_recipe(knn_rec)knn_wflow

Finally, we can make use of the workflow, training set resamples,metric set, and control object to tune our hyperparameters. Using thegrid argument, we specify that we would like to optimizeover four possible values ofk using a grid search.

# tune k and fit to the 5-fold cvset.seed(2020)knn_res<-tune_grid(    knn_wflow,resamples = folds,metrics = metric,grid =4,control = ctrl_grid  )knn_res

Thisknn_res object fully specifies the candidatemembers, and is ready to be included in astacksworkflow.

Now, specifying the linear model, note that we are not optimizingover any hyperparameters. Thus, we use thefit_resamples()function rather thantune_grid() ortune_bayes() when fitting to our resamples.

# create a model definitionlin_reg_spec<-linear_reg()|>set_engine("lm")# extend the recipelin_reg_rec<-  tree_frogs_rec|>step_dummy(all_nominal_predictors())|>step_zv(all_predictors())# add both to a workflowlin_reg_wflow<-workflow()|>add_model(lin_reg_spec)|>add_recipe(lin_reg_rec)# fit to the 5-fold cvset.seed(2020)lin_reg_res<-fit_resamples(    lin_reg_wflow,resamples = folds,metrics = metric,control = ctrl_res  )lin_reg_res

Finally, putting together the model definition for the support vectormachine:

# create a model definitionsvm_spec<-svm_rbf(cost =tune("cost"),rbf_sigma =tune("sigma")  )|>set_engine("kernlab")|>set_mode("regression")# extend the recipesvm_rec<-  tree_frogs_rec|>step_dummy(all_nominal_predictors())|>step_zv(all_predictors())|>step_impute_mean(all_numeric_predictors())|>step_corr(all_predictors())|>step_normalize(all_numeric_predictors())# add both to a workflowsvm_wflow<-workflow()|>add_model(svm_spec)|>add_recipe(svm_rec)# tune cost and sigma and fit to the 5-fold cvset.seed(2020)svm_res<-tune_grid(    svm_wflow,resamples = folds,grid =6,metrics = metric,control = ctrl_grid  )svm_res

Altogether, we’ve created three model definitions, where theK-nearest neighbors model definition specifies 4 model configurations,the linear regression specifies 1, and the support vector machinespecifies 6.

With these three model definitions fully specified, we are ready tobegin stacking these model configurations. (Note that, in most appliedsettings, one would likely specify many more than 11 candidatemembers.)

Putting together a stack

The first step to building an ensemble with stacks is to create adata_stack object—in this package, data stacks are tibbles(with some extra attributes) that contain the assessment set predictionsfor each candidate ensemble member.

We can initialize a data stack using thestacks()function.

stacks()

Thestacks() function works sort of like theggplot() constructor from ggplot2—the function creates abasic structure that the object will be built on top of—except you’llpipe the outputs rather than adding them with+.

Theadd_candidates() function adds ensemble members tothe stack.

tree_frogs_data_st<-stacks()|>add_candidates(knn_res)|>add_candidates(lin_reg_res)|>add_candidates(svm_res)tree_frogs_data_st

As mentioned before, under the hood, adata_stack objectis really just a tibble with some extra attributes. Checking out theactual data:

as_tibble(tree_frogs_data_st)

The first column gives the first response value, and the remainingcolumns give the assessment set predictions for each ensemble member.Since we’re in the regression case, there’s only one column per ensemblemember. In classification settings, there are as many columns as thereare levels of the outcome variable per candidate ensemble member.

That’s it! We’re now ready to evaluate how it is that we need tocombine predictions from each candidate ensemble member.

Fit the stack

The outputs from each of these candidate ensemble members are highlycorrelated, so theblend_predictions() function performsregularization to figure out how we can combine the outputs from thestack members to come up with a final prediction.

tree_frogs_model_st<-  tree_frogs_data_st|>blend_predictions()

Theblend_predictions function determines how membermodel output will ultimately be combined in the final prediction byfitting a LASSO model on the data stack, predicting the true assessmentset outcome using the predictions from each of the candidate members.Candidates with nonzero stacking coefficients become members.

To make sure that we have the right trade-off between minimizing thenumber of members and optimizing performance, we can use theautoplot() method:

autoplot(tree_frogs_model_st)

To show the relationship more directly:

autoplot(tree_frogs_model_st,type ="members")

If these results were not good enough,blend_predictions() could be called again with differentvalues ofpenalty. As it is,blend_predictions() picks the penalty parameter with thenumerically optimal results. To see the top results:

autoplot(tree_frogs_model_st,type ="weights")

Now that we know how to combine our model output, we can fit thecandidates with non-zero stacking coefficients on the full trainingset.

tree_frogs_model_st<-  tree_frogs_model_st|>fit_members()

Model stacks can be thought of as a group of fitted member models anda set of instructions on how to combine their predictions.

To identify which model configurations were assigned what stackingcoefficients, we can make use of thecollect_parameters()function:

collect_parameters(tree_frogs_model_st,"svm_res")

This object is now ready to predict with new data!

tree_frogs_test<-bind_cols(tree_frogs_test,predict(tree_frogs_model_st, tree_frogs_test))

Juxtaposing the predictions with the true data:

ggplot(tree_frogs_test)+aes(x = latency,y = .pred  )+geom_point()+coord_obs_pred()

Looks like our predictions were pretty strong! How do the stackspredictions perform, though, as compared to the members’ predictions? Wecan use thetype = "members" argument to generatepredictions from each of the ensemble members.

member_preds<-  tree_frogs_test|>select(latency)|>bind_cols(predict(tree_frogs_model_st, tree_frogs_test,members =TRUE))

Now, evaluating the root mean squared error from each model:

map(member_preds, rmse_vec,truth = member_preds$latency)|>as_tibble()

As we can see, the stacked ensemble outperforms each of the membermodels, though is closely followed by one of its members.

Voila! You’ve now made use of the stacks package to predict red-eyedtree frog embryo hatching using a stacked ensemble! The full visualoutline for these steps can be foundhere.


[8]ページ先頭

©2009-2025 Movatter.jp