| Title: | Tidy Tuning Tools |
| Version: | 2.0.1 |
| Description: | The ability to tune models is important. 'tune' contains functions and classes to be used in conjunction with other 'tidymodels' packages for finding reasonable values of hyper-parameters in models, pre-processing methods, and post-processing steps. |
| License: | MIT + file LICENSE |
| URL: | https://tune.tidymodels.org/,https://github.com/tidymodels/tune |
| BugReports: | https://github.com/tidymodels/tune/issues |
| Depends: | R (≥ 4.1) |
| Imports: | cli (≥ 3.3.0), dials (≥ 1.3.0.9000), dplyr (≥ 1.1.0),generics (≥ 0.1.2), ggplot2, glue (≥ 1.6.2), GPfit, hardhat(≥ 1.4.2), parallel, parsnip (≥ 1.2.1.9003), purrr (≥1.0.0), recipes (≥ 1.1.0.9001), rlang (≥ 1.1.4), rsample (≥1.3.0.9003), tailor (≥ 0.1.0), tibble (≥ 3.1.0), tidyr (≥1.2.0), tidyselect (≥ 1.1.2), vctrs (≥ 0.6.1), withr,workflows (≥ 1.3.0), yardstick (≥ 1.3.0) |
| Suggests: | C50, censored (≥ 0.3.0), covr, future (≥ 1.33.0),future.apply, kernlab, kknn, knitr, mgcv, mirai (≥ 2.4.0),modeldata, probably, scales, spelling, splines2, survival,testthat (≥ 3.0.0), xgboost, xml2 |
| Config/Needs/website: | pkgdown, tidymodels, kknn, tidyverse/tidytemplate |
| Config/testthat/edition: | 3 |
| Encoding: | UTF-8 |
| Language: | en-US |
| LazyData: | true |
| RoxygenNote: | 7.3.2 |
| NeedsCompilation: | no |
| Packaged: | 2025-10-17 13:00:40 UTC; max |
| Author: | Max Kuhn |
| Maintainer: | Max Kuhn <max@posit.co> |
| Repository: | CRAN |
| Date/Publication: | 2025-10-17 14:10:02 UTC |
Various accessor functions
Description
These functions return different attributes from objects with classtune_result.
Usage
.get_tune_parameters(x).get_tune_parameter_names(x).get_extra_col_names(x).get_tune_metrics(x).get_tune_metric_names(x).get_tune_eval_times(x).get_tune_eval_time_target(x).get_tune_outcome_names(x).get_tune_workflow(x)## S3 method for class 'tune_results'.get_fingerprint(x, ...)Arguments
x | An object of class |
Value
.get_tune_parameters()returns adialsparameterobject or a tibble..get_tune_parameter_names(),.get_tune_metric_names(), and.get_tune_outcome_names()return a character string..get_tune_metrics()returns a metric set or NULL..get_tune_workflow()returns the workflow used to fit theresamples (ifsave_workflowwas set toTRUEduring fitting) or NULL.
Save most recent results to search path
Description
Save most recent results to search path
Usage
.stash_last_result(x)Arguments
x | An object. |
Details
The function will assignx to.Last.tune.result and put it inthe search path.
Value
NULL, invisibly.
Determine if case weights should be passed on to yardstick
Description
This S3 method defines the logic for deciding when a case weight vectorshould be passed to yardstick metric functions and used to measure modelperformance. The current logic is that frequency weights (i.e.hardhat::frequency_weights()) are the only situation where this shouldoccur.
Usage
.use_case_weights_with_yardstick(x)## S3 method for class 'hardhat_importance_weights'.use_case_weights_with_yardstick(x)## S3 method for class 'hardhat_frequency_weights'.use_case_weights_with_yardstick(x)Arguments
x | A vector |
Value
A singleTRUE orFALSE.
Examples
library(parsnip)library(dplyr)frequency_weights(1:10) |> .use_case_weights_with_yardstick()importance_weights(seq(1, 10, by = .1))|> .use_case_weights_with_yardstick()Augment data with holdout predictions
Description
Fortune objects that use resampling, theseaugment() methods will addone or more columns for the hold-out predictions (i.e. from the assessmentset(s)).
Usage
## S3 method for class 'tune_results'augment(x, ..., parameters = NULL)## S3 method for class 'resample_results'augment(x, ...)## S3 method for class 'last_fit'augment(x, ...)Arguments
x | An object resulting from one of the |
... | Not currently used. |
parameters | A data frame with a single row that indicates whattuning parameters should be used to generate the predictions (for |
Details
For some resampling methods where rows may be replicated in multipleassessment sets, the prediction columns will be averages of the holdoutresults. Also, for these methods, it is possible that all rows of theoriginal data do not have holdout predictions (like a single bootstrapresample). In this case, all rows are return and a warning is issued.
For objects created bylast_fit(), the test set data and predictions arereturned.
Unlike otheraugment() methods, the predicted values for regression modelsare in a column called.pred instead of.fitted (to be consistent withother tidymodels conventions).
For regression problems, an additional.resid column is added to theresults.
Value
A data frame with one or more additional columns for modelpredictions.
Plot tuning search results
Description
Plot tuning search results
Usage
## S3 method for class 'tune_results'autoplot( object, type = c("marginals", "parameters", "performance"), metric = NULL, eval_time = NULL, width = NULL, call = rlang::current_env(), ...)Arguments
object | A tibble of results from |
type | A single character value. Choices are |
metric | A character vector or |
eval_time | A numeric vector of time points where dynamic event timemetrics should be chosen (e.g. the time-dependent ROC curve, etc). Thevalues should be consistent with the values used to create |
width | A number for the width of the confidence interval bars when |
call | The call to be displayed in warnings or errors. |
... | For plots with a regular grid, this is passed to |
Details
When the results oftune_grid() are used withautoplot(), it tries todetermine whether aregular grid was used.
Regular grids
For regular grids with one or more numeric tuning parameters, the parameterwith the most unique values is used on the x-axis. If there are categoricalparameters, the first is used to color the geometries. All other parametersare used in column faceting.
The plot has the performance metric(s) on the y-axis. If there are multiplemetrics, these are row-faceted.
If there are more than five tuning parameters, the "marginal effects" plotsare used instead.
Irregular grids
For space-filling or random grids, amarginal effect plot is created. Apanel is made for each numeric parameter so that each parameter is on thex-axis and performance is on the y-xis. If there are multiple metrics, theseare row-faceted.
A single categorical parameter is shown as colors. If there are two or morenon-numeric parameters, an error is given. A similar result occurs is onlynon-numeric parameters are in the grid. In these cases, we suggest usingcollect_metrics() andggplot() to create a plot that is appropriate forthe data.
If a parameter has an associated transformation associated with it (asdetermined by the parameter object used to create it), the plot shows thevalues in the transformed units (and is labeled with the transformation type).
Parameters are labeled using the labels found in the parameter objectexcept when an identifier was used (e.g.neighbors = tune("K")).
Value
Aggplot2 object.
See Also
Examples
# For grid search:data("example_ames_knn")# Plot the tuning parameter values versus performanceautoplot(ames_grid_search, metric = "rmse")# For iterative search:# Plot the tuning parameter values versus performanceautoplot(ames_iter_search, metric = "rmse", type = "marginals")# Plot tuning parameters versus iterationsautoplot(ames_iter_search, metric = "rmse", type = "parameters")# Plot performance over iterationsautoplot(ames_iter_search, metric = "rmse", type = "performance")Get colors for tune text.
Description
These are not intended for use by the general public.
Usage
check_rset(x)check_parameters(wflow, pset = NULL, data, grid_names = character(0))check_workflow(x, ..., pset = NULL, check_dials = FALSE, call = caller_env())check_metrics(x, object)check_initial( x, pset, wflow, resamples, metrics, eval_time, ctrl, checks = "grid")val_class_or_null(x, cls = "numeric", where = NULL)val_class_and_single(x, cls = "numeric", where = NULL).config_key_from_metrics(x)estimate_tune_results(x, ..., col_name = ".metrics")metrics_info(x)new_iteration_results( x, parameters, metrics, eval_time, eval_time_target, outcomes = character(0), rset_info, workflow)get_tune_colors()encode_set(x, pset, ..., as_matrix = FALSE)check_time(origin, limit)pull_rset_attributes(x)empty_ellipses(...)is_recipe(x)is_preprocessor(x)is_workflow(x)Arguments
x | An object. |
wflow | A |
pset | A |
data | The training data. |
grid_names | A character vector of column names from the grid. |
... | Other options |
check_dials | A logical for check for a NULL parameter object. |
object | A |
resamples | An |
metrics | A metric set. |
eval_time | A numeric vector of time points where dynamic event timemetrics should be computed (e.g. the time-dependent ROC curve, etc). |
ctrl | A |
cls | A character vector of possible classes |
where | A character string for the calling function. |
parameters | A |
outcomes | A character vector of outcome names. |
rset_info | Attributes from an |
workflow | The workflow used to fit the iteration results. |
as_matrix | A logical for the return type. |
origin | The calculation start time. |
limit | The allowable time (in minutes). |
Tools for selecting metrics and evaluation times
Description
Tools for selecting metrics and evaluation times
Usage
choose_metric(x, metric, ..., call = rlang::caller_env())check_metric_in_tune_results(mtr_info, metric, ..., call = rlang::caller_env())choose_eval_time( x, metric, ..., eval_time = NULL, quietly = FALSE, call = rlang::caller_env())maybe_choose_eval_time(x, mtr_set, eval_time)first_metric(mtr_set)first_eval_time( mtr_set, ..., metric = NULL, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()).filter_perf_metrics(x, metric, eval_time)check_metrics_arg(mtr_set, wflow, ..., call = rlang::caller_env())check_eval_time_arg(eval_time, mtr_set, ..., call = rlang::caller_env())Arguments
x | An object with class |
metric | A character value for which metric is being used. |
... | These dots are for future extensions and must be empty. |
call | The call to be displayed in warnings or errors. |
eval_time | An optional vector of times to compute dynamic and/orintegrated metrics. |
quietly | Logical. Should warnings be muffled? |
mtr_set | |
wflow |
Details
These are developer-facing functions used to compute and validate choicesfor performance metrics. For survival analysis models, there are similarfunctions for the evaluation time(s) required for dynamic and/or integratedmetrics.
choose_metric() is used with functions such asshow_best() orselect_best() where a single valid metric is required to rank models. Ifno value is given by the user, the first metric value is used (with awarning).
For evaluation times, one is only required when the metric type is dynamic(e.g.yardstick::brier_survival() oryardstick::roc_auc_survival()). Forthese metrics, we require a single numeric value that was originally givento the function used to producex (such astune_grid()).
If a time is required and none is given, the first value in the vectororiginally given in theeval_time argument is used (with a warning).
maybe_choose_eval_time() is for cases where multiple evaluation times areacceptable but you need to choose a good default. The "maybe" is becausethe function that would usemaybe_choose_eval_time() can accept multiplemetrics (likeautoplot()).
Obtain and format results produced by tuning functions
Description
Obtain and format results produced by tuning functions
Usage
collect_predictions(x, ...)## Default S3 method:collect_predictions(x, ...)## S3 method for class 'tune_results'collect_predictions(x, ..., summarize = FALSE, parameters = NULL)collect_metrics(x, ...)## S3 method for class 'tune_results'collect_metrics(x, ..., summarize = TRUE, type = c("long", "wide"))collect_notes(x, ...)## S3 method for class 'tune_results'collect_notes(x, ...)collect_extracts(x, ...)## S3 method for class 'tune_results'collect_extracts(x, ...)Arguments
x | The results of |
... | Not currently used. |
summarize | A logical; should metrics be summarized over resamples( |
parameters | An optional tibble of tuning parameter values that can beused to filter the predicted values before processing. This tibble shouldonly have columns for each tuning parameter identifier (e.g. |
type | One of |
Value
A tibble. The column names depend on the results and the mode of themodel.
Forcollect_metrics() andcollect_predictions(), when unsummarized,there are columns for each tuning parameter (using theid fromtune(),if any).
collect_metrics() also has columns.metric, and.estimator by default.Forcollect_metrics() methods that have atype argument, supplyingtype = "wide" will pivot the output such that each metric has its owncolumn. When the results are summarized, there are columns formean,n,andstd_err. When not summarized, the additional columns for the resamplingidentifier(s) and.estimate.
Forcollect_predictions(), there are additional columns for the resamplingidentifier(s), columns for the predicted values (e.g.,.pred,.pred_class, etc.), and a column for the outcome(s) using the originalcolumn name(s) in the data.
collect_predictions() can summarize the various results overreplicate out-of-sample predictions. For example, when using the bootstrap,each row in the original training set has multiple holdout predictions(across assessment sets). To convert these results to a format where everytraining set same has a single predicted value, the results are averagedover replicate predictions.
For regression cases, the numeric predictions are simply averaged.
For classification models, the problem is more complex. When class probabilitiesare used, these are averaged and then re-normalized to make sure that theyadd to one. If hard class predictions also exist in the data, then these aredetermined from the summarized probability estimates (so that they match).If only hard class predictions are in the results, then the mode is used tosummarize.
With censored outcome models, the predicted survival probabilities (if any)are averaged while the static predicted event times are summarized using themedian.
collect_notes() returns a tibble with columns for the resamplingindicators, the location (preprocessor, model, etc.), type (error or warning),and the notes.
collect_extracts() collects objects extracted from fitted workflowsvia theextract argument tocontrol functions. Thefunction returns a tibble with columns for the resamplingindicators, the location (preprocessor, model, etc.), and extracted objects.
Hyperparameters and extracted objects
When making use of submodels, tune can generate predictions and calculatemetrics for multiple model.configurations using only one model fit.However, this means that if a function was supplied to acontrol function'sextract argument, tune can onlyexecute that extraction on the one model that was fitted. As a result,in thecollect_extracts() output, tune opts to associate theextracted objects with the hyperparameter combination used tofit that one model workflow, rather than the hyperparametercombination of a submodel. In the output, this appears likea hyperparameter entry is recycled across many.configentries—this is intentional.
Seehttps://parsnip.tidymodels.org/articles/Submodels.html to learnmore about submodels.
Examples
data("example_ames_knn")# The parameters for the model:extract_parameter_set_dials(ames_wflow)# Summarized over resamplescollect_metrics(ames_grid_search)# Per-resample valuescollect_metrics(ames_grid_search, summarize = FALSE)# ---------------------------------------------------------------------------library(parsnip)library(rsample)library(dplyr)library(recipes)library(tibble)lm_mod <- linear_reg() |> set_engine("lm")set.seed(93599150)car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp, deg_free = tune("df"))grid <- tibble(df = 3:6)resampled <- lm_mod |> tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)collect_predictions(resampled) |> arrange(.row)collect_predictions(resampled, summarize = TRUE) |> arrange(.row)collect_predictions( resampled, summarize = TRUE, parameters = grid[1, ]) |> arrange(.row)collect_extracts(resampled)Calculate and format metrics from tuning functions
Description
This function computes metrics from tuning results. The arguments andoutput formats are closely related to those fromcollect_metrics(), butthis function additionally takes ametrics argument with ametric set for new metrics to compute. Thisallows for computing new performance metrics without requiring users tore-evaluate models against resamples.
Note that thecontrol optionsave_pred = TRUE musthave been supplied when generatingx.
Usage
compute_metrics(x, metrics, summarize, event_level, ...)## Default S3 method:compute_metrics(x, metrics, summarize = TRUE, event_level = "first", ...)## S3 method for class 'tune_results'compute_metrics(x, metrics, ..., summarize = TRUE, event_level = "first")Arguments
x | The results of a tuning function like |
metrics | Ametric set of new metricsto compute. See the "Details" section below for more information. |
summarize | A single logical value indicating whether metrics shouldbe summarized over resamples ( |
event_level | A single string containing either |
... | Not currently used. |
Details
Each metric in the set supplied to themetrics argument must have a metrictype (usually"numeric","class", or"prob") that matches some metricevaluated when generatingx. e.g. For example, ifx was generated withonly hard"class" metrics, this function can't compute metrics that take inclass probabilities ("prob".) By default, the tuning functions used togeneratex compute metrics of all needed types.
Value
A tibble. Seecollect_metrics() for more details on the return value.
Examples
# load needed packages:library(parsnip)library(rsample)library(yardstick)# evaluate a linear regression against resamples.# note that we pass `save_pred = TRUE`:res <- fit_resamples( linear_reg(), mpg ~ cyl + hp, bootstraps(mtcars, 5), control = control_grid(save_pred = TRUE) )# to return the metrics supplied to `fit_resamples()`:collect_metrics(res)# to compute new metrics:compute_metrics(res, metric_set(mae))# if `metrics` is the same as that passed to `fit_resamples()`,# then `collect_metrics()` and `compute_metrics()` give the same# output, though `compute_metrics()` is quite a bit slower:all.equal( collect_metrics(res), compute_metrics(res, metric_set(rmse, rsq)))Compute average confusion matrix across resamples
Description
For classification problems,conf_mat_resampled() computes a separateconfusion matrix for each resample then averages the cell counts.
Usage
conf_mat_resampled(x, ..., parameters = NULL, tidy = TRUE)Arguments
x | An object with class |
... | Currently unused, must be empty. |
parameters | A tibble with a single tuning parameter combination. Onlyone tuning parameter combination (if any were used) is allowed here. |
tidy | Should the results come back in a tibble ( |
Value
A tibble orconf_mat with the average cell count across resamples.
Examples
# example codelibrary(parsnip)library(rsample)library(dplyr)data(two_class_dat, package = "modeldata")set.seed(2393)res <- logistic_reg() |> set_engine("glm") |> fit_resamples( Class ~ ., resamples = vfold_cv(two_class_dat, v = 3), control = control_resamples(save_pred = TRUE) )conf_mat_resampled(res)conf_mat_resampled(res, tidy = FALSE)Control aspects of the Bayesian search process
Description
Control aspects of the Bayesian search process
Usage
control_bayes( verbose = FALSE, verbose_iter = FALSE, no_improve = 10L, uncertain = Inf, seed = sample.int(10^5, 1), extract = NULL, save_pred = FALSE, time_limit = NA, pkgs = NULL, save_workflow = FALSE, save_gp_scoring = FALSE, event_level = "first", parallel_over = NULL, backend_options = NULL, allow_par = TRUE)Arguments
verbose | A logical for logging results (other than warnings and errors,which are always shown) as they are generated during training in a singleR process. When using most parallel backends, this argument typically willnot result in any logging. If using a dark IDE theme, some logging messagesmight be hard to see; try setting the |
verbose_iter | A logical for logging results of the Bayesian searchprocess. Defaults to FALSE. If using a dark IDE theme, some loggingmessages might be hard to see; try setting the |
no_improve | The integer cutoff for the number of iterations withoutbetter results. |
uncertain | The number of iterations with no improvement before anuncertainty sample is created where a sample with high predicted variance ischosen (i.e., in a region that has not yet been explored). The iterationcounter is reset after each uncertainty sample. For example, if |
seed | An integer for controlling the random number stream. Tuningfunctions are sensitive to both the state of RNG set outside of tuningfunctions with |
extract | An optional function with at least one argument (or |
save_pred | A logical for whether the out-of-sample predictions shouldbe saved for each modelevaluated. |
time_limit | A number for the minimum number ofminutes (elapsed) thatthe function should execute. The elapsed time is evaluated at internalcheckpoints and, if over time, the results at that time are returned (witha warning). This means that the Note that timing begins immediately on execution. Thus, if the |
pkgs | An optional character string of R package names that should beloaded (by namespace) during parallel processing. |
save_workflow | A logical for whether the workflow should be appendedto the output as an attribute. |
save_gp_scoring | A logical to save the intermediate Gaussian processmodels for each iteration of the search. These are saved to |
event_level | A single string containing either |
parallel_over | A single string containing either If If If Note that switching between |
backend_options | An object of class |
allow_par | A logical to allow parallel processing (if a parallelbackend is registered). |
Details
Forextract, this function can be used to output the model object, therecipe (if used), or some components of either or both. When evaluated, thefunction's sole argument has a fitted workflow If the formula method is used,the recipe element will beNULL.
The results of theextract function are added to a list column in theoutput called.extracts. Each element of this list is a tibble with tuningparameter column and a list column (also called.extracts) that containsthe results of the function. If no extraction function is used, there is no.extracts column in the resulting object. Seetune_bayes() for morespecific details.
Note that forcollect_predictions(), it is possible that each row of theoriginal data point might be represented multiple times per tuningparameter. For example, if the bootstrap or repeated cross-validation areused, there will be multiple rows since the sample data point has beenevaluated multiple times. This may cause issues when merging the predictionswith the original data.
Hyperparameters and extracted objects
When making use of submodels, tune can generate predictions and calculatemetrics for multiple model.configurations using only one model fit.However, this means that if a function was supplied to acontrol function'sextract argument, tune can onlyexecute that extraction on the one model that was fitted. As a result,in thecollect_extracts() output, tune opts to associate theextracted objects with the hyperparameter combination used tofit that one model workflow, rather than the hyperparametercombination of a submodel. In the output, this appears likea hyperparameter entry is recycled across many.configentries—this is intentional.
Seehttps://parsnip.tidymodels.org/articles/Submodels.html to learnmore about submodels.
Control aspects of the grid search process
Description
Control aspects of the grid search process
Usage
control_grid( verbose = FALSE, allow_par = TRUE, extract = NULL, save_pred = FALSE, pkgs = NULL, save_workflow = FALSE, event_level = "first", parallel_over = NULL, backend_options = NULL)control_resamples( verbose = FALSE, allow_par = TRUE, extract = NULL, save_pred = FALSE, pkgs = NULL, save_workflow = FALSE, event_level = "first", parallel_over = NULL, backend_options = NULL)new_backend_options(..., class = character())Arguments
verbose | A logical for logging results (other than warnings and errors,which are always shown) as they are generated during training in a singleR process. When using most parallel backends, this argument typically willnot result in any logging. If using a dark IDE theme, some logging messagesmight be hard to see; try setting the |
allow_par | A logical to allow parallel processing (if a parallelbackend is registered). |
extract | An optional function with at least one argument (or |
save_pred | A logical for whether the out-of-sample predictions shouldbe saved for each modelevaluated. |
pkgs | An optional character string of R package names that should beloaded (by namespace) during parallel processing. |
save_workflow | A logical for whether the workflow should be appendedto the output as an attribute. |
event_level | A single string containing either |
parallel_over | A single string containing either If If If Note that switching between |
backend_options | An object of class |
Details
Forextract, this function can be used to output the model object, therecipe (if used), or some components of either or both. When evaluated, thefunction's sole argument has a fitted workflow If the formula method is used,the recipe element will beNULL.
The results of theextract function are added to a list column in theoutput called.extracts. Each element of this list is a tibble with tuningparameter column and a list column (also called.extracts) that containsthe results of the function. If no extraction function is used, there is no.extracts column in the resulting object. Seetune_bayes() for morespecific details.
Note that forcollect_predictions(), it is possible that each row of theoriginal data point might be represented multiple times per tuningparameter. For example, if the bootstrap or repeated cross-validation areused, there will be multiple rows since the sample data point has beenevaluated multiple times. This may cause issues when merging the predictionswith the original data.
control_resamples() is an alias forcontrol_grid() and is meant to beused withfit_resamples().
Hyperparameters and extracted objects
When making use of submodels, tune can generate predictions and calculatemetrics for multiple model.configurations using only one model fit.However, this means that if a function was supplied to acontrol function'sextract argument, tune can onlyexecute that extraction on the one model that was fitted. As a result,in thecollect_extracts() output, tune opts to associate theextracted objects with the hyperparameter combination used tofit that one model workflow, rather than the hyperparametercombination of a submodel. In the output, this appears likea hyperparameter entry is recycled across many.configentries—this is intentional.
Seehttps://parsnip.tidymodels.org/articles/Submodels.html to learnmore about submodels.
Control aspects of the last fit process
Description
Control aspects of the last fit process
Usage
control_last_fit(verbose = FALSE, event_level = "first", allow_par = FALSE)Arguments
verbose | A logical for logging results (other than warnings and errors,which are always shown) as they are generated during training in a singleR process. When using most parallel backends, this argument typically willnot result in any logging. If using a dark IDE theme, some logging messagesmight be hard to see; try setting the |
event_level | A single string containing either |
allow_par | A logical to allow parallel processing (if a parallelbackend is registered). |
Details
control_last_fit() is a wrapper aroundcontrol_resamples() and is meantto be used withlast_fit().
Use same scale for plots of observed vs predicted values
Description
For regression models,coord_obs_pred() can be used in a ggplot to make thex- and y-axes have the same exact scale along with an aspect ratio of one.
Usage
coord_obs_pred(ratio = 1, xlim = NULL, ylim = NULL, expand = TRUE, clip = "on")Arguments
ratio | Aspect ratio, expressed as |
xlim,ylim | Limits for the x and y axes. |
expand | Not currently used. |
clip | Should drawing be clipped to the extent of the plot panel? A settingof "on" (the default) means yes, and a setting of "off" means no. In mostcases, the default of "on" should not be changed, as setting |
Value
Aggproto object.
Examples
# example codedata(solubility_test, package = "modeldata")library(ggplot2)p <- ggplot(solubility_test, aes(x = solubility, y = prediction)) + geom_abline(lty = 2) + geom_point(alpha = 0.5)pp + coord_fixed()p + coord_obs_pred()Example Analysis of Ames Housing Data
Description
Example Analysis of Ames Housing Data
Details
These objects are the results of an analysis of the Ameshousing data. A K-nearest neighbors model was used with a smallpredictor set that included natural spline transformations oftheLongitude andLatitude predictors. The code used togenerate these examples was:
library(tidymodels)library(tune)library(AmesHousing)# ------------------------------------------------------------------------------ames <- make_ames()set.seed(4595)data_split <- initial_split(ames, strata = "Sale_Price")ames_train <- training(data_split)set.seed(2453)rs_splits <- vfold_cv(ames_train, strata = "Sale_Price")# ------------------------------------------------------------------------------ames_rec <- recipe(Sale_Price ~ ., data = ames_train) |> step_log(Sale_Price, base = 10) |> step_YeoJohnson(Lot_Area, Gr_Liv_Area) |> step_other(Neighborhood, threshold = .1) |> step_dummy(all_nominal()) |> step_zv(all_predictors()) |> step_spline_natural(Longitude, deg_free = tune("lon")) |> step_spline_natural(Latitude, deg_free = tune("lat"))knn_model <- nearest_neighbor( mode = "regression", neighbors = tune("K"), weight_func = tune(), dist_power = tune() ) |> set_engine("kknn")ames_wflow <- workflow() |> add_recipe(ames_rec) |> add_model(knn_model)ames_set <- extract_parameter_set_dials(ames_wflow) |> update(K = neighbors(c(1, 50)))set.seed(7014)ames_grid <- ames_set |> grid_max_entropy(size = 10)ames_grid_search <- tune_grid( ames_wflow, resamples = rs_splits, grid = ames_grid )set.seed(2082)ames_iter_search <- tune_bayes( ames_wflow, resamples = rs_splits, param_info = ames_set, initial = ames_grid_search, iter = 15 )important note: Since thersample split columns contain a referenceto the same data, saving them to disk can results in large object sizes whenthe object is later used. In essence, R replaces all of those references withthe actual data. For this reason, we saved zero-row tibbles in their place.This doesn't affect how we use these objects in examples but be advised thatusing somersample functions on them will cause issues.
Value
ames_wflow | A workflow object |
ames_grid_search,ames_iter_search | Results of model tuning. |
Examples
library(tune)ames_grid_searchames_iter_searchExponential decay function
Description
expo_decay() can be used to increase or decrease a function exponentiallyover iterations. This can be used to dynamically set parameters foracquisition functions as iterations of Bayesian optimization proceed.
Usage
expo_decay(iter, start_val, limit_val, slope = 1/5)Arguments
iter | An integer for the current iteration number. |
start_val | The number returned for the first iteration. |
limit_val | The number that the process converges to over iterations. |
slope | A coefficient for the exponent to control the rate of decay. Thesign of the slope controls the direction of decay. |
Details
Note that, when used with the acquisition functions intune(), a wrapperwould be required since only the first argument would be evaluated duringtuning.
Value
A single numeric value.
Examples
library(tibble)library(purrr)library(ggplot2)library(dplyr)tibble( iter = 1:40, value = map_dbl( 1:40, expo_decay, start_val = .1, limit_val = 0, slope = 1 / 5 )) |> ggplot(aes(x = iter, y = value)) + geom_path()Extract elements oftune objects
Description
These functions extract various elements from a tune object. If they donot exist yet, an error is thrown.
extract_preprocessor()returnsthe formula, recipe, or variableexpressions used for preprocessing.extract_spec_parsnip()returnsthe parsnip model specification.extract_fit_parsnip()returns theparsnip model fit object.extract_fit_engine()returns theengine specific fit embedded withina parsnip model fit. For example, when usingparsnip::linear_reg()with the"lm"engine, this returns the underlyinglmobject.extract_mold()returns the preprocessed"mold" object returnedfromhardhat::mold(). It contains information about the preprocessing,including either the prepped recipe, the formula terms object, orvariable selectors.extract_recipe()returns the recipe.Theestimatedargument specifieswhether the fitted or original recipe is returned.extract_workflow()returns theworkflow object if the control optionsave_workflow = TRUEwas used. The workflow will only have beenestimated for objects produced bylast_fit().
Usage
## S3 method for class 'last_fit'extract_workflow(x, ...)## S3 method for class 'tune_results'extract_workflow(x, ...)## S3 method for class 'tune_results'extract_spec_parsnip(x, ...)## S3 method for class 'tune_results'extract_recipe(x, ..., estimated = TRUE)## S3 method for class 'tune_results'extract_fit_parsnip(x, ...)## S3 method for class 'tune_results'extract_fit_engine(x, ...)## S3 method for class 'tune_results'extract_mold(x, ...)## S3 method for class 'tune_results'extract_preprocessor(x, ...)Arguments
x | A |
... | Not currently used. |
estimated | A logical for whether the original (unfit) recipe or thefitted recipe should be returned. |
Details
These functions supersedeextract_model().
Value
The extracted value from thetune tune_results,x, as described in thedescription section.
Examples
# example codelibrary(recipes)library(rsample)library(parsnip)set.seed(6735)tr_te_split <- initial_split(mtcars)spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp)lin_mod <- linear_reg() |> set_engine("lm")spline_res <- last_fit(lin_mod, spline_rec, split = tr_te_split)extract_preprocessor(spline_res)# The `spec` is the parsnip spec before it has been fit.# The `fit` is the fitted parsnip model.extract_spec_parsnip(spline_res)extract_fit_parsnip(spline_res)extract_fit_engine(spline_res)# The mold is returned from `hardhat::mold()`, and contains the# predictors, outcomes, and information about the preprocessing# for use on new data at `predict()` time.extract_mold(spline_res)# A useful shortcut is to extract the fitted recipe from the workflowextract_recipe(spline_res)# That is identical toidentical( extract_mold(spline_res)$blueprint$recipe, extract_recipe(spline_res))Remove some tuning parameter results
Description
For objects produced by thetune_*() functions, there may only be a subsetof tuning parameter combinations of interest. For large data sets, it might behelpful to be able to remove some results. This function trims the.metricscolumn of unwanted results as well as columns.predictions and.extracts(if they were requested).
Usage
filter_parameters(x, ..., parameters = NULL)Arguments
x | An object of class |
... | Expressions that return a logical value, and are defined in termsof the tuning parameter values. If multiple expressions are included, theyare combined with the |
parameters | A tibble of tuning parameter values that can be used tofilter the predicted values before processing. This tibble should only havecolumns for tuning parameter identifiers (e.g. |
Details
Removing some parameter combinations might affect the results ofautoplot()for the object.
Value
A version ofx where the lists columns only retain the parametercombinations inparameters or satisfied by the filtering logic.
Examples
library(dplyr)library(tibble)# For grid search:data("example_ames_knn")## -----------------------------------------------------------------------------# select all combinations using the 'rank' weighting schemeames_grid_search |> collect_metrics()filter_parameters(ames_grid_search, weight_func == "rank") |> collect_metrics()rank_only <- tibble::tibble(weight_func = "rank")filter_parameters(ames_grid_search, parameters = rank_only) |> collect_metrics()## -----------------------------------------------------------------------------# Keep only the results from the numerically best combinationames_iter_search |> collect_metrics()best_param <- select_best(ames_iter_search, metric = "rmse")ames_iter_search |> filter_parameters(parameters = best_param) |> collect_metrics()Splice final parameters into objects
Description
Thefinalize_* functions take a list or tibble of tuning parameter values andupdate objects with those values.
Usage
finalize_model(x, parameters)finalize_recipe(x, parameters)finalize_workflow(x, parameters)finalize_tailor(x, parameters)Arguments
x | A recipe,parsnip model specification,tailorpostprocessor, or workflow. |
parameters | A list or 1-row tibble of parameter values. Note that thecolumn names of the tibble should be the |
Value
An updated version ofx.
Examples
data("example_ames_knn")library(parsnip)knn_model <- nearest_neighbor( mode = "regression", neighbors = tune("K"), weight_func = tune(), dist_power = tune() ) |> set_engine("kknn")lowest_rmse <- select_best(ames_grid_search, metric = "rmse")lowest_rmseknn_modelfinalize_model(knn_model, lowest_rmse)Fit a model to the numerically optimal configuration
Description
fit_best() takes the results from model tuning and fits it to the trainingset using tuning parameters associated with the best performance.
Usage
fit_best(x, ...)## Default S3 method:fit_best(x, ...)## S3 method for class 'tune_results'fit_best( x, ..., metric = NULL, eval_time = NULL, parameters = NULL, verbose = FALSE, add_validation_set = NULL)Arguments
x | The results of class |
... | Not currently used, must be empty. |
metric | A character string (or |
eval_time | A single numeric time point where dynamic event timemetrics should be chosen (e.g., the time-dependent ROC curve, etc). Thevalues should be consistent with the values used to create |
parameters | An optional 1-row tibble of tuning parameter settings, witha column for each tuning parameter. This tibble should have columns for eachtuning parameter identifier (e.g. |
verbose | A logical for printing logging. |
add_validation_set | When the resamples embedded in |
Details
This function is a shortcut for the manual steps of:
best_param <- select_best(tune_results, metric) # or other `select_*()` wflow <- finalize_workflow(wflow, best_param) # or just `finalize_model()` wflow_fit <- fit(wflow, data_set)
Value
A fitted workflow.
Case Weights
Some models can utilize case weights during training. tidymodels currentlysupports two types of case weights: importance weights (doubles) andfrequency weights (integers). Frequency weights are used during modelfitting and evaluation, whereas importance weights are only used duringfitting.
To know if your model is capable of using case weights, create a model specand test it usingparsnip::case_weights_allowed().
To use them, you will need a numeric column in your data set that has beenpassed through eitherhardhat:: importance_weights() orhardhat::frequency_weights().
For functions such asfit_resamples() and thetune_*() functions, themodel must be contained inside of aworkflows::workflow(). To declare thatcase weights are used, invokeworkflows::add_case_weights() with thecorresponding (unquoted) column name.
From there, the packages will appropriately handle the weights during modelfitting and (if appropriate) performance estimation.
See also
last_fit() is closely related tofit_best(). They bothgive you access to a workflow fitted on the training data but are situatedsomewhat differently in the modeling workflow.fit_best() picks upafter a tuning function liketune_grid() to take you from tuning resultsto fitted workflow, ready for you to predict and assess further.last_fit()assumes you have made your choice of hyperparameters and finalized yourworkflow to then take you from finalized workflow to fitted workflow andfurther to performance assessment on the test data. Whilefit_best() givesa fitted workflow,last_fit() gives you the performance results. If youwant the fitted workflow, you can extract it from the result oflast_fit()viaextract_workflow().
Examples
library(recipes)library(rsample)library(parsnip)library(dplyr)data(meats, package = "modeldata")meats <- meats |> select(-water, -fat)set.seed(1)meat_split <- initial_split(meats)meat_train <- training(meat_split)meat_test <- testing(meat_split)set.seed(2)meat_rs <- vfold_cv(meat_train, v = 10)pca_rec <- recipe(protein ~ ., data = meat_train) |> step_normalize(all_numeric_predictors()) |> step_pca(all_numeric_predictors(), num_comp = tune())knn_mod <- nearest_neighbor(neighbors = tune()) |> set_mode("regression")ctrl <- control_grid(save_workflow = TRUE)set.seed(128)knn_pca_res <- tune_grid(knn_mod, pca_rec, resamples = meat_rs, grid = 10, control = ctrl)knn_fit <- fit_best(knn_pca_res, verbose = TRUE)predict(knn_fit, meat_test)Fit multiple models via resampling
Description
fit_resamples() computes a set of performance metrics across one or moreresamples. It does not perform any tuning (seetune_grid() andtune_bayes() for that), and is instead used for fitting a singlemodel+recipe or model+formula combination across many resamples.
Usage
fit_resamples(object, ...)## S3 method for class 'model_spec'fit_resamples( object, preprocessor, resamples, ..., metrics = NULL, eval_time = NULL, control = control_resamples())## S3 method for class 'workflow'fit_resamples( object, resamples, ..., metrics = NULL, eval_time = NULL, control = control_resamples())Arguments
object | A |
... | Currently unused. |
preprocessor | A traditional model formula or a recipe created using |
resamples | An |
metrics | A |
eval_time | A numeric vector of time points where dynamic event timemetrics should be computed (e.g. the time-dependent ROC curve, etc). Thevalues must be non-negative and should probably be no greater than thelargest event time in the training set (See Details below). |
control | A |
Case Weights
Some models can utilize case weights during training. tidymodels currentlysupports two types of case weights: importance weights (doubles) andfrequency weights (integers). Frequency weights are used during modelfitting and evaluation, whereas importance weights are only used duringfitting.
To know if your model is capable of using case weights, create a model specand test it usingparsnip::case_weights_allowed().
To use them, you will need a numeric column in your data set that has beenpassed through eitherhardhat:: importance_weights() orhardhat::frequency_weights().
For functions such asfit_resamples() and thetune_*() functions, themodel must be contained inside of aworkflows::workflow(). To declare thatcase weights are used, invokeworkflows::add_case_weights() with thecorresponding (unquoted) column name.
From there, the packages will appropriately handle the weights during modelfitting and (if appropriate) performance estimation.
Censored Regression Models
Three types of metrics can be used to assess the quality of censoredregression models:
static: the prediction is independent of time.
dynamic: the prediction is a time-specific probability (e.g., survivalprobability) and is measured at one or more particular times.
integrated: same as the dynamic metric but returns the integral of thedifferent metrics from each time point.
Which metrics are chosen by the user affects how many evaluation timesshould be specified. For example:
# Needs no `eval_time` valuemetric_set(concordance_survival)# Needs at least one `eval_time`metric_set(brier_survival)metric_set(brier_survival, concordance_survival)# Needs at least two eval_time` valuesmetric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival, brier_survival)
Values ofeval_time should be less than the largest observed eventtime in the training data. For many non-parametric models, the results beyondthe largest time corresponding to an event are constant (orNA).
Performance Metrics
To use your own performance metrics, theyardstick::metric_set() functioncan be used to pick what should be measured for each model. If multiplemetrics are desired, they can be bundled. For example, to estimate the areaunder the ROC curve as well as the sensitivity and specificity (under thetypical probability cutoff of 0.50), themetrics argument could be given:
metrics = metric_set(roc_auc, sens, spec)
Each metric is calculated for each candidate model.
If no metric set is provided, one is created:
For regression models, the root mean squared error and coefficientof determination are computed.
For classification, the area under the ROC curve and overall accuracyare computed.
Note that the metrics also determine what type of predictions are estimatedduring tuning. For example, in a classification problem, if metrics are usedthat are all associated with hard class predictions, the classificationprobabilities are not created.
The out-of-sample estimates of these metrics are contained in a list columncalled.metrics. This tibble contains a row for each metric and columnsfor the value, the estimator type, and so on.
collect_metrics() can be used for these objects to collapse the resultsover the resampled (to obtain the final resampling estimates per tuningparameter combination).
Obtaining Predictions
Whencontrol_grid(save_pred = TRUE), the output tibble contains a listcolumn called.predictions that has the out-of-sample predictions for eachparameter combination in the grid and each fold (which can be very large).
The elements of the tibble are tibbles with columns for the tuningparameters, the row number from the original data object (.row), theoutcome data (with the same name(s) of the original data), and any columnscreated by the predictions. For example, for simple regression problems, thisfunction generates a column called.pred and so on. As noted above, theprediction columns that are returned are determined by the type of metric(s)requested.
This list column can beunnested usingtidyr::unnest() or using theconvenience functioncollect_predictions().
Extracting Information
Theextract control option will result in an additional function to bereturned called.extracts. This is a list column that has tibblescontaining the results of the user's function for each tuning parametercombination. This can enable returning each model and/or recipe object thatis created during resampling. Note that this could result in a large returnobject, depending on what is returned.
The control function contains an option (extract) that can be used toretain any model or recipe that was created within the resamples. Thisargument should be a function with a single argument. The value of theargument that is given to the function in each resample is a workflowobject (seeworkflows::workflow() for more information). Severalhelper functions can be used to easily pull out the preprocessingand/or model information from the workflow, such asextract_preprocessor() andextract_fit_parsnip().
As an example, if there is interest in getting each parsnip model fit back,one could use:
extract = function (x) extract_fit_parsnip(x)
Note that the function given to theextract argument is evaluated onevery model that isfit (as opposed to every model that isevaluated).As noted above, in some cases, model predictions can be derived forsub-models so that, in these cases, not every row in the tuning parametergrid has a separate R object associated with it.
Finally, it is a good idea to include calls torequire() for packages thatare used in the function. This helps prevent failures when using parallelprocessing.
See Also
control_resamples(),collect_predictions(),collect_metrics()
Examples
library(recipes)library(rsample)library(parsnip)library(workflows)set.seed(6735)folds <- vfold_cv(mtcars, v = 5)spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp) |> step_spline_natural(wt)lin_mod <- linear_reg() |> set_engine("lm")control <- control_resamples(save_pred = TRUE)spline_res <- fit_resamples(lin_mod, spline_rec, folds, control = control)spline_resshow_best(spline_res, metric = "rmse")# You can also wrap up a preprocessor and a model into a workflow, and# supply that to `fit_resamples()` instead. Here, a workflows "variables"# preprocessor is used, which lets you supply terms using dplyr selectors.# The variables are used as-is, no preprocessing is done to them.wf <- workflow() |> add_variables(outcomes = mpg, predictors = everything()) |> add_model(lin_mod)wf_res <- fit_resamples(wf, folds)Internal functions used by other tidymodels packages
Description
These are not to be meant to be invoked directly by users.
Usage
forge_from_workflow(new_data, workflow)finalize_workflow_preprocessor(workflow, grid_preprocessor).estimate_metrics( dat, metric, param_names, outcome_name, event_level, metrics_info = metrics_info(metrics)).load_namespace(x)initialize_catalog(control, env = rlang::caller_env(), workflow = NULL).catch_and_log(.expr, ..., bad_only = FALSE, notes, catalog = TRUE)Arguments
new_data | A data frame or matrix of predictors to process. |
workflow | A workflow. |
grid_preprocessor | A tibble with parameter information. |
dat | A data set. |
metric | A metric set. |
param_names | A character vector of tuning parameter names. |
outcome_name | A character string for the column of |
event_level | A logical passed from the control function. |
metrics_info | The output of |
x | A character vector of package names. |
.expr | Code to execute. |
... | Object to pass to the internal |
bad_only | A logical for whether warnings and errors should be caught. |
notes | Character data to add to the logging. |
catalog | A logical passed to |
Get time for analysis of dynamic survival metrics
Description
Get time for analysis of dynamic survival metrics
Usage
get_metric_time(metrics, eval_time)Arguments
metrics | A metric set. |
eval_time | A vector of evaluation times. |
Internal functions to help use parallel processing
Description
Internal functions to help use parallel processing
Usage
has_non_par_pkgs(object, control, verbose = FALSE)future_installed()mirai_installed()get_future_workers(verbose)get_mirai_workers(verbose)choose_framework( object = NULL, control = NULL, verbose = FALSE, default = "mirai")get_parallel_seeds(workers)eval_mirai(.x, .f, ..., .args)par_fns(framework)Arguments
object | A workflow. |
control | A control object |
verbose | A logical for printing |
default | The default parallel processor. |
workers | The number of existing workers |
.x | A list. |
.f | A function |
...,.args | Options to pass to other functions. |
Bootstrap confidence intervals for performance metrics
Description
Using out-of-sample predictions, the bootstrap is used to create percentileconfidence intervals.
Usage
## S3 method for class 'tune_results'int_pctl( .data, metrics = NULL, eval_time = NULL, times = 1001, parameters = NULL, alpha = 0.05, allow_par = TRUE, event_level = "first", keep_replicates = FALSE, ...)Arguments
.data | A object with class |
metrics | A |
eval_time | A vector of evaluation times for censored regression models. |
times | The number of bootstrap samples. |
parameters | An optional tibble of tuning parameter values that can beused to filter the predicted values before processing. This tibble shouldonly have columns for each tuning parameter identifier (e.g. |
alpha | Level of significance. |
allow_par | A logical to allow parallel processing (if a parallelbackend is registered). |
event_level | A single string. Either |
keep_replicates | A logic for saving the individual estimates from eachbootstrap sample (as a list column called |
... | Not currently used. |
Details
For each model configuration (if any), this function takes bootstrap samplesof the out-of-sample predicted values. For each bootstrap sample, the metricsare computed and these are used to compute confidence intervals.Seersample::int_pctl() and the references therein for more details.
Note that the.estimate column is likely to be different from the resultsgiven bycollect_metrics() since a different estimator is used. Sincerandom numbers are used in sampling, set the random number seed prior torunning this function.
The number of bootstrap samples should be large to have reliable intervals.The defaults reflect the fewest samples that should be used.
The computations for each configuration can be extensive. To increasecomputational efficiency parallel processing can be used. Thefuturepackage is used here. To execute the resampling iterations in parallel,specify aplan with future first. Theallow_par argumentcan be used to avoid parallelism.
Also, if a censored regression model used numerous evaluation times, thecomputations can take a long time unless the times are filtered with theeval_time argument.
Value
A tibble of metrics with additional columns for.lower and.upper (and potentially,.values).
References
Davison, A., & Hinkley, D. (1997).Bootstrap Methods and theirApplication. Cambridge: Cambridge University Press.doi:10.1017/CBO9780511802843
See Also
Examples
if (rlang::is_installed("modeldata")) { data(Sacramento, package = "modeldata") library(rsample) library(parsnip) set.seed(13) sac_rs <- vfold_cv(Sacramento) lm_res <- linear_reg() |> fit_resamples( log10(price) ~ beds + baths + sqft + type + latitude + longitude, resamples = sac_rs, control = control_resamples(save_pred = TRUE) ) set.seed(31) int_pctl(lm_res)}Fit the final best model to the training set and evaluate the test set
Description
last_fit() emulates the process where, after determining the best model,the final fit on the entire training set is needed and is then evaluated onthe test set.
Usage
last_fit(object, ...)## S3 method for class 'model_spec'last_fit( object, preprocessor, split, ..., metrics = NULL, eval_time = NULL, control = control_last_fit(), add_validation_set = FALSE)## S3 method for class 'workflow'last_fit( object, split, ..., metrics = NULL, eval_time = NULL, control = control_last_fit(), add_validation_set = FALSE)Arguments
object | A |
... | Currently unused. |
preprocessor | A traditional model formula or a recipe created using |
split | An |
metrics | A |
eval_time | A numeric vector of time points where dynamic event timemetrics should be computed (e.g. the time-dependent ROC curve, etc). Thevalues must be non-negative and should probably be no greater than thelargest event time in the training set (See Details below). |
control | A |
add_validation_set | For 3-way splits into training, validation, and testset via |
Details
This function is intended to be used after fitting avariety of modelsand the final tuning parameters (if any) have been finalized. The next stepwould be to fit using the entire training set and verify performance usingthe test data.
Value
A single row tibble that emulates the structure offit_resamples().However, a list column called.workflow is also attached with the fittedmodel (and recipe, if any) that used the training set. Helper functionsfor formatting tuning results likecollect_metrics() andcollect_predictions() can be used withlast_fit() output.
Case Weights
Some models can utilize case weights during training. tidymodels currentlysupports two types of case weights: importance weights (doubles) andfrequency weights (integers). Frequency weights are used during modelfitting and evaluation, whereas importance weights are only used duringfitting.
To know if your model is capable of using case weights, create a model specand test it usingparsnip::case_weights_allowed().
To use them, you will need a numeric column in your data set that has beenpassed through eitherhardhat:: importance_weights() orhardhat::frequency_weights().
For functions such asfit_resamples() and thetune_*() functions, themodel must be contained inside of aworkflows::workflow(). To declare thatcase weights are used, invokeworkflows::add_case_weights() with thecorresponding (unquoted) column name.
From there, the packages will appropriately handle the weights during modelfitting and (if appropriate) performance estimation.
Censored Regression Models
Three types of metrics can be used to assess the quality of censoredregression models:
static: the prediction is independent of time.
dynamic: the prediction is a time-specific probability (e.g., survivalprobability) and is measured at one or more particular times.
integrated: same as the dynamic metric but returns the integral of thedifferent metrics from each time point.
Which metrics are chosen by the user affects how many evaluation timesshould be specified. For example:
# Needs no `eval_time` valuemetric_set(concordance_survival)# Needs at least one `eval_time`metric_set(brier_survival)metric_set(brier_survival, concordance_survival)# Needs at least two eval_time` valuesmetric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival, brier_survival)
Values ofeval_time should be less than the largest observed eventtime in the training data. For many non-parametric models, the results beyondthe largest time corresponding to an event are constant (orNA).
See also
last_fit() is closely related tofit_best(). They bothgive you access to a workflow fitted on the training data but are situatedsomewhat differently in the modeling workflow.fit_best() picks upafter a tuning function liketune_grid() to take you from tuning resultsto fitted workflow, ready for you to predict and assess further.last_fit()assumes you have made your choice of hyperparameters and finalized yourworkflow to then take you from finalized workflow to fitted workflow andfurther to performance assessment on the test data. Whilefit_best() givesa fitted workflow,last_fit() gives you the performance results. If youwant the fitted workflow, you can extract it from the result oflast_fit()viaextract_workflow().
Examples
library(recipes)library(rsample)library(parsnip)set.seed(6735)tr_te_split <- initial_split(mtcars)spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp)lin_mod <- linear_reg() |> set_engine("lm")spline_res <- last_fit(lin_mod, spline_rec, split = tr_te_split)spline_res# test set metricscollect_metrics(spline_res)# test set predictionscollect_predictions(spline_res)# or use a workflowlibrary(workflows)spline_wfl <- workflow() |> add_recipe(spline_rec) |> add_model(lin_mod)last_fit(spline_wfl, split = tr_te_split)Quietly load package namespace
Description
For one or more packages, load the namespace. This is used during parallelprocessing since the different parallel backends handle the packageenvironments differently.
Usage
load_pkgs(x, ..., infra = TRUE)Arguments
x | A character vector of packages. |
infra | Should base tidymodels packages be loaded as well? |
Value
An invisible NULL.
Merge parameter grid values into objects
Description
merge() can be used to easily update any of the arguments in aparsnip model or recipe.
Usage
## S3 method for class 'recipe'merge(x, y, ...)## S3 method for class 'model_spec'merge(x, y, ...)Arguments
x | A recipe or model specification object. |
y | A data frame or a parameter grid resulting from one of the |
... | Not used but required for S3 completeness. |
Value
A tibble with a columnx that has as many rows as were iny.
Examples
library(tibble)library(recipes)library(parsnip)library(dials)pca_rec <- recipe(mpg ~ ., data = mtcars) |> step_impute_knn(all_predictors(), neighbors = tune()) |> step_pca(all_predictors(), num_comp = tune())pca_grid <- tribble( ~neighbors, ~num_comp, 1, 1, 5, 1, 1, 2, 5, 2 )merge(pca_rec, pca_grid)spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp, deg_free = tune("disp df")) |> step_spline_natural(wt, deg_free = tune("wt df"))spline_grid <- tribble( ~"disp df", ~ "wt df", 3, 3, 5, 3, 3, 5, 5, 5 )merge(pca_rec, pca_grid)data(hpc_data, package = "modeldata")xgb_mod <- boost_tree(trees = tune(), min_n = tune()) |> set_engine("xgboost")set.seed(254)xgb_grid <- extract_parameter_set_dials(xgb_mod) |> finalize(hpc_data) |> grid_max_entropy(size = 3)merge(xgb_mod, xgb_grid)Write a message that respects the line width
Description
Write a message that respects the line width
Usage
message_wrap( x, width = options()$width - 2, prefix = "", color_text = NULL, color_prefix = color_text)Arguments
x | A character string of the message text. |
width | An integer for the width. |
prefix | An optional string to go on the first line of the message. |
color_text,color_prefix | A function (or |
Value
The processed text is returned (invisibly) but a message is written.
Examples
library(cli)Gaiman <- paste( '"Good point." Bod was pleased with himself, and glad he had thought of', "asking the poet for advice. Really, he thought, if you couldn't trust a", "poet to offer sensible advice, who could you trust?", collapse = "" )message_wrap(Gaiman)message_wrap(Gaiman, width = 20, prefix = "-")message_wrap(Gaiman, width = 30, prefix = "-", color_text = cli::col_silver)message_wrap(Gaiman, width = 30, prefix = "-", color_text = cli::style_underline, color_prefix = cli::col_green)Determine the minimum set of model fits
Description
min_grid() determines exactly what models should be fit in order toevaluate the entire set of tuning parameter combinations. This is forinternal use only and the API may change in the near future.
Usage
## S3 method for class 'model_spec'min_grid(x, grid, ...)fit_max_value(x, grid, ...)## S3 method for class 'boost_tree'min_grid(x, grid, ...)## S3 method for class 'linear_reg'min_grid(x, grid, ...)## S3 method for class 'logistic_reg'min_grid(x, grid, ...)## S3 method for class 'mars'min_grid(x, grid, ...)## S3 method for class 'multinom_reg'min_grid(x, grid, ...)## S3 method for class 'proportional_hazards'min_grid(x, grid, ...)## S3 method for class 'nearest_neighbor'min_grid(x, grid, ...)## S3 method for class 'cubist_rules'min_grid(x, grid, ...)## S3 method for class 'C5_rules'min_grid(x, grid, ...)## S3 method for class 'rule_fit'min_grid(x, grid, ...)## S3 method for class 'pls'min_grid(x, grid, ...)## S3 method for class 'poisson_reg'min_grid(x, grid, ...)Arguments
x | A model specification. |
grid | A tibble with tuning parameter combinations. |
... | Not currently used. |
Details
fit_max_value() can be used in other packages to implement amin_grid()method.
Value
A tibble with the minimum tuning parameters to fit and an additionallist column with the parameter combinations used for prediction.
Examples
library(dplyr)library(dials)library(parsnip)## -----------------------------------------------------------------------------## No ability to exploit submodels:svm_spec <- svm_poly(cost = tune(), degree = tune()) |> set_engine("kernlab") |> set_mode("regression")svm_grid <- svm_spec |> extract_parameter_set_dials() |> grid_regular(levels = 3)min_grid(svm_spec, svm_grid)## -----------------------------------------------------------------------------## Can use submodelsxgb_spec <- boost_tree(trees = tune(), min_n = tune()) |> set_engine("xgboost") |> set_mode("regression")xgb_grid <- xgb_spec |> extract_parameter_set_dials() |> grid_regular(levels = 3)min_grid(xgb_spec, xgb_grid)Determine names of the outcome data in a workflow
Description
Determine names of the outcome data in a workflow
Usage
outcome_names(x, ...)## S3 method for class 'terms'outcome_names(x, ...)## S3 method for class 'formula'outcome_names(x, ...)## S3 method for class 'recipe'outcome_names(x, ...)## S3 method for class 'workflow'outcome_names(x, ...)## S3 method for class 'tune_results'outcome_names(x, ...)## S3 method for class 'workflow_variables'outcome_names(x, data = NULL, ...)Arguments
x | An object. |
... | Not used. |
Value
A character string of variable names
Examples
library(dplyr)lm(cbind(mpg, wt) ~ ., data = mtcars) |> purrr::pluck(terms) |> outcome_names()Support for parallel processing in tune
Description
tune can enable simultaneous parallel computations. Tierney (2008)defined different classes of parallel processing techniques:
Implicit is when a function uses low-level tools to perform acalculation that is small in scope in parallel. Examples are usingmultithreaded linear algebra libraries (e.g., BLAS) or basic R vectorizationfunctions.
Explicit parallelization occurs when the user requests that somecalculations should be run by generating multiple new R (sub)processes. Thesecalculations can be more complex than those for implicit parallelprocessing.
For example, some decision tree libraries can implicitly parallelize theirsearch for the optimal splitting routine using multiple threads.
Alternatively, if you are resampling a modelB times, you can explicitlycreateB new R jobs to trainB boosted trees in parallel and return theirresampling results to the main R process (e.g.,fit_resamples()).
There are two frameworks that can be used to explicitly parallel processyour work intune: thefuture package and themirai package. Previously, you could use theforeach package, but this has been deprecated as ofversion 1.2.1 of tune.
By default, no parallelism is used to process models intune; you haveto opt-in.
Using future
You should install the package and choose your flavor of parallelism usingtheplan function. This allows you to specify the number ofworker processes and the specific technology to use.
For example, you can use:
library(future) plan(multisession, workers = 4)
and work will be conducted simultaneously (unless there is an exception; seethe section below).
If you had previously usedforeach, this would replace your existingcode that probably looked like:
library(doBackend) registerDoBackend(cores = 4)
Seefuture::plan() for possible options other thanmultisession.
Note thattune resets themaximum limit of memory of global variables(e.g., attached packages) to be greater than the default when the package isloaded. This value can be altered usingoptions(future.globals.maxSize).
If you wantfuture to usemirai parallel workers, you caninstall and load thefuture.mirai package.
Using mirai
To set the specific for parallel processing withmirai, use themirai::daemons() function. The first argument,n, determines the numberof parallel workers. Usingdaemons(0) reverts to sequential processing.
The argumentsurl andremote are used to set up and launch parallelprocesses over the network for distributed computing. Seemirai::daemons()documentation for more details.
Reverting to sequential processing
There are a few times when you might specify that you wish to use parallelprocessing, but it will revert to sequential execution:
Many of the control functions (e.g.
control_grid()) have an argumentcalledallow_par. If this is set toFALSE, parallel backends willalways be ignored.Some packages, such asrJava andkeras are not compatible withexplicit parallelization. If any of these packages are used, sequentialprocessing occurs.
If you specify fewer than two workers, or if there is only a single task,the computations will occur sequentially.
Expectations for reproducibility
We advise that youalways runset.seed() with a seed value just prior tousing a function that uses (or might use) random numbers. Given this:
You should expect to get the same results if you run that section of coderepeatedly, conditional on using version 1.4.0 of tune.
You should expect differences in results between version 1.4.0 of tune andprevious versions.
When using
last_fit(), you should be able to get the same results asmanually usinggenerics::fit()andpredict()to do the same work.When running with or without parallel processing (using any backendpackage), you should be able to achieve the same results from
fit_resamples()and the various tuning functions.
Specific exceptions:
For SVM classification models using thekernlab package, the randomnumber generator is independent of R, and there is no argument to controlit. Unfortunately, it is likely to give you different results fromrun-to-run.
For some deep learning packages (e.g.,tensorflow,keras, andtorch), it is very difficult to achieve reproducible results. Thisis especially true when using GPUs for computations. Additionally, we haveseen differences in computations (stochastic or non-random) betweenplatforms due to the packages' use of different numerical toleranceconstants across operating systems.
Handling package dependencies
tune knows what packages are required to fit a workflow object.
When computations are run sequentially, an initial check is made to see ifthey are installed. This triggers the packages to be loaded but not visiblein the search path.
In parallel, the required packages are fully loaded (i.e., loaded and seenin the search path), as they were previously withforeach, in theworker processes (but not the main R session).
References
https://www.tmwr.org/grid-search#parallel-processing
Tierney, Luke. "Implicit and explicit parallel computing in R." COMPSTAT2008: Proceedings in Computational Statistics. Physica-Verlag HD, 2008.
Acquisition function for scoring parameter combinations
Description
These functions can be used to score candidate tuning parameter combinationsas a function of their predicted mean and variation.
Usage
prob_improve(trade_off = 0, eps = .Machine$double.eps)exp_improve(trade_off = 0, eps = .Machine$double.eps)conf_bound(kappa = 0.1)Arguments
trade_off | A number or function that describes the trade-off betweenexploitation and exploration. Smaller values favor exploitation. |
eps | A small constant to avoid division by zero. |
kappa | A positive number (or function) that corresponds to themultiplier of the standard deviation in a confidence bound (e.g. 1.96 innormal-theory 95 percent confidence intervals). Smaller values lean moretowards exploitation. |
Details
The acquisition functions often combine the mean and variancepredictions from the Gaussian process model into an objective to beoptimized.
For this documentation, we assume that the metric in question is better whenmaximized (e.g. accuracy, the coefficient of determination, etc).
The expected improvement of a pointx is based on the predicted mean andvariation at that point as well as the current best value (denoted here asx_b). The vignette linked below contains the formulas for this acquisitionfunction. When thetrade_off parameter is greater than zero, theacquisition function will down-play the effect of themean prediction andgive more weight to the variation. This has the effect of searching for newparameter combinations that are in areas that have yet to be sampled.
Note that forexp_improve() andprob_improve(), thetrade_off value isin the units of the outcome. The functions are parameterized so that thetrade_off value should always be non-negative.
The confidence bound function does not take into account the current bestresults in the data.
If a function is passed toexp_improve() orprob_improve(), the functioncan have multiple arguments but only the first (the current iteration number)is given to the function. In other words, the function argument should havedefaults for all but the first argument. Seeexpo_decay() as an example ofa function.
Value
An object of classprob_improve,exp_improve, orconf_boundsalong with an extra class ofacquisition_function.
See Also
Examples
prob_improve()Objects exported from other packages
Description
These objects are imported from other packages. Follow the linksbelow to see their documentation.
- dials
- dplyr
- generics
- ggplot2
- hardhat
extract_fit_engine,extract_fit_parsnip,extract_mold,extract_parameter_set_dials,extract_preprocessor,extract_recipe,extract_spec_parsnip,extract_workflow,tune- rsample
Schedule a grid
Description
Schedule a grid
Usage
schedule_grid(grid, wflow)Arguments
grid | A tibble containing the parameter grid. |
wflow | The workflow object for which we schedule the grid. |
Value
A schedule object, inheriting from either 'single_schedule','grid_schedule', or 'resample_schedule'.
Investigate best tuning parameters
Description
show_best() displays the top sub-models and their performance estimates.
select_best() finds the tuning parameter combination with the bestperformance values.
select_by_one_std_err() uses the "one-standard error rule" (Breiman _elat, 1984) that selects the most simple model that is within one standarderror of the numerically optimal results.
select_by_pct_loss() selects the most simple model whose loss ofperformance is within some acceptable limit.
Usage
show_best(x, ...)## Default S3 method:show_best(x, ...)## S3 method for class 'tune_results'show_best( x, ..., metric = NULL, eval_time = NULL, n = 5, call = rlang::current_env())select_best(x, ...)## Default S3 method:select_best(x, ...)## S3 method for class 'tune_results'select_best(x, ..., metric = NULL, eval_time = NULL)select_by_pct_loss(x, ...)## Default S3 method:select_by_pct_loss(x, ...)## S3 method for class 'tune_results'select_by_pct_loss(x, ..., metric = NULL, eval_time = NULL, limit = 2)select_by_one_std_err(x, ...)## Default S3 method:select_by_one_std_err(x, ...)## S3 method for class 'tune_results'select_by_one_std_err(x, ..., metric = NULL, eval_time = NULL)Arguments
x | The results of |
... | For |
metric | A character value for the metric that will be used to sortthe models. (Seehttps://yardstick.tidymodels.org/articles/metric-types.html formore details). Not required if a single metric exists in |
eval_time | A single numeric time point where dynamic event timemetrics should be chosen (e.g., the time-dependent ROC curve, etc). Thevalues should be consistent with the values used to create |
n | An integer for the number of top results/rows to return. |
call | The call to be shown in errors and warnings. |
limit | The limit of loss of performance that is acceptable (in percentunits). See details below. |
Details
For percent loss, suppose the best model has an RMSE of 0.75 and a simplermodel has an RMSE of 1. The percent loss would be(1.00 - 0.75)/1.00 * 100,or 25 percent. Note that loss will always be non-negative.
Value
A tibble with columns for the parameters.show_best() alsoincludes columns for performance metrics.
References
Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984).Classification and Regression Trees. Monterey, CA: Wadsworth.
Examples
data("example_ames_knn")show_best(ames_iter_search, metric = "rmse")select_best(ames_iter_search, metric = "rsq")# To find the least complex model within one std error of the numerically# optimal model, the number of nearest neighbors are sorted from the largest# number of neighbors (the least complex class boundary) to the smallest# (corresponding to the most complex model).select_by_one_std_err(ames_grid_search, metric = "rmse", desc(K))# Now find the least complex model that has no more than a 5% loss of RMSE:select_by_pct_loss( ames_grid_search, metric = "rmse", limit = 5, desc(K))Display distinct errors from tune objects
Description
Display distinct errors from tune objects
Usage
show_notes(x, n = 10)Arguments
x | An object of class |
n | An integer for how many unique notes to show. |
Value
Invisibly,x. Function is called for side-effects and printing.
Bayesian optimization of model parameters.
Description
tune_bayes() uses models to generate new candidate tuning parametercombinations based on previous results.
Usage
tune_bayes(object, ...)## S3 method for class 'model_spec'tune_bayes( object, preprocessor, resamples, ..., iter = 10, param_info = NULL, metrics = NULL, eval_time = NULL, objective = exp_improve(), initial = 5, control = control_bayes())## S3 method for class 'workflow'tune_bayes( object, resamples, ..., iter = 10, param_info = NULL, metrics = NULL, eval_time = NULL, objective = exp_improve(), initial = 5, control = control_bayes())Arguments
object | A |
... | Options to pass to |
preprocessor | A traditional model formula or a recipe created using |
resamples | An |
iter | The maximum number of search iterations. |
param_info | A |
metrics | A |
eval_time | A numeric vector of time points where dynamic event timemetrics should be computed (e.g. the time-dependent ROC curve, etc). Thevalues must be non-negative and should probably be no greater than thelargest event time in the training set (See Details below). |
objective | A character string for what metric should be optimized oran acquisition function object. |
initial | An initial set of results in a tidy format (as would resultfrom |
control | A control object created by |
Details
The optimization starts with a set of initial results, such as thosegenerated bytune_grid(). If none exist, the function will create severalcombinations and obtain their performance estimates.
Using one of the performance estimates as themodel outcome, a Gaussianprocess (GP) model is created where the previous tuning parameter combinationsare used as the predictors.
A large grid of potential hyperparameter combinations is predicted usingthe model and scored using anacquisition function. These functionsusually combine the predicted mean and variance of the GP to decide the bestparameter combination to try next. For more information, see thedocumentation forexp_improve() and the corresponding package vignette.
The best combination is evaluated using resampling and the process continues.
Value
A tibble of results that mirror those generated bytune_grid().However, these results contain an.iter column and replicate thersetobject multiple times over iterations (at limited additional memory costs).
Parallel Processing
tune supports parallel processing with thefuture package. To executethe resampling iterations in parallel, specify aplan withfuture first. Theallow_par argument can be used to avoid parallelism.
For the most part, warnings generated during training are shown as they occurand are associated with a specific resample whencontrol_bayes(verbose = TRUE). They are (usually) not aggregated until theend of processing.
For Bayesian optimization, parallel processing is used to estimate theresampled performance values once a new candidate set of values are estimated.
Initial Values
The results oftune_grid(), or a previous run oftune_bayes() can be usedin theinitial argument.initial can also be a positive integer. In thiscase, a space-filling design will be used to populate a preliminary set ofresults. For good results, the number of initial values should be more thanthe number of parameters being optimized.
The tuning parameter combinations that were tested are calledcandidates.Each candidate has a unique.config value that, for the initial grid search,has the patternpre{num}_mod{num}_post{num}. The numbers include a zerowhen that element was static. For example, a value ofpre0_mod3_post4 meansno preprocessors were tuned and the model and postprocessor(s) had at leastthree and four candidates, respectively. The iterative part of thesearch uses the patterniter{num}. In each case, the numbers arezero-padded to enable proper sorting.
Parameter Ranges and Values
In some cases, the tuning parameter values depend on the dimensions of thedata (they are said to containunknown values). Forexample,mtry in random forest models depends on the number of predictors.In such cases, the unknowns in the tuning parameter object must be determinedbeforehand and passed to the function via theparam_info argument.dials::finalize() can be used to derive the data-dependent parameters.Otherwise, a parameter set can be created viadials::parameters(), and thedialsupdate() function can be used to specify the ranges or values.
Performance Metrics
To use your own performance metrics, theyardstick::metric_set() functioncan be used to pick what should be measured for each model. If multiplemetrics are desired, they can be bundled. For example, to estimate the areaunder the ROC curve as well as the sensitivity and specificity (under thetypical probability cutoff of 0.50), themetrics argument could be given:
metrics = metric_set(roc_auc, sens, spec)
Each metric is calculated for each candidate model.
If no metric set is provided, one is created:
For regression models, the root mean squared error and coefficientof determination are computed.
For classification, the area under the ROC curve and overall accuracyare computed.
Note that the metrics also determine what type of predictions are estimatedduring tuning. For example, in a classification problem, if metrics are usedthat are all associated with hard class predictions, the classificationprobabilities are not created.
The out-of-sample estimates of these metrics are contained in a list columncalled.metrics. This tibble contains a row for each metric and columnsfor the value, the estimator type, and so on.
collect_metrics() can be used for these objects to collapse the resultsover the resampled (to obtain the final resampling estimates per tuningparameter combination).
Obtaining Predictions
Whencontrol_bayes(save_pred = TRUE), the output tibble contains a listcolumn called.predictions that has the out-of-sample predictions for eachparameter combination in the grid and each fold (which can be very large).
The elements of the tibble are tibbles with columns for the tuningparameters, the row number from the original data object (.row), theoutcome data (with the same name(s) of the original data), and any columnscreated by the predictions. For example, for simple regression problems, thisfunction generates a column called.pred and so on. As noted above, theprediction columns that are returned are determined by the type of metric(s)requested.
This list column can beunnested usingtidyr::unnest() or using theconvenience functioncollect_predictions().
Case Weights
Some models can utilize case weights during training. tidymodels currentlysupports two types of case weights: importance weights (doubles) andfrequency weights (integers). Frequency weights are used during modelfitting and evaluation, whereas importance weights are only used duringfitting.
To know if your model is capable of using case weights, create a model specand test it usingparsnip::case_weights_allowed().
To use them, you will need a numeric column in your data set that has beenpassed through eitherhardhat:: importance_weights() orhardhat::frequency_weights().
For functions such asfit_resamples() and thetune_*() functions, themodel must be contained inside of aworkflows::workflow(). To declare thatcase weights are used, invokeworkflows::add_case_weights() with thecorresponding (unquoted) column name.
From there, the packages will appropriately handle the weights during modelfitting and (if appropriate) performance estimation.
Censored Regression Models
Three types of metrics can be used to assess the quality of censoredregression models:
static: the prediction is independent of time.
dynamic: the prediction is a time-specific probability (e.g., survivalprobability) and is measured at one or more particular times.
integrated: same as the dynamic metric but returns the integral of thedifferent metrics from each time point.
Which metrics are chosen by the user affects how many evaluation timesshould be specified. For example:
# Needs no `eval_time` valuemetric_set(concordance_survival)# Needs at least one `eval_time`metric_set(brier_survival)metric_set(brier_survival, concordance_survival)# Needs at least two eval_time` valuesmetric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival, brier_survival)
Values ofeval_time should be less than the largest observed eventtime in the training data. For many non-parametric models, the results beyondthe largest time corresponding to an event are constant (orNA).
Optimizing Censored Regression Models
With dynamic performance metrics (e.g. Brier or ROC curves), performance iscalculated for every value ofeval_time but thefirst evaluation timegiven by the user (e.g.,eval_time[1]) is used to guide the optimization.
Extracting Information
Theextract control option will result in an additional function to bereturned called.extracts. This is a list column that has tibblescontaining the results of the user's function for each tuning parametercombination. This can enable returning each model and/or recipe object thatis created during resampling. Note that this could result in a large returnobject, depending on what is returned.
The control function contains an option (extract) that can be used toretain any model or recipe that was created within the resamples. Thisargument should be a function with a single argument. The value of theargument that is given to the function in each resample is a workflowobject (seeworkflows::workflow() for more information). Severalhelper functions can be used to easily pull out the preprocessingand/or model information from the workflow, such asextract_preprocessor() andextract_fit_parsnip().
As an example, if there is interest in getting each parsnip model fit back,one could use:
extract = function (x) extract_fit_parsnip(x)
Note that the function given to theextract argument is evaluated onevery model that isfit (as opposed to every model that isevaluated).As noted above, in some cases, model predictions can be derived forsub-models so that, in these cases, not every row in the tuning parametergrid has a separate R object associated with it.
Finally, it is a good idea to include calls torequire() for packages thatare used in the function. This helps prevent failures when using parallelprocessing.
See Also
control_bayes(),tune(),autoplot.tune_results(),show_best(),select_best(),collect_predictions(),collect_metrics(),prob_improve(),exp_improve(),conf_bound(),fit_resamples()
Examples
library(recipes)library(rsample)library(parsnip)# define resamples and minimal recipe on mtcarsset.seed(6735)folds <- vfold_cv(mtcars, v = 5)car_rec <- recipe(mpg ~ ., data = mtcars) |> step_normalize(all_predictors())# define an svm with parameters to tunesvm_mod <- svm_rbf(cost = tune(), rbf_sigma = tune()) |> set_engine("kernlab") |> set_mode("regression")# use a space-filling design with 6 pointsset.seed(3254)svm_grid <- tune_grid(svm_mod, car_rec, folds, grid = 6)show_best(svm_grid, metric = "rmse")# use bayesian optimization to evaluate at 6 more pointsset.seed(8241)svm_bayes <- tune_bayes(svm_mod, car_rec, folds, initial = svm_grid, iter = 6)# note that bayesian optimization evaluated parameterizations# similar to those that previously decreased rmse in svm_gridshow_best(svm_bayes, metric = "rmse")# specifying `initial` as a numeric rather than previous tuning results# will result in `tune_bayes` initially evaluating an space-filling# grid using `tune_grid` with `grid = initial`set.seed(0239)svm_init <- tune_bayes(svm_mod, car_rec, folds, initial = 6, iter = 6)show_best(svm_init, metric = "rmse")Model tuning via grid search
Description
tune_grid() computes a set of performance metrics (e.g. accuracy or RMSE)for a pre-defined set of tuning parameters that correspond to a model orrecipe across one or more resamples of the data.
Usage
tune_grid(object, ...)## S3 method for class 'model_spec'tune_grid( object, preprocessor, resamples, ..., param_info = NULL, grid = 10, metrics = NULL, eval_time = NULL, control = control_grid())## S3 method for class 'workflow'tune_grid( object, resamples, ..., param_info = NULL, grid = 10, metrics = NULL, eval_time = NULL, control = control_grid())Arguments
object | A |
... | Not currently used. |
preprocessor | A traditional model formula or a recipe created using |
resamples | An |
param_info | A |
grid | A data frame of tuning combinations or a positive integer. Thedata frame should have columns for each parameter being tuned and rows fortuning parameter candidates. An integer denotes the number of candidateparameter sets to be created automatically. |
metrics | A |
eval_time | A numeric vector of time points where dynamic event timemetrics should be computed (e.g. the time-dependent ROC curve, etc). Thevalues must be non-negative and should probably be no greater than thelargest event time in the training set (See Details below). |
control | An object used to modify the tuning process, likely createdby |
Details
Suppose there arem tuning parameter combinations.tune_grid() may notrequire allm model/recipe fits across each resample. For example:
In cases where a single model fit can be used to make predictionsfor different parameter values in the grid, only one fit is used.For example, for some boosted trees, if 100 iterations of boostingare requested, the model object for 100 iterations can be used tomake predictions on iterations less than 100 (if all otherparameters are equal).
When the model is being tuned in conjunction with pre-processingand/or post-processing parameters, the minimum number of fits areused. For example, if the number of PCA components in a recipe stepare being tuned over three values (along with model tuningparameters), only three recipes are trained. The alternativewould be to re-train the same recipe multiple times for each modeltuning parameter.
tune supports parallel processing with thefuture package. To executethe resampling iterations in parallel, specify aplan withfuture first. Theallow_par argument can be used to avoid parallelism.
For the most part, warnings generated during training are shown as they occurand are associated with a specific resample whencontrol_grid(verbose = TRUE). They are (usually) not aggregated until theend of processing.
Value
An updated version ofresamples with extra list columns for.metrics and.notes (optional columns are.predictions and.extracts)..notescontains warnings and errors that occur during execution.
Parameter Grids
If no tuning grid is provided, a grid (viadials::grid_space_filling()) iscreated with 10 candidate parameter combinations.
When provided, the grid should have column names for each parameter andthese should be named by the parameter name orid. For example, if aparameter is marked for optimization usingpenalty = tune(), there shouldbe a column namedpenalty. If the optional identifier is used, such aspenalty = tune(id = 'lambda'), then the corresponding column name shouldbelambda.
In some cases, the tuning parameter values depend on the dimensions of thedata. For example,mtry in random forest models depends on the number ofpredictors. In this case, the default tuning parameter object requires anupper range.dials::finalize() can be used to derive the data-dependentparameters. Otherwise, a parameter set can be created (viadials::parameters()) and thedialsupdate() function can be used tochange the values. This updated parameter set can be passed to the functionvia theparam_info argument.
The rows of the grid are called tuning parametercandidates. Eachcandidate has a unique.config value that, for grid search, has thepatternpre{num}_mod{num}_post{num}. The numbers include a zero when thatelement was static. For example, a value ofpre0_mod3_post4 means nopreprocessors were tuned and the model and postprocessor(s) had at leastthree and four candidates, respectively. Also, the numbers are zero-paddedto enable proper sorting.
Performance Metrics
To use your own performance metrics, theyardstick::metric_set() functioncan be used to pick what should be measured for each model. If multiplemetrics are desired, they can be bundled. For example, to estimate the areaunder the ROC curve as well as the sensitivity and specificity (under thetypical probability cutoff of 0.50), themetrics argument could be given:
metrics = metric_set(roc_auc, sens, spec)
Each metric is calculated for each candidate model.
If no metric set is provided, one is created:
For regression models, the root mean squared error and coefficientof determination are computed.
For classification, the area under the ROC curve and overall accuracyare computed.
Note that the metrics also determine what type of predictions are estimatedduring tuning. For example, in a classification problem, if metrics are usedthat are all associated with hard class predictions, the classificationprobabilities are not created.
The out-of-sample estimates of these metrics are contained in a list columncalled.metrics. This tibble contains a row for each metric and columnsfor the value, the estimator type, and so on.
collect_metrics() can be used for these objects to collapse the resultsover the resampled (to obtain the final resampling estimates per tuningparameter combination).
Obtaining Predictions
Whencontrol_grid(save_pred = TRUE), the output tibble contains a listcolumn called.predictions that has the out-of-sample predictions for eachparameter combination in the grid and each fold (which can be very large).
The elements of the tibble are tibbles with columns for the tuningparameters, the row number from the original data object (.row), theoutcome data (with the same name(s) of the original data), and any columnscreated by the predictions. For example, for simple regression problems, thisfunction generates a column called.pred and so on. As noted above, theprediction columns that are returned are determined by the type of metric(s)requested.
This list column can beunnested usingtidyr::unnest() or using theconvenience functioncollect_predictions().
Extracting Information
Theextract control option will result in an additional function to bereturned called.extracts. This is a list column that has tibblescontaining the results of the user's function for each tuning parametercombination. This can enable returning each model and/or recipe object thatis created during resampling. Note that this could result in a large returnobject, depending on what is returned.
The control function contains an option (extract) that can be used toretain any model or recipe that was created within the resamples. Thisargument should be a function with a single argument. The value of theargument that is given to the function in each resample is a workflowobject (seeworkflows::workflow() for more information). Severalhelper functions can be used to easily pull out the preprocessingand/or model information from the workflow, such asextract_preprocessor() andextract_fit_parsnip().
As an example, if there is interest in getting each parsnip model fit back,one could use:
extract = function (x) extract_fit_parsnip(x)
Note that the function given to theextract argument is evaluated onevery model that isfit (as opposed to every model that isevaluated).As noted above, in some cases, model predictions can be derived forsub-models so that, in these cases, not every row in the tuning parametergrid has a separate R object associated with it.
Finally, it is a good idea to include calls torequire() for packages thatare used in the function. This helps prevent failures when using parallelprocessing.
Case Weights
Some models can utilize case weights during training. tidymodels currentlysupports two types of case weights: importance weights (doubles) andfrequency weights (integers). Frequency weights are used during modelfitting and evaluation, whereas importance weights are only used duringfitting.
To know if your model is capable of using case weights, create a model specand test it usingparsnip::case_weights_allowed().
To use them, you will need a numeric column in your data set that has beenpassed through eitherhardhat:: importance_weights() orhardhat::frequency_weights().
For functions such asfit_resamples() and thetune_*() functions, themodel must be contained inside of aworkflows::workflow(). To declare thatcase weights are used, invokeworkflows::add_case_weights() with thecorresponding (unquoted) column name.
From there, the packages will appropriately handle the weights during modelfitting and (if appropriate) performance estimation.
Censored Regression Models
Three types of metrics can be used to assess the quality of censoredregression models:
static: the prediction is independent of time.
dynamic: the prediction is a time-specific probability (e.g., survivalprobability) and is measured at one or more particular times.
integrated: same as the dynamic metric but returns the integral of thedifferent metrics from each time point.
Which metrics are chosen by the user affects how many evaluation timesshould be specified. For example:
# Needs no `eval_time` valuemetric_set(concordance_survival)# Needs at least one `eval_time`metric_set(brier_survival)metric_set(brier_survival, concordance_survival)# Needs at least two eval_time` valuesmetric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival)metric_set(brier_survival_integrated, concordance_survival, brier_survival)
Values ofeval_time should be less than the largest observed eventtime in the training data. For many non-parametric models, the results beyondthe largest time corresponding to an event are constant (orNA).
See Also
control_grid(),tune(),fit_resamples(),autoplot.tune_results(),show_best(),select_best(),collect_predictions(),collect_metrics()
Examples
library(recipes)library(rsample)library(parsnip)library(workflows)library(ggplot2)# ---------------------------------------------------------------------------set.seed(6735)folds <- vfold_cv(mtcars, v = 5)# ---------------------------------------------------------------------------# tuning recipe parameters:spline_rec <- recipe(mpg ~ ., data = mtcars) |> step_spline_natural(disp, deg_free = tune("disp")) |> step_spline_natural(wt, deg_free = tune("wt"))lin_mod <- linear_reg() |> set_engine("lm")# manually create a gridspline_grid <- expand.grid(disp = 2:5, wt = 2:5)# Warnings will occur from making spline terms on the holdout data that are# extrapolations.spline_res <- tune_grid(lin_mod, spline_rec, resamples = folds, grid = spline_grid)spline_resshow_best(spline_res, metric = "rmse")# ---------------------------------------------------------------------------# tune model parameters only (example requires the `kernlab` package)car_rec <- recipe(mpg ~ ., data = mtcars) |> step_normalize(all_predictors())svm_mod <- svm_rbf(cost = tune(), rbf_sigma = tune()) |> set_engine("kernlab") |> set_mode("regression")# Use a space-filling design with 7 pointsset.seed(3254)svm_res <- tune_grid(svm_mod, car_rec, resamples = folds, grid = 7)svm_resshow_best(svm_res, metric = "rmse")autoplot(svm_res, metric = "rmse") + scale_x_log10()# ---------------------------------------------------------------------------# Using a variables preprocessor with a workflow# Rather than supplying a preprocessor (like a recipe) and a model directly# to `tune_grid()`, you can also wrap them up in a workflow and pass# that along instead (note that this doesn't do any preprocessing to# the variables, it passes them along as-is).wf <- workflow() |> add_variables(outcomes = mpg, predictors = everything()) |> add_model(svm_mod)set.seed(3254)svm_res_wf <- tune_grid(wf, resamples = folds, grid = 7)