| Title: | Higher Level 'API' for 'torch' |
| Version: | 0.5.1 |
| Description: | A high level interface for 'torch' providing utilities to reduce the the amount of code needed for common tasks, abstract away torch details and make the same code work on both the 'CPU' and 'GPU'. It's flexible enough to support expressing a large range of models. It's heavily inspired by 'fastai' by Howard et al. (2020) <doi:10.48550/arXiv.2002.04688>, 'Keras' by Chollet et al. (2015) and 'PyTorch Lightning' by Falcon et al. (2019) <doi:10.5281/zenodo.3828935>. |
| License: | MIT + file LICENSE |
| URL: | https://mlverse.github.io/luz/,https://github.com/mlverse/luz |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.2 |
| Imports: | torch (≥ 0.11.9000), magrittr, zeallot, rlang (≥ 1.0.0),coro, glue, progress, R6, generics, purrr, fs, prettyunits, cli |
| Suggests: | knitr, rmarkdown, testthat (≥ 3.0.0), covr, Metrics, withr,vdiffr, ggplot2 (≥ 3.0.0), dplyr, torchvision, tfevents (≥0.0.2), tidyr |
| VignetteBuilder: | knitr |
| Config/testthat/edition: | 3 |
| Collate: | 'accelerator.R' 'as_dataloader.R' 'utils.R' 'callbacks.R''callbacks-amp.R' 'callbacks-interrupt.R' 'callbacks-mixup.R''callbacks-monitor-metrics.R' 'callbacks-profile.R''callbacks-resume.R' 'callbacks-tfevents.R' 'context.R''losses.R' 'lr-finder.R' 'metrics.R' 'metrics-auc.R''module-plot.R' 'module-print.R' 'module.R' 'reexports.R''serialization.R' |
| BugReports: | https://github.com/mlverse/luz/issues |
| NeedsCompilation: | no |
| Packaged: | 2025-10-30 11:37:48 UTC; dfalbel |
| Author: | Daniel Falbel [aut, cre, cph], Christophe Regouby [ctb], RStudio [cph] |
| Maintainer: | Daniel Falbel <daniel@rstudio.com> |
| Repository: | CRAN |
| Date/Publication: | 2025-10-30 12:30:02 UTC |
Pipe operator
Description
Seemagrittr::%>% for details.
Usage
lhs %>% rhsCreate an accelerator
Description
Create an accelerator
Usage
accelerator( device_placement = TRUE, cpu = FALSE, cuda_index = torch::cuda_current_device())Arguments
device_placement | (logical) whether the |
cpu | (logical) whether the training procedure should run on the CPU. |
cuda_index | (integer) index of the CUDA device to use if multiple GPUsare available. Default: the result of torch::cuda_current_device(). |
Creates a dataloader from its input
Description
as_dataloader is used internally by luz to convert inputdata andvalid_data as passed tofit.luz_module_generator() to atorch::dataloader
Usage
as_dataloader(x, ...)## S3 method for class 'dataset'as_dataloader(x, ..., batch_size = 32)## S3 method for class 'iterable_dataset'as_dataloader(x, ..., batch_size = 32)## S3 method for class 'list'as_dataloader(x, ...)## S3 method for class 'dataloader'as_dataloader(x, ...)## S3 method for class 'matrix'as_dataloader(x, ...)## S3 method for class 'numeric'as_dataloader(x, ...)## S3 method for class 'array'as_dataloader(x, ...)## S3 method for class 'torch_tensor'as_dataloader(x, ...)Arguments
x | the input object. |
... | Passed to |
batch_size | (int, optional): how many samples per batch to load(default: |
Details
as_dataloader methods should have sensible defaults for batch_size,parallel workers, etc.
It allows users to quickly experiment withfit.luz_module_generator() by not requiringto create atorch::dataset and atorch::dataloader in simpleexperiments.
Methods (by class)
as_dataloader(dataset): Converts atorch::dataset()to atorch::dataloader().as_dataloader(iterable_dataset): Converts atorch::iterable_dataset()into atorch::dataloader()as_dataloader(list): Converts a list of tensors or arrays with the samesize in the first dimension to atorch::dataloader()as_dataloader(dataloader): Returns the same dataloaderas_dataloader(matrix): Converts the matrix to a dataloaderas_dataloader(numeric): Converts the numeric vector to a dataloaderas_dataloader(array): Converts the array to a dataloaderas_dataloader(torch_tensor): Converts the tensor to a dataloader
Overriding
You can implement your ownas_dataloader S3 method if you want your datastructure to be automatically supported by luz'sfit.luz_module_generator().The method must satisfy the following conditions:
The method should return a
torch::dataloader().The only required argument is
x. You have good default for all otherarguments.
It's better to avoid implementingas_dataloader methods for common S3 classeslikedata.frames. In this case, its better to assign a different class tothe inputs and implementas_dataloader for it.
Context object
Description
Context object storing information about the model training context.See alsoctx.
Public fields
buffersThis is a list of buffers that callbacks can use to write temporaryinformation into
ctx.
Active bindings
recordsstores information about values logged with
self$log.deviceallows querying the current accelerator device
callbackslist of callbacks that will be called.
itercurrent iteration
batchthe current batch data. a list with input data and targets.
inputa shortcut for
ctx$batch[[1]]targeta shortcut for
ctx$batch[[2]]min_epochsthe minimum number of epochs that the model will run on.
max_epochsthe maximum number of epochs that the model will run.
hparamsa list of hyperparameters that were used to initialize
ctx$model.opt_hparamsa list of hyperparameters used to initialize the
ctx$optimizers.train_dataa dataloader that is used for training the model
valid_dataa dataloader using during model validation
acceleratoran
accelerator()used to move data, model and etc the the correctdevice.optimizersa named list of optimizers that will be used during model training.
verbosebool wether the process is in verbose mode or not.
handlersList of error handlers that can be used. See
rlang::try_fetch()for more info.epoch_handlersList of error handlers that can be used. See
rlang::try_fetch()for more info.trainingA bool indicating if the model is in training or validation mode.
modelThe model being trained.
predLast predicted values.
optCurrent optimizer.
opt_nameCurrent optimizer name.
dataCurrent dataloader in use.
loss_fnLoss function used to train the model
lossLast computed loss values. Detached from the graph.
loss_gradLast computed loss value, not detached, so you can do additionaltranformation.
epochCurrent epoch.
metricsList of metrics that are tracked by the process.
step_optDefines how step is called for the optimizer. It must be a functiontaking an optimizer as argument.
Methods
Public methods
Methodnew()
Initializes the context object with minimal necessary information.
Usage
context$new(verbose, accelerator, callbacks, training)
Arguments
verboseWhether the context should be in verbose mode or not.
acceleratorA luz
accelerator()that configures device placement andothers.callbacksA list of callbacks used by the model. See
luz_callback().trainingA boolean that indicates if the context is in training mode or not.
Methodlog()
Allows logging arbitrary information in thectx.
Usage
context$log(what, set, value, index = NULL, append = TRUE)
Arguments
what(string) What you are logging.
set(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.
valueArbitrary value to log.
indexIndex that this value should be logged. If
NULLthe valueis added to the end of list, otherwise the index is used.appendIf
TRUEand a value in the corresponding index alreadyexists, then value is appended to the current value. IfFALSEvalueis overwritten in favor of the new value.
Methodlog_metric()
Log a metric by its name and value.Metric values are indexed by epoch.
Usage
context$log_metric(name, value)
Arguments
namename of the metric
valueArbitrary value to log.
Methodget_log()
Get a specific value from the log.
Usage
context$get_log(what, set, index = NULL)
Arguments
what(string) What you are logging.
set(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.
indexIndex that this value should be logged. If
NULLthe valueis added to the end of list, otherwise the index is used.
Methodget_metrics()
Get all metric given an epoch and set.
Usage
context$get_metrics(set, epoch = NULL)
Arguments
set(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.
epochThe epoch you want to extract metrics from.
Methodget_metric()
Get the value of a metric given its name, epoch and set.
Usage
context$get_metric(name, set, epoch = NULL)
Arguments
namename of the metric
set(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.
epochThe epoch you want to extract metrics from.
Methodget_formatted_metrics()
Get formatted metrics values
Usage
context$get_formatted_metrics(set, epoch = NULL)
Arguments
set(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.
epochThe epoch you want to extract metrics from.
Methodget_metrics_df()
Get a data.frame containing all metrics.
Usage
context$get_metrics_df()
Methodset_verbose()
Allows setting theverbose attribute.
Usage
context$set_verbose(verbose = NULL)
Arguments
verboseboolean. If
TRUEverbose mode is used. IfFALSEnon verbose.ifNULLwe use the result ofinteractive().
Methodclean()
Removes unnecessary information from the context object.
Usage
context$clean()
Methodcall_callbacks()
Call the selected callbacks. Wherename is the callback types to call, eg'on_epoch_begin'.
Usage
context$call_callbacks(name)
Arguments
namename of the metric
Methodstate_dict()
Returns a list containing minimal information from the context. Used tocreate the returned values.
Usage
context$state_dict()
Methodunsafe_set_records()
Are you sure you know what you are doing?
Usage
context$unsafe_set_records(records)
Arguments
recordsNew set of records to be set.
Methodclone()
The objects of this class are cloneable with this method.
Usage
context$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Context object
Description
Context objects used in luz to share information between model methods,metrics and callbacks.
Details
Thectx object is used in luz to share information between thetraining loop and callbacks, model methods, and metrics. The table belowdescribes information available in thectx by default. Other callbackscould potentially modify these attributes or add new ones.
| Attribute | Description |
verbose | The value (TRUE orFALSE) attributed to theverbose argument infit . |
accelerator | Accelerator object used to query the correct device to place models, data, etc. It assumes the value passed to theaccelerator parameter infit. |
model | Initializednn_module object that will be trained during thefit procedure. |
optimizers | A named list of optimizers used during training. |
data | The currently in-use dataloader. When training it’sctx$train_data, when doing validation itsctx$valid_data. It can also be the prediction dataset when inpredict. |
train_data | Dataloader passed to thedata argument infit. Modified to yield data in the selected device. |
valid_data | Dataloader passed to thevalid_data argument infit. Modified to yield data in the selected device. |
min_epochs | Minimum number of epochs the model will be trained for. |
max_epochs | Maximum number of epochs the model will be trained for. |
epoch | Current training epoch. |
iter | Current training iteration. It’s reset every epoch and when going from training to validation. |
training | Whether the model is in training or validation mode. See alsohelp("luz_callback_train_valid") |
callbacks | List of callbacks that will be called during the training procedure. It’s the union of the list passed to thecallbacks parameter and the defaultcallbacks. |
step | Closure that will be used to do onestep of the model. It’s used for both training and validation. Takes no argument, but can access thectx object. |
call_callbacks | Call callbacks by name. For examplecall_callbacks("on_train_begin") will call all callbacks that provide methods for this point. |
batch | Last batch obtained by the dataloader. A batch is alist() with 2 elements, one that is used asinput and the other astarget. |
input | First element of the last batch obtained by the current dataloader. |
target | Second element of the last batch obtained by the current dataloader. |
pred | Last predictions obtained byctx$model$forward .Note: can be potentially modified by previously ran callbacks. Also note that this might not be available if you used a custom training step. |
loss_fn | The active loss function that will be minimized during training. |
loss | Last computed loss from the model.Note: this might not be available if you modified the training or validation step. |
opt | Current optimizer, ie. the optimizer that will be used to do the nextstep to update parameters. |
opt_nm | Current optimizer name. By default it’sopt , but can change if your model uses more than one optimizer depending on the set of parameters being optimized. |
metrics | list() with current metric objects that areupdated at everyon_train_batch_end() oron_valid_batch_end(). See alsohelp("luz_callback_metrics") |
records | list() recording metric values for training and validation for each epoch. See alsohelp("luz_callback_metrics") . Also records profiling metrics. Seehelp("luz_callback_profile") for more information. |
handlers | A namedlist() of handlers that is passed torlang::with_handlers() during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. |
epoch_handlers | A named list of handlers that is used withrlang::with_handlers(). Those handlers are used inside the epochs loop, thus you can handle epoch specific conditions, that won’t necessarily end training. |
Context attributes
See Also
Context object:context
Evaluates a fitted model on a dataset
Description
Evaluates a fitted model on a dataset
Usage
evaluate( object, data, ..., metrics = NULL, callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL)Arguments
object | A fitted model to evaluate. |
data | (dataloader, dataset or list) A dataloader created with |
... | Currently unused. |
metrics | A list of luz metrics to be tracked during evaluation. If |
callbacks | (list, optional) A list of callbacks defined with |
accelerator | (accelerator, optional) An optional |
verbose | (logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output if |
dataloader_options | Options used when creating a dataloader. See |
Details
Once a model has been trained you might want to evaluate its performanceon a different dataset. For that reason, luz provides the?evaluatefunction that takes a fitted model and a dataset and computes themetrics attached to the model.
Evaluate returns aluz_module_evaluation object that you can query formetrics using theget_metrics function or simplyprint to see theresults.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl)metrics <- get_metrics(evaluation)print(evaluation)
See Also
Other training:fit.luz_module_generator(),predict.luz_module_fitted(),setup()
Fit ann_module
Description
Fit ann_module
Usage
## S3 method for class 'luz_module_generator'fit( object, data, epochs = 10, callbacks = NULL, valid_data = NULL, accelerator = NULL, verbose = NULL, ..., dataloader_options = NULL)Arguments
object | An |
data | (dataloader, dataset or list) A dataloader created with |
epochs | (int) The maximum number of epochs for training the model. If asingle value is provided, this is taken to be the |
callbacks | (list, optional) A list of callbacks defined with |
valid_data | (dataloader, dataset, list or scalar value; optional) Adataloader created with |
accelerator | (accelerator, optional) An optional |
verbose | (logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output if |
... | Currently unused. |
dataloader_options | Options used when creating a dataloader. See |
Value
A fitted object that can be saved withluz_save() and can beprinted withprint() and plotted withplot().
See Also
predict.luz_module_fitted() for how to create predictions.setup() to find out how to create modules that can be trained withfit.
Other training:evaluate(),predict.luz_module_fitted(),setup()
Get metrics from the object
Description
Get metrics from the object
Usage
get_metrics(object, ...)## S3 method for class 'luz_module_fitted'get_metrics(object, ...)Arguments
object | The object to query for metrics. |
... | Currently unused. |
Value
A data.frame containing the metric values.
Methods (by class)
get_metrics(luz_module_fitted): Extract metrics from a luz fitted model.
Learning Rate Finder
Description
Learning Rate Finder
Usage
lr_finder( object, data, steps = 100, start_lr = 1e-07, end_lr = 0.1, log_spaced_intervals = TRUE, ..., verbose = NULL)Arguments
object | An nn_module that has been setup(). |
data | (dataloader) A dataloader created with torch::dataloader() used for learning rate finding. |
steps | (integer) The number of steps to iterate over in the learning rate finder. Default: 100. |
start_lr | (float) The smallest learning rate. Default: 1e-7. |
end_lr | (float) The highest learning rate. Default: 1e-1. |
log_spaced_intervals | (logical) Whether to divide the range between start_lr and end_lr into log-spaced intervals (alternative: uniform intervals). Default: TRUE |
... | Other arguments passed to |
verbose | Wether to show a progress bar during the process. |
Value
A dataframe with two columns: learning rate and loss
Examples
if (torch::torch_is_installed()) {library(torch)ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1))dl <- torch::dataloader(ds, batch_size = 32)model <- torch::nn_linearmodel <- model %>% setup( loss = torch::nn_mse_loss(), optimizer = torch::optim_adam) %>% set_hparams(in_features = 10, out_features = 1)records <- lr_finder(model, dl, verbose = FALSE)plot(records)}Create a new callback
Description
Create a new callback
Usage
luz_callback( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL)Arguments
name | name of the callback |
... | Public methods of the callback. The name of the methods is usedto know how they should be called. See the details section. |
private | An optional list of private members, which can be functionsand non-functions. |
active | An optional list of active binding functions. |
parent_env | An environment to use as the parent of newly-createdobjects. |
inherit | A R6ClassGenerator object to inherit from; in other words, asuperclass. This is captured as an unevaluated expression which isevaluated in |
Details
Let’s implement a callback that prints ‘Iterationn’ (wheren is theiteration number) for every batch in the training set and ‘Done’ when anepoch is finished. For that task we use theluz_callback function:
print_callback <- luz_callback( name = "print_callback", initialize = function(message) { self$message <- message }, on_train_batch_end = function() { cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() { cat(self$message, "\n") })luz_callback() takes named functions as... arguments, where thename indicates the moment at which the callback should be called. Forinstanceon_train_batch_end() is called for every batch at the end ofthe training procedure, andon_epoch_end() is called at the end ofevery epoch.
The returned value ofluz_callback() is a function that initializes aninstance of the callback. Callbacks can have initialization parameters,like the name of a file where you want to log the results. In that case,you can pass aninitialize method when creating the callbackdefinition, and save these parameters to theself object. In the aboveexample, the callback has amessage parameter that is printed at theend of each epoch.
Once a callback is defined it can be passed to thefit function viathecallbacks parameter:
fitted <- net %>% setup(...) %>% fit(..., callbacks = list( print_callback(message = "Done!") ))
Callbacks can be called in many different positions of the trainingloop, including combinations of them. Here’s an overview of possiblecallbackbreakpoints:
Start Fit - on_fit_begin Start Epoch Loop - on_epoch_begin Start Train - on_train_begin Start Batch Loop - on_train_batch_begin Start Default Training Step - on_train_batch_after_pred - on_train_batch_after_loss - on_train_batch_before_backward - on_train_batch_before_step - on_train_batch_after_step End Default Training Step: - on_train_batch_end End Batch Loop - on_train_end End Train Start Valid - on_valid_begin Start Batch Loop - on_valid_batch_begin Start Default Validation Step - on_valid_batch_after_pred - on_valid_batch_after_loss End Default Validation Step - on_valid_batch_end End Batch Loop - on_valid_end End Valid - on_epoch_end End Epoch Loop - on_fit_endEnd Fit
Every step marked withon_* is a point in the training procedure thatis available for callbacks to be called.
The other important part of callbacks is thectx (context) object. Seehelp("ctx") for details.
By default, callbacks are called in the same order as they were passedtofit (orpredict orevaluate), but you can provide aweightattribute that will control the order in which it will be called. Forexample, if one callback hasweight = 10 and another hasweight = 1,then the first one is called after the second one. Callbacks that don’tspecify aweight attribute are consideredweight = 0. A few built-incallbacks in luz already provide a weight value. For example, the?luz_callback_early_stopping has a weight ofInf, since in generalwe want to run it as the last thing in the loop.
Value
Aluz_callback that can be passed tofit.luz_module_generator().
Prediction callbacks
You can also use callbacks when usingpredict(). In this case the supportedcallback methods are detailed below:
Start predict - on_predict_begin Start prediction loop - on_predict_batch_begin - on_predict_batch_end End prediction loop - on_predict_endEnd predict
Evaluate callbacks
Callbacks can also be used withevaluate(), in this case, the callbacks thatare used are equivalent to those of the validation loop when usingfit():
Start Valid - on_valid_begin Start Batch Loop - on_valid_batch_begin Start Default Validation Step - on_valid_batch_after_pred - on_valid_batch_after_loss End Default Validation Step - on_valid_batch_end End Batch Loop - on_valid_endEnd Valid
See Also
Other luz_callbacks:luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
print_callback <- luz_callback( name = "print_callback", on_train_batch_end = function() { cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() { cat("Done!\n") })Resume training callback
Description
This callback allows you to resume training a model.
Usage
luz_callback_auto_resume(path = "./state.pt")Arguments
path | Path to save state files for the model. |
Details
When using it, model weights, optimizer state are serialized at the end ofeach epoch. If something fails during training simply re-running the samescript will restart the model training from the epoch right after the lastepoch that was serialized.
Customizing serialization
By default model, optimizer state and records are serialized. Callbacks canbe used to customize serialization by implementing thestate_dict() andload_state_dict() methods.If those methods are implemented, thenstate_dict() is called at the end ofeach epoch andload_state_dict() is called when the model is resumed.
Note
In general you will want to add this callback as the last in the callbackslist, this way, the serialized state is likely to contain all possible changesthat other callbacks could have made at'on_epoch_end'. The defaultweightattribute of this callback isInf.
Read the checkpointing article in the pkgdown website for moreinformation.
See Also
Other luz_callbacks:luz_callback(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {library(torch)library(luz)x <- torch_randn(1000, 10)y <- torch_randn(1000, 1)model <- nn_linear %>% setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 0.01)# simulate a failure in the middle of epoch 5 happening only once.callback_stop <- luz_callback( "interrupt", failed = FALSE, on_epoch_end = function() { if (ctx$epoch == 5 && !self$failed) { self$failed <- TRUE stop("Error on epoch 5") } })path <- tempfile()autoresume <- luz_callback_auto_resume(path = path)interrupt <- callback_stop()# try once and the model failstry({ results <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE )})# model resumes and completesresults <- model %>% fit( list(x, y), callbacks = list(autoresume, interrupt), verbose = FALSE)get_metrics(results)}CSV logger callback
Description
Logs metrics obtained during training a file on disk.The file will have 1 line for each epoch/validation.
Usage
luz_callback_csv_logger(path)Arguments
path | path to a file on disk. |
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Early stopping callback
Description
Stops training when a monitored metric stops improving
Usage
luz_callback_early_stopping( monitor = "valid_loss", min_delta = 0, patience = 0, mode = "min", baseline = NULL)Arguments
monitor | A string in the format |
min_delta | Minimum improvement to reset the patience counter. |
patience | Number of epochs without improving until stoping training. |
mode | Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better). |
baseline | An initial value that will be used as the best seen valuein the begining. Model will stop training if no better than baseline valueis found in the first |
Value
Aluz_callback that does early stopping.
Note
This callback adds aon_early_stopping callback that can be used tocall callbacks as soon as the model stops training.
Ifverbose=TRUE infit.luz_module_generator() a message is printed whenearly stopping.
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
cb <- luz_callback_early_stopping()Gradient clipping callback
Description
By adding the GradientClip callback, the gradientnorm_type (default:2) normis clipped to at mostmax_norm (default:1) usingtorch::nn_utils_clip_grad_norm_(),which can avoid loss divergence.
Usage
luz_callback_gradient_clip(max_norm = 1, norm_type = 2)Arguments
max_norm | (float or int): max norm of the gradients |
norm_type | (float or int): type of the used p-norm. Can be |
References
See FastAIdocumentationfor the GradientClip callback.
Interrupt callback
Description
Adds a handler that allows interrupting the training loop usingctrl + C.Also registers aon_interrupt breakpoint so users can register callbacks tobe run on training loop interruption.
Usage
luz_callback_interrupt()Value
Aluz_callback
Note
In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
interrupt_callback <- luz_callback_interrupt()Keep the best model
Description
Each epoch, if there's improvement in the monitored metric we serialize themodel weights to a temp file. When training is done, we reload weights fromthe best model.
Usage
luz_callback_keep_best_model( monitor = "valid_loss", mode = "min", min_delta = 0)Arguments
monitor | A string in the format |
mode | Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better). |
min_delta | Minimum improvement to reset the patience counter. |
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
cb <- luz_callback_keep_best_model()Learning rate scheduler callback
Description
Initializes and runstorch::lr_scheduler()s.
Usage
luz_callback_lr_scheduler( lr_scheduler, ..., call_on = "on_epoch_end", opt_name = NULL)Arguments
lr_scheduler | A |
... | Additional arguments passed to |
call_on | The callback breakpoint that |
opt_name | name of the optimizer that will be affected by this callback.Should match the name given in |
Value
Aluz_callback() generator.
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {cb <- luz_callback_lr_scheduler(torch::lr_step, step_size = 30)}Metrics callback
Description
Tracks metrics passed tosetup() during training and validation.
Usage
luz_callback_metrics()Details
This callback takes care of 2ctx attributes:
ctx$metrics: stores the current metrics objects that are initialized once for epoch,and are furtherupdate()d andcompute()d every batch. You will rarely needto work with these metrics.ctx$records$metrics: Stores metrics per training/validation and epoch. Thestructure is very similar toctx$losses.
Value
Aluz_callback
Note
In general you won't need to explicitly use the metrics callback as it'sused by default infit.luz_module_generator().
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Automatic Mixed Precision callback
Description
This callback will enabletorch::local_autocast() training model forwardand during loss computation. It will then disable autocast and scale the lossbeforebackward() andopt$step(). Seeherefor more information.
Usage
luz_callback_mixed_precision(...)Arguments
... | Passed to |
Value
Aluz_callback
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Mixup callback
Description
Implementation of'mixup: Beyond Empirical Risk Minimization'.As of today, tested only for categorical data,where targets are expected to be integers, not one-hot encoded vectors.This callback is supposed to be used together withnn_mixup_loss().
Usage
luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)Arguments
alpha | parameter for the beta distribution used to sample mixing coefficients |
... | currently unused. Just to force named arguments. |
run_valid | Should it run during validation |
auto_loss | Should it automatically modify the loss function? This will wrapthe loss function to create the mixup loss. If |
Details
Overall, we follow thefastai implementationdescribedhere.Namely,
We work with a single dataloader only, randomly mixing two observations from the same batch.
We linearly combine losses computed for both targets:
loss(output, new_target) = weight * loss(output, target1) + (1-weight) * loss(output, target2)We draw different mixing coefficients for every pair.
We replace
weightwithweight = max(weight, 1-weight)to avoid duplicates.
Value
Aluz_callback
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {mixup_callback <- luz_callback_mixup()}Checkpoints model weights
Description
This saves checkpoints of the model according to the specified metric andbehavior.
Usage
luz_callback_model_checkpoint( path, monitor = "valid_loss", save_best_only = FALSE, mode = "min", min_delta = 0)Arguments
path | Path to save the model on disk. The path is interpolated with |
monitor | A string in the format |
save_best_only | if |
mode | Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better). |
min_delta | Minimum difference to consider as improvement. Only used when |
Note
mode andmin_delta are only used whensave_best_only=TRUE.save_best_only will overwrite the saved models if thepath parameterdon't differentiate by epochs.
Read the checkpointing article in the pkgdown website for moreinformation.
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
luz_callback_model_checkpoint(path= "path/to/dir")luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model.pt")luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model-{monitor:.2f}.pt")Profile callback
Description
Computes the times for high-level operations in the training loops.
Usage
luz_callback_profile()Details
Records are saved inctx$records$profile. Times are stored as seconds.Data is stored in the following structure:
fit time for the entire fit procedure.
epoch times per epoch
Value
Aluz_callback
Note
In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Examples
profile_callback <- luz_callback_profile()Progress callback
Description
Responsible for printing progress during training.
Usage
luz_callback_progress()Value
Aluz_callback
Note
In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().
Printing can be disabled by passingverbose=FALSE tofit.luz_module_generator().
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()
Allow resume model training from a specific checkpoint
Description
Allow resume model training from a specific checkpoint
Usage
luz_callback_resume_from_checkpoint( path, ..., restore_model_state = TRUE, restore_records = FALSE, restore_optimizer_state = FALSE, restore_callbacks_state = FALSE)Arguments
path | Path to the checkpoint that you want to resume. |
... | currently unused. |
restore_model_state | Wether to restore the model state from the checkpoint. |
restore_records | Wether to restore records from the checkpoint. |
restore_optimizer_state | Wether to restore the optimizer state from thecheckpoint. |
restore_callbacks_state | Wether to restore the callbacks state from thecheckpoint. |
Note
Read the checkpointing article in the pkgdown website for moreinformation.
See Also
luz_callback_model_checkpoint()
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_train_valid()
tfevents callback
Description
Logs metrics and other model information in the tfevents file format.Assuming tensorboard is installed, result can be visualized with
Usage
luz_callback_tfevents(logdir = "logs", histograms = FALSE, ...)Arguments
logdir | A directory to where log will be written to. |
histograms | A boolean specifying if histograms of model weights shouldbe logged. It can also be a character vector specifying the name of the parametersthat should be logged (names are the same as |
... | Currently not used. For future expansion. |
Details
tensorboard --logdir=logs
Examples
if (torch::torch_is_installed()) {library(torch)x <- torch_randn(1000, 10)y <- torch_randn(1000, 1)model <- nn_linear %>% setup(loss = nnf_mse_loss, optimizer = optim_adam) %>% set_hparams(in_features = 10, out_features = 1) %>% set_opt_hparams(lr = 1e-4)tmp <- tempfile()model %>% fit(list(x, y), valid_data = 0.2, callbacks = list( luz_callback_tfevents(tmp, histograms = TRUE)))}Train-eval callback
Description
Switches important flags for training and evaluation modes.
Usage
luz_callback_train_valid()Details
It takes care of the threectx attributes:
ctx$model: Responsible for callingctx$model$train()andctx$model$eval(),when appropriate.ctx$training: Sets this flag toTRUEwhen training andFALSEwhen invalidation mode.ctx$loss: Resets thelossattribute tolist()when finished training/ orvalidating.
Value
Aluz_callback
Note
In general you won't need to explicitly use the train_valid callback as it'sused by default infit.luz_module_generator().
See Also
Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint()
Load trained model
Description
Loads a fitted model. See documentation inluz_save().
Usage
luz_load(path)Arguments
path | path in file system to the object. |
See Also
Other luz_save:luz_save()
Loads a checkpoint
Description
Works with checkpoints created typically withluz_callback_model_checkpoint().
Usage
luz_load_checkpoint(obj, path, ...)Arguments
obj | Object to which we want to load the checkpoint. |
path | Path of the checkpoint on disk. |
... | unused. Is there to allow future extensions. |
Loads model weights into a fitted object.
Description
This can be useful when you have saved model checkpoints during training andwant to reload the best checkpoint in the end.
Usage
luz_load_model_weights(obj, path, ...)luz_save_model_weights(obj, path)Arguments
obj | luz object to which you want to copy the new weights. |
path | path to saved model in disk. |
... | other arguments passed to |
Value
ReturnsNULL invisibly.
Warning
luz_save_model_weights operates inplace, ie modifies the model object to contain thenew weights.
Creates a new luz metric
Description
Creates a new luz metric
Usage
luz_metric( name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL)Arguments
name | string naming the new metric. |
... | named list of public methods. You should implement at least |
private | An optional list of private members, which can be functionsand non-functions. |
active | An optional list of active binding functions. |
parent_env | An environment to use as the parent of newly-createdobjects. |
inherit | A R6ClassGenerator object to inherit from; in other words, asuperclass. This is captured as an unevaluated expression which isevaluated in |
Details
In order to implement a newluz_metric we need to implement 3 methods:
initialize: defines the metric initial state. This function iscalled for each epoch for both training and validation loops.update: updates the metric internal state. This function is calledat every training and validation step with the predictions obtained bythe model and the target values obtained from the dataloader.compute: uses the internal state to compute metric values. Thisfunction is called whenever we need to obtain the current metricvalue. Eg, it’s called every training step for metrics displayed inthe progress bar, but only called once per epoch to record it’s valuewhen the progress bar is not displayed.
Optionally, you can implement anabbrev field that gives the metric anabbreviation that will be used when displaying metric information in theconsole or tracking record. If noabbrev is passed, the class namewill be used.
Let’s take a look at the implementation ofluz_metric_accuracy so youcan see how to implement a new one:
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total })Note: It’s good practice that thecompute metric returns regular Rvalues instead of torch tensors and other parts of luz will expect that.
Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
luz_metric_accuracy <- luz_metric( # An abbreviation to be shown in progress bars, or # when printing progress abbrev = "Acc", # Initial setup for the metric. Metrics are initialized # every epoch, for both training and validation initialize = function() { self$correct <- 0 self$total <- 0 }, # Run at every training or validation step and updates # the internal state. The update function takes `preds` # and `target` as parameters. update = function(preds, target) { pred <- torch::torch_argmax(preds, dim = 2) self$correct <- self$correct + (pred == target)$ to(dtype = torch::torch_float())$ sum()$ item() self$total <- self$total + pred$numel() }, # Use the internal state to query the metric value compute = function() { self$correct/self$total })Accuracy
Description
Computes accuracy for multi-class classification problems.
Usage
luz_metric_accuracy()Details
This metric expects to take logits or probabilities at everyupdate. It will then take the columnwise argmax and compareto the target.
Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_accuracy()metric <- metric$new()metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100))metric$compute()}Binary accuracy
Description
Computes the accuracy for binary classification problems where themodel returns probabilities. Commonly used when the loss istorch::nn_bce_loss().
Usage
luz_metric_binary_accuracy(threshold = 0.5)Arguments
threshold | value used to classifiy observations between 0 and 1. |
Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_binary_accuracy(threshold = 0.5)metric <- metric$new()metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100))metric$compute()}Binary accuracy with logits
Description
Computes accuracy for binary classification problems where the modelreturn logits. Commonly used together withtorch::nn_bce_with_logits_loss().
Usage
luz_metric_binary_accuracy_with_logits(threshold = 0.5)Arguments
threshold | value used to classifiy observations between 0 and 1. |
Details
Probabilities are generated usingtorch::nnf_sigmoid() andthreshold is used toclassify between 0 or 1.
Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5)metric <- metric$new()metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100))metric$compute()}Computes the area under the ROC
Description
To avoid storing all predictions and targets for an epoch we compute confusionmatrices across a range of pre-established thresholds.
Usage
luz_metric_binary_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE)Arguments
num_thresholds | Number of thresholds used to compute confusion matrices.In that case, thresholds are created by getting |
thresholds | (optional) If threshold are passed, then those are used to compute theconfusion matrices and |
from_logits | Boolean indicating if predictions are logits, in that casewe use sigmoid to put them in the unit interval. |
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()){library(torch)actual <- c(1, 1, 1, 0, 0, 0)predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)y_true <- torch_tensor(actual)y_pred <- torch_tensor(predicted)m <- luz_metric_binary_auroc(thresholds = predicted)m <- m$new()m$update(y_pred[1:2], y_true[1:2])m$update(y_pred[3:4], y_true[3:4])m$update(y_pred[5:6], y_true[5:6])m$compute()}Mean absolute error
Description
Computes the mean absolute error.
Usage
luz_metric_mae()Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_mae()metric <- metric$new()metric$update(torch_randn(100), torch_randn(100))metric$compute()}Mean squared error
Description
Computes the mean squared error
Usage
luz_metric_mse()Value
A luz_metric object.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_multiclass_auroc(),luz_metric_rmse()
Computes the multi-class AUROC
Description
The same definition asKerasis used by default. This is equivalent to the'micro' method in SciKit Learntoo. Seedocs.
Usage
luz_metric_multiclass_auroc( num_thresholds = 200, thresholds = NULL, from_logits = FALSE, average = c("micro", "macro", "weighted", "none"))Arguments
num_thresholds | Number of thresholds used to compute confusion matrices.In that case, thresholds are created by getting |
thresholds | (optional) If threshold are passed, then those are used to compute theconfusion matrices and |
from_logits | If |
average | The averaging method:
|
Details
Note that class imbalance can affect this metric unlikethe AUC for binary classification.
Currently the AUC is approximated using the 'interpolation' method described inKeras.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {library(torch)actual <- c(1, 1, 1, 0, 0, 0) + 1Lpredicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)predicted <- cbind(1-predicted, predicted)y_true <- torch_tensor(as.integer(actual))y_pred <- torch_tensor(predicted)m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted), average = "micro")m <- m$new()m$update(y_pred[1:2,], y_true[1:2])m$update(y_pred[3:4,], y_true[3:4])m$update(y_pred[5:6,], y_true[5:6])m$compute()}Root mean squared error
Description
Computes the root mean squared error.
Usage
luz_metric_rmse()Value
Returns new luz metric.
See Also
Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc()
Creates a metric set
Description
A metric set can be used to specify metrics that are only evaluated duringtraining, validation or both.
Usage
luz_metric_set(metrics = NULL, train_metrics = NULL, valid_metrics = NULL)Arguments
metrics | A list of luz_metrics that are meant to be used in both trainingand validation. |
train_metrics | A list of luz_metrics that are only used during training. |
valid_metrics | A list of luz_metrics that are only sued for validation. |
Saves luz objects to disk
Description
Allows saving luz fitted models to the disk. Objects can be loaded back withluz_load().
Usage
luz_save(obj, path, ...)Arguments
obj | an object of class 'luz_module_fitted' as returned by |
path | path in file system to the object. |
... | currently unused. |
Warning
Thectx is naively serialized. Ie, we only usesaveRDS() to serialize it.Don't expectluz_save to work correctly if you have unserializable objectsin thectx liketorch_tensors and external pointers in general.
Note
Objects are saved as plain.rds files butobj$model is serializedwithtorch_save before saving it.
See Also
Other luz_save:luz_load()
Loss to be used withcallbacks_mixup().
Description
In the training phase, computes individual losses with regard to two targets, weights them item-wise,and averages the linear combinations to yield the mean batch loss.For validation and testing, defers to the passed-in loss.
Usage
nn_mixup_loss(loss)Arguments
loss | the underlying loss |
Details
It should be used together withluz_callback_mixup().
See Also
Mixup logic
Description
Logic underlyingluz_callback_mixup().
Usage
nnf_mixup(x, y, weight)Arguments
x | an input batch |
y | a target batch |
weight | weighting coefficient to be used by |
Details
Based on the passed-in input and target batches, as well as applicable mixing weights,we return new tensors intended to replace the current batch.The new input batch is a weighted linear combination of input batch items, whilethe new target batch bundles the original targets, as well as the mixing weights, ina nested list.
Value
Alist of:
x, the new, mixed-up input batchy, alistof:ys, alistof:y1, the original targety1y2, the mixed-in targety2
weight, the mixing weights
See Also
Examples
if (torch::torch_is_installed()) {batch_x <- torch::torch_randn(c(10, 768))batch_y <- torch::torch_randn(10)weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))nnf_mixup(batch_x, batch_y, weight)}Create predictions for a fitted model
Description
Create predictions for a fitted model
Usage
## S3 method for class 'luz_module_fitted'predict( object, newdata, ..., callbacks = list(), accelerator = NULL, verbose = NULL, dataloader_options = NULL)Arguments
object | (fitted model) the fitted model object returned from |
newdata | (dataloader, dataset, list or array) returning a list with atleast 1 element. The other elements aren't used. |
... | Currently unused. |
callbacks | (list, optional) A list of callbacks defined with |
accelerator | (accelerator, optional) An optional |
verbose | (logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output if |
dataloader_options | Options used when creating a dataloader. See |
See Also
Other training:evaluate(),fit.luz_module_generator(),setup()
Objects exported from other packages
Description
These objects are imported from other packages. Follow the linksbelow to see their documentation.
- generics
Set hyper-parameter of a module
Description
This function is used to define hyper-parameters before callingfit forluz_modules.
Usage
set_hparams(module, ...)Arguments
module | An |
... | The parameters set here will be used to initialize the |
Value
The same luz module
See Also
Other set_hparam:set_opt_hparams()
Set optimizer hyper-parameters
Description
This function is used to define hyper-parameters for the optimizer initializationmethod.
Usage
set_opt_hparams(module, ...)Arguments
module | An |
... | The parameters passed here will be used to initialize the optimizers.For example, if your optimizer is |
Value
The same luz module
See Also
Other set_hparam:set_hparams()
Set's up ann_module to use with luz
Description
The setup function is used to set important attributes and method fornn_modulesto be used with luz.
Usage
setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)Arguments
module | ( |
loss | ( |
optimizer | ( |
metrics | ( |
backward | ( |
Details
It makes sure the module have all the necessary ingredients in order to be fitted.
Value
A luz module that can be trained withfit().
Note
It also adds adevice active field that can be used to query the currentmoduledevice within methods, with egself$device. This is useful whenctx() is not available, eg, when calling methods from outside theluzwrappers. Users can override the default by implementing adevice activemethod in the inputmodule.
See Also
Other training:evaluate(),fit.luz_module_generator(),predict.luz_module_fitted()