| Title: | Fit 'TabNet' Models for Classification and Regression |
| Version: | 0.7.0 |
| Description: | Implements the 'TabNet' model by Sercan O. Arik et al. (2019) <doi:10.48550/arXiv.1908.07442> with 'Coherent Hierarchical Multi-label Classification Networks' by Giunchiglia et al. <doi:10.48550/arXiv.2010.10151> and provides a consistent interface for fitting and creating predictions. It's also fully compatible with the 'tidymodels' ecosystem. |
| License: | MIT + file LICENSE |
| URL: | https://mlverse.github.io/tabnet/,https://github.com/mlverse/tabnet |
| BugReports: | https://github.com/mlverse/tabnet/issues |
| Depends: | R (≥ 3.6) |
| Imports: | coro, data.tree, dials, dplyr, ggplot2, hardhat (≥ 1.3.0),magrittr, Matrix, methods, parsnip, progress, purrr, rlang,stats, stringr, tibble, tidyr, torch (≥ 0.4.0), tune, utils,vctrs, withr, zeallot |
| Suggests: | cli, knitr, modeldata, patchwork, recipes, rmarkdown,rsample, spelling, testthat (≥ 3.0.0), tidymodels, tidyverse,vip, visdat, workflows, yardstick |
| VignetteBuilder: | knitr |
| Config/testthat/edition: | 3 |
| Config/testthat/parallel: | false |
| Config/testthat/start-first: | interface, explain, params |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.2 |
| Language: | en-US |
| NeedsCompilation: | no |
| Packaged: | 2025-04-16 22:08:41 UTC; creg |
| Author: | Daniel Falbel [aut], RStudio [cph], Christophe Regouby [cre, ctb], Egill Fridgeirsson [ctb], Philipp Haarmeyer [ctb], Sven Verweij |
| Maintainer: | Christophe Regouby <christophe.regouby@free.fr> |
| Repository: | CRAN |
| Date/Publication: | 2025-04-17 00:10:02 UTC |
Pipe operator
Description
Seemagrittr::%>% for details.
Usage
lhs %>% rhsValue
Returnsrhs(lhs).
Parameters for the tabnet model
Description
Parameters for the tabnet model
Usage
attention_width(range = c(8L, 64L), trans = NULL)decision_width(range = c(8L, 64L), trans = NULL)feature_reusage(range = c(1, 2), trans = NULL)momentum(range = c(0.01, 0.4), trans = NULL)mask_type(values = c("sparsemax", "entmax"))num_independent(range = c(1L, 5L), trans = NULL)num_shared(range = c(1L, 5L), trans = NULL)num_steps(range = c(3L, 10L), trans = NULL)Arguments
range | the default range for the parameter value |
trans | whether to apply a transformation to the parameter |
values | possible values for factor parameters These functions are used with |
Value
Adials parameter to be used when tuning TabNet models.
Examples
model <- tabnet(attention_width = tune(), feature_reusage = tune(), momentum = tune(), penalty = tune(), rate_step_size = tune()) %>% parsnip::set_mode("regression") %>% parsnip::set_engine("torch")Plot tabnet_explain mask importance heatmap
Description
Plot tabnet_explain mask importance heatmap
Usage
autoplot.tabnet_explain( object, type = c("mask_agg", "steps"), quantile = 1, ...)Arguments
object | A |
type | a character value. Either |
quantile | numerical value between 0 and 1. Provides quantile clipping of themask values |
... | not used. |
Details
Plot thetabnet_explain object mask importance per variable along the predicted dataset.type="mask_agg" output a single heatmap of mask aggregated values,type="steps" provides a plot faceted along then_steps mask present in the model.quantile=.995 may be used for strong outlier clipping, in order to better highlightlow values.quantile=1, the default, do not clip any values.
Value
Aggplot object.
Examples
## Not run: library(ggplot2)data("attrition", package = "modeldata")## Single-outcome binary classification of `Attrition` in `attrition` datasetattrition_fit <- tabnet_fit(Attrition ~. , data=attrition, epoch=11)attrition_explain <- tabnet_explain(attrition_fit, attrition)# Plot the model aggregated mask interpretation heatmapautoplot(attrition_explain)## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset,data("ames", package = "modeldata")x <- ames[,-which(names(ames) %in% c("Sale_Price", "Pool_Area"))]y <- ames[, c("Sale_Price", "Pool_Area")]ames_fit <- tabnet_fit(x, y, epochs = 1, verbose=TRUE)ames_explain <- tabnet_explain(ames_fit, x)autoplot(ames_explain, quantile = 0.99)## End(Not run)Plot tabnet_fit model loss along epochs
Description
Plot tabnet_fit model loss along epochs
Usage
autoplot.tabnet_fit(object, ...)autoplot.tabnet_pretrain(object, ...)Arguments
object | A |
... | not used. |
Details
Plot the training loss along epochs, and validation loss along epochs if any.A dot is added on epochs where model snapshot is available, helpingthe choice offrom_epoch value for later model training resume.
Value
Aggplot object.
Examples
## Not run: library(ggplot2)data("attrition", package = "modeldata")attrition_fit <- tabnet_fit(Attrition ~. , data=attrition, valid_split=0.2, epoch=11)# Plot the model loss over epochsautoplot(attrition_fit)## End(Not run)Non-tunable parameters for the tabnet model
Description
Non-tunable parameters for the tabnet model
Usage
cat_emb_dim(range = NULL, trans = NULL)checkpoint_epochs(range = NULL, trans = NULL)drop_last(range = NULL, trans = NULL)encoder_activation(range = NULL, trans = NULL)lr_scheduler(range = NULL, trans = NULL)mlp_activation(range = NULL, trans = NULL)mlp_hidden_multiplier(range = NULL, trans = NULL)num_independent_decoder(range = NULL, trans = NULL)num_shared_decoder(range = NULL, trans = NULL)optimizer(range = NULL, trans = NULL)penalty(range = NULL, trans = NULL)verbose(range = NULL, trans = NULL)virtual_batch_size(range = NULL, trans = NULL)Arguments
range | unused |
trans | unused |
Check that Node object names are compliant
Description
Check that Node object names are compliant
Usage
check_compliant_node(node)Arguments
node | the Node object, or a dataframe ready to be parsed by |
Value
node if it is compliant, else an Error with the column names to fix
Examples
library(dplyr)library(data.tree)data(starwars)starwars_tree <- starwars %>% mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/"))# pre as.Node() checktry(check_compliant_node(starwars_tree))# post as.Node() checkcheck_compliant_node(as.Node(starwars_tree))Determine the minimum set of model fits
Description
min_grid() determines exactly what models should be fit in order toevaluate the entire set of tuning parameter combinations. This is forinternal use only and the API may change in the near future.
Usage
## S3 method for class 'tabnet'min_grid(x, grid, ...)Arguments
x | A model specification. |
grid | A tibble with tuning parameter combinations. |
... | Not currently used. |
Details
fit_max_value() can be used in other packages to implement amin_grid()method.
Value
A tibble with the minimum tuning parameters to fit and an additionallist column with the parameter combinations used for prediction.
Examples
library(dials)library(tune)library(parsnip)tabnet_spec <- tabnet(decision_width = tune(), attention_width = tune()) %>% set_mode("regression") %>% set_engine("torch")tabnet_grid <- tabnet_spec %>% extract_parameter_set_dials() %>% grid_regular(levels = 3)min_grid(tabnet_spec, tabnet_grid)Prune top layer(s) of a tabnet network
Description
Prunehead_size last layers of a tabnet network in order touse the pruned module as a sequential embedding module.
Usage
## S3 method for class 'tabnet_fit'nn_prune_head(x, head_size)## S3 method for class 'tabnet_pretrain'nn_prune_head(x, head_size)Arguments
x | nn_network to prune |
head_size | number of nn_layers to prune, should be less than 2 |
Value
a tabnet network with the top nn_layer removed
Examples
data("ames", package = "modeldata")x <- ames[,-which(names(ames) == "Sale_Price")]y <- ames$Sale_Price# pretrain a tabnet model on ames datasetames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1)# prune classification head to get an embedding modelpruned_pretrain <- torch::nn_prune_head(ames_pretrain, 1)Turn a Node object into predictor and outcome.
Description
Turn a Node object into predictor and outcome.
Usage
node_to_df(x, drop_last_level = TRUE)Arguments
x | Node object |
drop_last_level | TRUE unused |
Value
a named list of x and y, being respectively the predictor data-frame and the outcomes data-frame,as expected inputs forhardhat::mold() function.
Examples
library(dplyr)library(data.tree)data(starwars)starwars_tree <- starwars %>% mutate(pathString = paste("tree", species, homeworld, `name`, sep = "/")) %>% as.Node()node_to_df(starwars_tree)$x %>% head()node_to_df(starwars_tree)$y %>% head()Parsnip compatible tabnet model
Description
Parsnip compatible tabnet model
Usage
tabnet( mode = "unknown", cat_emb_dim = NULL, decision_width = NULL, attention_width = NULL, num_steps = NULL, mask_type = NULL, num_independent = NULL, num_shared = NULL, num_independent_decoder = NULL, num_shared_decoder = NULL, penalty = NULL, feature_reusage = NULL, momentum = NULL, epochs = NULL, batch_size = NULL, virtual_batch_size = NULL, learn_rate = NULL, optimizer = NULL, loss = NULL, clip_value = NULL, drop_last = NULL, lr_scheduler = NULL, rate_decay = NULL, rate_step_size = NULL, checkpoint_epochs = NULL, verbose = NULL, importance_sample_size = NULL, early_stopping_monitor = NULL, early_stopping_tolerance = NULL, early_stopping_patience = NULL, skip_importance = NULL, tabnet_model = NULL, from_epoch = NULL)Arguments
mode | A single character string for the type of model. Possible valuesfor this model are "unknown", "regression", or "classification". |
cat_emb_dim | Size of the embedding of categorical features. If int, all categoricalfeatures will have same embedding size, if list of int, every corresponding feature will havespecific embedding size. |
decision_width | (int) Width of the decision prediction layer. Bigger values givesmore capacity to the model with the risk of overfitting. Values typicallyrange from 8 to 64. |
attention_width | (int) Width of the attention embedding for each mask. According tothe paper n_d = n_a is usually a good choice. (default=8) |
num_steps | (int) Number of steps in the architecture(usually between 3 and 10) |
mask_type | (character) Final layer of feature selector in the attentive_transformerblock, either |
num_independent | Number of independent Gated Linear Units layers at each step of the encoder.Usual values range from 1 to 5. |
num_shared | Number of shared Gated Linear Units at each step of the encoder. Usual valuesat each step of the decoder. range from 1 to 5 |
num_independent_decoder | For pretraining, number of independent Gated Linear Units layersUsual values range from 1 to 5. |
num_shared_decoder | For pretraining, number of shared Gated Linear Units at each step of thedecoder. Usual values range from 1 to 5. |
penalty | This is the extra sparsity loss coefficient as proposedin the original paper. The bigger this coefficient is, the sparser your modelwill be in terms of feature selection. Depending on the difficulty of yourproblem, reducing this value could help (default 1e-3). |
feature_reusage | (num) This is the coefficient for feature reusage in the masks.A value close to 1 will make mask selection least correlated between layers.Values range from 1 to 2. |
momentum | Momentum for batch normalization, typically ranges from 0.01to 0.4 (default=0.02) |
epochs | (int) Number of training epochs. |
batch_size | (int) Number of examples per batch, large batch sizes arerecommended. (default: 1024^2) |
virtual_batch_size | (int) Size of the mini batches used for"Ghost Batch Normalization" (default=256^2) |
learn_rate | initial learning rate for the optimizer. |
optimizer | the optimization method. currently only |
loss | (character or function) Loss function for training (default to msefor regression and cross entropy for classification) |
clip_value | If a num is given this will clip the gradient atclip_value. Pass |
drop_last | (logical) Whether to drop last batch if not complete duringtraining |
lr_scheduler | if |
rate_decay | multiplies the initial learning rate by |
rate_step_size | the learning rate scheduler step size. Unused if |
checkpoint_epochs | checkpoint model weights and architecture every |
verbose | (logical) Whether to print progress and loss values duringtraining. |
importance_sample_size | sample of the dataset to compute importance metrics.If the dataset is larger than 1e5 obs we will use a sample of size 1e5 anddisplay a warning. |
early_stopping_monitor | Metric to monitor for early_stopping. One of "valid_loss", "train_loss" or "auto" (defaults to "auto"). |
early_stopping_tolerance | Minimum relative improvement to reset the patience counter.0.01 for 1% tolerance (default 0) |
early_stopping_patience | Number of epochs without improving until stopping training. (default=5) |
skip_importance | if feature importance calculation should be skipped (default: |
tabnet_model | A previously fitted |
from_epoch | When a |
Value
A TabNetparsnip instance. It can be used to fit tabnet models usingparsnip machinery.
Threading
TabNet usestorch as its backend for computation andtorch uses allavailable threads by default.
You can control the number of threads used bytorch with:
torch::torch_set_num_threads(1)torch::torch_set_num_interop_threads(1)
See Also
tabnet_fit
Examples
library(parsnip)data("ames", package = "modeldata")model <- tabnet() %>% set_mode("regression") %>% set_engine("torch")model %>% fit(Sale_Price ~ ., data = ames)Configuration for TabNet models
Description
Configuration for TabNet models
Usage
tabnet_config( batch_size = 1024^2, penalty = 0.001, clip_value = NULL, loss = "auto", epochs = 5, drop_last = FALSE, decision_width = NULL, attention_width = NULL, num_steps = 3, feature_reusage = 1.3, mask_type = "sparsemax", virtual_batch_size = 256^2, valid_split = 0, learn_rate = 0.02, optimizer = "adam", lr_scheduler = NULL, lr_decay = 0.1, step_size = 30, checkpoint_epochs = 10, cat_emb_dim = 1, num_independent = 2, num_shared = 2, num_independent_decoder = 1, num_shared_decoder = 1, momentum = 0.02, pretraining_ratio = 0.5, verbose = FALSE, device = "auto", importance_sample_size = NULL, early_stopping_monitor = "auto", early_stopping_tolerance = 0, early_stopping_patience = 0L, num_workers = 0L, skip_importance = FALSE)Arguments
batch_size | (int) Number of examples per batch, large batch sizes arerecommended. (default: 1024^2) |
penalty | This is the extra sparsity loss coefficient as proposedin the original paper. The bigger this coefficient is, the sparser your modelwill be in terms of feature selection. Depending on the difficulty of yourproblem, reducing this value could help (default 1e-3). |
clip_value | If a num is given this will clip the gradient atclip_value. Pass |
loss | (character or function) Loss function for training (default to msefor regression and cross entropy for classification) |
epochs | (int) Number of training epochs. |
drop_last | (logical) Whether to drop last batch if not complete duringtraining |
decision_width | (int) Width of the decision prediction layer. Bigger values givesmore capacity to the model with the risk of overfitting. Values typicallyrange from 8 to 64. |
attention_width | (int) Width of the attention embedding for each mask. According tothe paper n_d = n_a is usually a good choice. (default=8) |
num_steps | (int) Number of steps in the architecture(usually between 3 and 10) |
feature_reusage | (num) This is the coefficient for feature reusage in the masks.A value close to 1 will make mask selection least correlated between layers.Values range from 1 to 2. |
mask_type | (character) Final layer of feature selector in the attentive_transformerblock, either |
virtual_batch_size | (int) Size of the mini batches used for"Ghost Batch Normalization" (default=256^2) |
valid_split | In [0, 1). The fraction of the dataset used for validation.(default = 0 means no split) |
learn_rate | initial learning rate for the optimizer. |
optimizer | the optimization method. currently only |
lr_scheduler | if |
lr_decay | multiplies the initial learning rate by |
step_size | the learning rate scheduler step size. Unused if |
checkpoint_epochs | checkpoint model weights and architecture every |
cat_emb_dim | Size of the embedding of categorical features. If int, all categoricalfeatures will have same embedding size, if list of int, every corresponding feature will havespecific embedding size. |
num_independent | Number of independent Gated Linear Units layers at each step of the encoder.Usual values range from 1 to 5. |
num_shared | Number of shared Gated Linear Units at each step of the encoder. Usual valuesat each step of the decoder. range from 1 to 5 |
num_independent_decoder | For pretraining, number of independent Gated Linear Units layersUsual values range from 1 to 5. |
num_shared_decoder | For pretraining, number of shared Gated Linear Units at each step of thedecoder. Usual values range from 1 to 5. |
momentum | Momentum for batch normalization, typically ranges from 0.01to 0.4 (default=0.02) |
pretraining_ratio | Ratio of features to mask for reconstruction duringpretraining. Ranges from 0 to 1 (default=0.5) |
verbose | (logical) Whether to print progress and loss values duringtraining. |
device | the device to use for training. "cpu" or "cuda". The default ("auto")uses to "cuda" if it's available, otherwise uses "cpu". |
importance_sample_size | sample of the dataset to compute importance metrics.If the dataset is larger than 1e5 obs we will use a sample of size 1e5 anddisplay a warning. |
early_stopping_monitor | Metric to monitor for early_stopping. One of "valid_loss", "train_loss" or "auto" (defaults to "auto"). |
early_stopping_tolerance | Minimum relative improvement to reset the patience counter.0.01 for 1% tolerance (default 0) |
early_stopping_patience | Number of epochs without improving until stopping training. (default=5) |
num_workers | (int, optional): how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: |
skip_importance | if feature importance calculation should be skipped (default: |
Value
A named list with all hyperparameters of the TabNet implementation.
Examples
data("ames", package = "modeldata")# change the model config for an faster ignite optimizerconfig <- tabnet_config(optimizer = torch::optim_ignite_adamw)## Single-outcome regression using formula specificationfit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1, config = config)Interpretation metrics from a TabNet model
Description
Interpretation metrics from a TabNet model
Usage
tabnet_explain(object, new_data)## Default S3 method:tabnet_explain(object, new_data)## S3 method for class 'tabnet_fit'tabnet_explain(object, new_data)## S3 method for class 'tabnet_pretrain'tabnet_explain(object, new_data)## S3 method for class 'model_fit'tabnet_explain(object, new_data)Arguments
object | a TabNet fit object |
new_data | a data.frame to obtain interpretation metrics. |
Value
Returns a list with
M_explain: the aggregated feature importance masks as detailed inTabNet's paper.masksa list containing the masks for each step.
Examples
set.seed(2021)n <- 256x <- data.frame( x = rnorm(n), y = rnorm(n), z = rnorm(n))y <- x$xfit <- tabnet_fit(x, y, epochs = 10, num_steps = 1, batch_size = 512, attention_width = 1, num_shared = 1, num_independent = 1) ex <- tabnet_explain(fit, x)Tabnet model
Description
Fits theTabNet: Attentive Interpretable Tabular Learning model
Usage
tabnet_fit(x, ...)## Default S3 method:tabnet_fit(x, ...)## S3 method for class 'data.frame'tabnet_fit( x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL, weights = NULL)## S3 method for class 'formula'tabnet_fit( formula, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL, weights = NULL)## S3 method for class 'recipe'tabnet_fit( x, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL, weights = NULL)## S3 method for class 'Node'tabnet_fit( x, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL)Arguments
x | Depending on the context:
The predictor data should be standardized (e.g. centered or scaled).The model treats categorical predictors internally thus, you don't need tomake any treatment.The model treats missing values internally thus, you don't need to make anytreatment. |
... | Model hyperparameters.Any hyperparameters set here will update those set by the config argument.See |
y | When
|
tabnet_model | A previously fitted |
config | A set of hyperparameters created using the |
from_epoch | When a |
weights | Unused. Placeholder for hardhat::importance_weight() variables. |
formula | A formula specifying the outcome terms on the left-hand side,and the predictor terms on the right-hand side. |
data | When arecipe orformula is used,
|
Value
A TabNet model object. It can be used for serialization, predictions, or further fitting.
Fitting a pre-trained model
When providing a parenttabnet_model parameter, the model fitting resumes from that model weightsat the following epoch:
last fitted epoch for a model already in torch context
Last model checkpoint epoch for a model loaded from file
the epoch related to a checkpoint matching or preceding the
from_epochvalue if providedThe model fitting metrics append on top of the parent metrics in the returned TabNet model.
Multi-outcome
TabNet allows multi-outcome prediction, which is usually namedmulti-label classificationor multi-output regression when outcomes are numerical.Multi-outcome currently expect outcomes to be either all numeric or all categorical.
Threading
TabNet usestorch as its backend for computation andtorch uses allavailable threads by default.
You can control the number of threads used bytorch with:
torch::torch_set_num_threads(1)torch::torch_set_num_interop_threads(1)
Examples
## Not run: data("ames", package = "modeldata")data("attrition", package = "modeldata")## Single-outcome regression using formula specificationfit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 4)## Single-outcome classification using data-frame specificationattrition_x <- attrition[ids,-which(names(attrition) == "Attrition")]fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 4, verbose = TRUE)## Multi-outcome regression on `Sale_Price` and `Pool_Area` in `ames` dataset using formula,ames_fit <- tabnet_fit(Sale_Price + Pool_Area ~ ., data = ames, epochs = 4, valid_split = 0.2)## Multi-label classification on `Attrition` and `JobSatisfaction` in## `attrition` dataset using recipelibrary(recipes)rec <- recipe(Attrition + JobSatisfaction ~ ., data = attrition) %>% step_normalize(all_numeric(), -all_outcomes())attrition_fit <- tabnet_fit(rec, data = attrition, epochs = 4, valid_split = 0.2)## Hierarchical classification on `acme`data(acme, package = "data.tree")acme_fit <- tabnet_fit(acme, epochs = 4, verbose = TRUE)# Note: Model's number of epochs should be increased for publication-level results.## End(Not run)TabNet Model Architecture
Description
This is ann_module representing the TabNet architecture fromAttentive Interpretable Tabular Deep Learning.
Usage
tabnet_nn( input_dim, output_dim, n_d = 8, n_a = 8, n_steps = 3, gamma = 1.3, cat_idxs = c(), cat_dims = c(), cat_emb_dim = 1, n_independent = 2, n_shared = 2, epsilon = 1e-15, virtual_batch_size = 128, momentum = 0.02, mask_type = "sparsemax")Arguments
input_dim | Initial number of features. |
output_dim | Dimension of network output. Examples : one for regression, 2 forbinary classification etc.. Vector of those dimensions in case of multi-output. |
n_d | Dimension of the prediction layer (usually between 4 and 64). |
n_a | Dimension of the attention layer (usually between 4 and 64). |
n_steps | Number of successive steps in the network (usually between 3 and 10). |
gamma | Scaling factor for attention updates (usually between 1 and 2). |
cat_idxs | Index of each categorical column in the dataset. |
cat_dims | Number of categories in each categorical column. |
cat_emb_dim | Size of the embedding of categorical features if int, all categoricalfeatures will have same embedding size if list of int, every corresponding feature will havespecific size. |
n_independent | Number of independent GLU layer in each GLU block of the encoder. |
n_shared | Number of shared GLU layer in each GLU block of the encoder. |
epsilon | Avoid log(0), this should be kept very low. |
virtual_batch_size | Batch size for Ghost Batch Normalization. |
momentum | Numerical value between 0 and 1 which will be used for momentum in all batch norm. |
mask_type | Either "sparsemax" or "entmax" : this is the masking function to use. |
Tabnet model
Description
Pretrain theTabNet: Attentive Interpretable Tabular Learning modelon the predictor data exclusively (unsupervised training).
Usage
tabnet_pretrain(x, ...)## Default S3 method:tabnet_pretrain(x, ...)## S3 method for class 'data.frame'tabnet_pretrain( x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL)## S3 method for class 'formula'tabnet_pretrain( formula, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL)## S3 method for class 'recipe'tabnet_pretrain( x, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL)## S3 method for class 'Node'tabnet_pretrain( x, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL)Arguments
x | Depending on the context:
The predictor data should be standardized (e.g. centered or scaled).The model treats categorical predictors internally thus, you don't need tomake any treatment.The model treats missing values internally thus, you don't need to make anytreatment. |
... | Model hyperparameters.Any hyperparameters set here will update those set by the config argument.See |
y | (optional) When |
tabnet_model | A pretrained |
config | A set of hyperparameters created using the |
from_epoch | When a |
formula | A formula specifying the outcome terms on the left-hand side,and the predictor terms on the right-hand side. |
data | When arecipe orformula is used,
|
Value
A TabNet model object. It can be used for serialization, predictions, or further fitting.
outcome
Outcome value are accepted here only for consistent syntax withtabnet_fit, butby design the outcome, if present, is ignored during pre-training.
pre-training from a previous model
When providing a parenttabnet_model parameter, the model pretraining resumes from that model weightsat the following epoch:
last pretrained epoch for a model already in torch context
Last model checkpoint epoch for a model loaded from file
the epoch related to a checkpoint matching or preceding the
from_epochvalue if providedThe model pretraining metrics append on top of the parent metrics in the returned TabNet model.
Threading
TabNet usestorch as its backend for computation andtorch uses allavailable threads by default.
You can control the number of threads used bytorch with:
torch::torch_set_num_threads(1)torch::torch_set_num_interop_threads(1)
Examples
data("ames", package = "modeldata")pretrained <- tabnet_pretrain(Sale_Price ~ ., data = ames, epochs = 1)