Movatterモバイル変換


[0]ホーム

URL:


Title:Higher Level 'API' for 'torch'
Version:0.5.1
Description:A high level interface for 'torch' providing utilities to reduce the the amount of code needed for common tasks, abstract away torch details and make the same code work on both the 'CPU' and 'GPU'. It's flexible enough to support expressing a large range of models. It's heavily inspired by 'fastai' by Howard et al. (2020) <doi:10.48550/arXiv.2002.04688>, 'Keras' by Chollet et al. (2015) and 'PyTorch Lightning' by Falcon et al. (2019) <doi:10.5281/zenodo.3828935>.
License:MIT + file LICENSE
URL:https://mlverse.github.io/luz/,https://github.com/mlverse/luz
Encoding:UTF-8
RoxygenNote:7.3.2
Imports:torch (≥ 0.11.9000), magrittr, zeallot, rlang (≥ 1.0.0),coro, glue, progress, R6, generics, purrr, fs, prettyunits, cli
Suggests:knitr, rmarkdown, testthat (≥ 3.0.0), covr, Metrics, withr,vdiffr, ggplot2 (≥ 3.0.0), dplyr, torchvision, tfevents (≥0.0.2), tidyr
VignetteBuilder:knitr
Config/testthat/edition:3
Collate:'accelerator.R' 'as_dataloader.R' 'utils.R' 'callbacks.R''callbacks-amp.R' 'callbacks-interrupt.R' 'callbacks-mixup.R''callbacks-monitor-metrics.R' 'callbacks-profile.R''callbacks-resume.R' 'callbacks-tfevents.R' 'context.R''losses.R' 'lr-finder.R' 'metrics.R' 'metrics-auc.R''module-plot.R' 'module-print.R' 'module.R' 'reexports.R''serialization.R'
BugReports:https://github.com/mlverse/luz/issues
NeedsCompilation:no
Packaged:2025-10-30 11:37:48 UTC; dfalbel
Author:Daniel Falbel [aut, cre, cph], Christophe Regouby [ctb], RStudio [cph]
Maintainer:Daniel Falbel <daniel@rstudio.com>
Repository:CRAN
Date/Publication:2025-10-30 12:30:02 UTC

Pipe operator

Description

Seemagrittr::%>% for details.

Usage

lhs %>% rhs

Create an accelerator

Description

Create an accelerator

Usage

accelerator(  device_placement = TRUE,  cpu = FALSE,  cuda_index = torch::cuda_current_device())

Arguments

device_placement

(logical) whether theaccelerator object shouldhandle device placement. Default:TRUE

cpu

(logical) whether the training procedure should run on the CPU.

cuda_index

(integer) index of the CUDA device to use if multiple GPUsare available. Default: the result of torch::cuda_current_device().


Creates a dataloader from its input

Description

as_dataloader is used internally by luz to convert inputdata andvalid_data as passed tofit.luz_module_generator() to atorch::dataloader

Usage

as_dataloader(x, ...)## S3 method for class 'dataset'as_dataloader(x, ..., batch_size = 32)## S3 method for class 'iterable_dataset'as_dataloader(x, ..., batch_size = 32)## S3 method for class 'list'as_dataloader(x, ...)## S3 method for class 'dataloader'as_dataloader(x, ...)## S3 method for class 'matrix'as_dataloader(x, ...)## S3 method for class 'numeric'as_dataloader(x, ...)## S3 method for class 'array'as_dataloader(x, ...)## S3 method for class 'torch_tensor'as_dataloader(x, ...)

Arguments

x

the input object.

...

Passed totorch::dataloader().

batch_size

(int, optional): how many samples per batch to load(default:1).

Details

as_dataloader methods should have sensible defaults for batch_size,parallel workers, etc.

It allows users to quickly experiment withfit.luz_module_generator() by not requiringto create atorch::dataset and atorch::dataloader in simpleexperiments.

Methods (by class)

Overriding

You can implement your ownas_dataloader S3 method if you want your datastructure to be automatically supported by luz'sfit.luz_module_generator().The method must satisfy the following conditions:

It's better to avoid implementingas_dataloader methods for common S3 classeslikedata.frames. In this case, its better to assign a different class tothe inputs and implementas_dataloader for it.


Context object

Description

Context object storing information about the model training context.See alsoctx.

Public fields

buffers

This is a list of buffers that callbacks can use to write temporaryinformation intoctx.

Active bindings

records

stores information about values logged withself$log.

device

allows querying the current accelerator device

callbacks

list of callbacks that will be called.

iter

current iteration

batch

the current batch data. a list with input data and targets.

input

a shortcut forctx$batch[[1]]

target

a shortcut forctx$batch[[2]]

min_epochs

the minimum number of epochs that the model will run on.

max_epochs

the maximum number of epochs that the model will run.

hparams

a list of hyperparameters that were used to initializectx$model.

opt_hparams

a list of hyperparameters used to initialize thectx$optimizers.

train_data

a dataloader that is used for training the model

valid_data

a dataloader using during model validation

accelerator

anaccelerator() used to move data, model and etc the the correctdevice.

optimizers

a named list of optimizers that will be used during model training.

verbose

bool wether the process is in verbose mode or not.

handlers

List of error handlers that can be used. Seerlang::try_fetch()for more info.

epoch_handlers

List of error handlers that can be used. Seerlang::try_fetch()for more info.

training

A bool indicating if the model is in training or validation mode.

model

The model being trained.

pred

Last predicted values.

opt

Current optimizer.

opt_name

Current optimizer name.

data

Current dataloader in use.

loss_fn

Loss function used to train the model

loss

Last computed loss values. Detached from the graph.

loss_grad

Last computed loss value, not detached, so you can do additionaltranformation.

epoch

Current epoch.

metrics

List of metrics that are tracked by the process.

step_opt

Defines how step is called for the optimizer. It must be a functiontaking an optimizer as argument.

Methods

Public methods


Methodnew()

Initializes the context object with minimal necessary information.

Usage
context$new(verbose, accelerator, callbacks, training)
Arguments
verbose

Whether the context should be in verbose mode or not.

accelerator

A luzaccelerator() that configures device placement andothers.

callbacks

A list of callbacks used by the model. Seeluz_callback().

training

A boolean that indicates if the context is in training mode or not.


Methodlog()

Allows logging arbitrary information in thectx.

Usage
context$log(what, set, value, index = NULL, append = TRUE)
Arguments
what

(string) What you are logging.

set

(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.

value

Arbitrary value to log.

index

Index that this value should be logged. IfNULL the valueis added to the end of list, otherwise the index is used.

append

IfTRUE and a value in the corresponding index alreadyexists, then value is appended to the current value. IfFALSE valueis overwritten in favor of the new value.


Methodlog_metric()

Log a metric by its name and value.Metric values are indexed by epoch.

Usage
context$log_metric(name, value)
Arguments
name

name of the metric

value

Arbitrary value to log.


Methodget_log()

Get a specific value from the log.

Usage
context$get_log(what, set, index = NULL)
Arguments
what

(string) What you are logging.

set

(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.

index

Index that this value should be logged. IfNULL the valueis added to the end of list, otherwise the index is used.


Methodget_metrics()

Get all metric given an epoch and set.

Usage
context$get_metrics(set, epoch = NULL)
Arguments
set

(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.

epoch

The epoch you want to extract metrics from.


Methodget_metric()

Get the value of a metric given its name, epoch and set.

Usage
context$get_metric(name, set, epoch = NULL)
Arguments
name

name of the metric

set

(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.

epoch

The epoch you want to extract metrics from.


Methodget_formatted_metrics()

Get formatted metrics values

Usage
context$get_formatted_metrics(set, epoch = NULL)
Arguments
set

(string) Usually 'train' or 'valid' indicating the set you wantto log to. But can be arbitrary info.

epoch

The epoch you want to extract metrics from.


Methodget_metrics_df()

Get a data.frame containing all metrics.

Usage
context$get_metrics_df()

Methodset_verbose()

Allows setting theverbose attribute.

Usage
context$set_verbose(verbose = NULL)
Arguments
verbose

boolean. IfTRUE verbose mode is used. IfFALSE non verbose.ifNULL we use the result ofinteractive().


Methodclean()

Removes unnecessary information from the context object.

Usage
context$clean()

Methodcall_callbacks()

Call the selected callbacks. Wherename is the callback types to call, eg'on_epoch_begin'.

Usage
context$call_callbacks(name)
Arguments
name

name of the metric


Methodstate_dict()

Returns a list containing minimal information from the context. Used tocreate the returned values.

Usage
context$state_dict()

Methodunsafe_set_records()

Are you sure you know what you are doing?

Usage
context$unsafe_set_records(records)
Arguments
records

New set of records to be set.


Methodclone()

The objects of this class are cloneable with this method.

Usage
context$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.


Context object

Description

Context objects used in luz to share information between model methods,metrics and callbacks.

Details

Thectx object is used in luz to share information between thetraining loop and callbacks, model methods, and metrics. The table belowdescribes information available in thectx by default. Other callbackscould potentially modify these attributes or add new ones.

Attribute Description
verbose The value (TRUE orFALSE) attributed to theverbose argument infit .
accelerator Accelerator object used to query the correct device to place models, data, etc. It assumes the value passed to theaccelerator parameter infit.
model Initializednn_module object that will be trained during thefit procedure.
optimizers A named list of optimizers used during training.
data The currently in-use dataloader. When training it’sctx$train_data, when doing validation itsctx$valid_data. It can also be the prediction dataset when inpredict.
train_data Dataloader passed to thedata argument infit. Modified to yield data in the selected device.
valid_data Dataloader passed to thevalid_data argument infit. Modified to yield data in the selected device.
min_epochs Minimum number of epochs the model will be trained for.
max_epochs Maximum number of epochs the model will be trained for.
epoch Current training epoch.
iter Current training iteration. It’s reset every epoch and when going from training to validation.
training Whether the model is in training or validation mode. See alsohelp("luz_callback_train_valid")
callbacks List of callbacks that will be called during the training procedure. It’s the union of the list passed to thecallbacks parameter and the defaultcallbacks.
step Closure that will be used to do onestep of the model. It’s used for both training and validation. Takes no argument, but can access thectx object.
call_callbacks Call callbacks by name. For examplecall_callbacks("on_train_begin") will call all callbacks that provide methods for this point.
batch Last batch obtained by the dataloader. A batch is alist() with 2 elements, one that is used asinput and the other astarget.
input First element of the last batch obtained by the current dataloader.
target Second element of the last batch obtained by the current dataloader.
pred Last predictions obtained byctx$model$forward .Note: can be potentially modified by previously ran callbacks. Also note that this might not be available if you used a custom training step.
loss_fn The active loss function that will be minimized during training.
loss Last computed loss from the model.Note: this might not be available if you modified the training or validation step.
opt Current optimizer, ie. the optimizer that will be used to do the nextstep to update parameters.
opt_nm Current optimizer name. By default it’sopt , but can change if your model uses more than one optimizer depending on the set of parameters being optimized.
metricslist() with current metric objects that areupdated at everyon_train_batch_end() oron_valid_batch_end(). See alsohelp("luz_callback_metrics")
recordslist() recording metric values for training and validation for each epoch. See alsohelp("luz_callback_metrics") . Also records profiling metrics. Seehelp("luz_callback_profile") for more information.
handlers A namedlist() of handlers that is passed torlang::with_handlers() during the training loop and can be used to handle errors or conditions that might be raised by other callbacks.
epoch_handlers A named list of handlers that is used withrlang::with_handlers(). Those handlers are used inside the epochs loop, thus you can handle epoch specific conditions, that won’t necessarily end training.

Context attributes

See Also

Context object:context


Evaluates a fitted model on a dataset

Description

Evaluates a fitted model on a dataset

Usage

evaluate(  object,  data,  ...,  metrics = NULL,  callbacks = list(),  accelerator = NULL,  verbose = NULL,  dataloader_options = NULL)

Arguments

object

A fitted model to evaluate.

data

(dataloader, dataset or list) A dataloader created withtorch::dataloader() used for training the model, or a dataset createdwithtorch::dataset() or a list. Dataloaders and datasets must return alist with at most 2 items. The first item will be used as input for themodule and the second will be used as a target for the loss function.

...

Currently unused.

metrics

A list of luz metrics to be tracked during evaluation. IfNULL(default) then the same metrics that were used during training are tracked.

callbacks

(list, optional) A list of callbacks defined withluz_callback() that will be called during the training procedure. Thecallbacksluz_callback_metrics(),luz_callback_progress() andluz_callback_train_valid() are always added by default.

accelerator

(accelerator, optional) An optionalaccelerator() objectused to configure device placement of the components liketorch::nn_modules,optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output ifinteractive() isTRUE, otherwiseit won't print to the console.

dataloader_options

Options used when creating a dataloader. Seetorch::dataloader().shuffle=TRUE by default for the training data andbatch_size=32 by default. It will error if notNULL anddata isalready a dataloader.

Details

Once a model has been trained you might want to evaluate its performanceon a different dataset. For that reason, luz provides the?evaluatefunction that takes a fitted model and a dataset and computes themetrics attached to the model.

Evaluate returns aluz_module_evaluation object that you can query formetrics using theget_metrics function or simplyprint to see theresults.

For example:

evaluation <- fitted %>% evaluate(data = valid_dl)metrics <- get_metrics(evaluation)print(evaluation)

See Also

Other training:fit.luz_module_generator(),predict.luz_module_fitted(),setup()


Fit ann_module

Description

Fit ann_module

Usage

## S3 method for class 'luz_module_generator'fit(  object,  data,  epochs = 10,  callbacks = NULL,  valid_data = NULL,  accelerator = NULL,  verbose = NULL,  ...,  dataloader_options = NULL)

Arguments

object

Annn_module that has beensetup().

data

(dataloader, dataset or list) A dataloader created withtorch::dataloader() used for training the model, or a dataset createdwithtorch::dataset() or a list. Dataloaders and datasets must return alist with at most 2 items. The first item will be used as input for themodule and the second will be used as a target for the loss function.

epochs

(int) The maximum number of epochs for training the model. If asingle value is provided, this is taken to be themax_epochs andmin_epochs is set to 0. If a vector of two numbers is provided, the firstvalue ismin_epochs and the second value ismax_epochs. The minimum andmaximum number of epochs are included in the context object asctx$min_epochs andctx$max_epochs, respectively.

callbacks

(list, optional) A list of callbacks defined withluz_callback() that will be called during the training procedure. Thecallbacksluz_callback_metrics(),luz_callback_progress() andluz_callback_train_valid() are always added by default.

valid_data

(dataloader, dataset, list or scalar value; optional) Adataloader created withtorch::dataloader() or a dataset created withtorch::dataset() that will be used during the validation procedure. Theymust return a list with (input, target). Ifdata is a torch dataset or alist, then you can also supply a numeric value between 0 and 1 - and inthis case a random sample with size corresponding to that proportion fromdata will be used for validation.

accelerator

(accelerator, optional) An optionalaccelerator() objectused to configure device placement of the components liketorch::nn_modules,optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output ifinteractive() isTRUE, otherwiseit won't print to the console.

...

Currently unused.

dataloader_options

Options used when creating a dataloader. Seetorch::dataloader().shuffle=TRUE by default for the training data andbatch_size=32 by default. It will error if notNULL anddata isalready a dataloader.

Value

A fitted object that can be saved withluz_save() and can beprinted withprint() and plotted withplot().

See Also

predict.luz_module_fitted() for how to create predictions.setup() to find out how to create modules that can be trained withfit.

Other training:evaluate(),predict.luz_module_fitted(),setup()


Get metrics from the object

Description

Get metrics from the object

Usage

get_metrics(object, ...)## S3 method for class 'luz_module_fitted'get_metrics(object, ...)

Arguments

object

The object to query for metrics.

...

Currently unused.

Value

A data.frame containing the metric values.

Methods (by class)


Learning Rate Finder

Description

Learning Rate Finder

Usage

lr_finder(  object,  data,  steps = 100,  start_lr = 1e-07,  end_lr = 0.1,  log_spaced_intervals = TRUE,  ...,  verbose = NULL)

Arguments

object

An nn_module that has been setup().

data

(dataloader) A dataloader created with torch::dataloader() used for learning rate finding.

steps

(integer) The number of steps to iterate over in the learning rate finder. Default: 100.

start_lr

(float) The smallest learning rate. Default: 1e-7.

end_lr

(float) The highest learning rate. Default: 1e-1.

log_spaced_intervals

(logical) Whether to divide the range between start_lr and end_lr into log-spaced intervals (alternative: uniform intervals). Default: TRUE

...

Other arguments passed tofit.

verbose

Wether to show a progress bar during the process.

Value

A dataframe with two columns: learning rate and loss

Examples

if (torch::torch_is_installed()) {library(torch)ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1))dl <- torch::dataloader(ds, batch_size = 32)model <- torch::nn_linearmodel <- model %>% setup(  loss = torch::nn_mse_loss(),  optimizer = torch::optim_adam) %>%  set_hparams(in_features = 10, out_features = 1)records <- lr_finder(model, dl, verbose = FALSE)plot(records)}

Create a new callback

Description

Create a new callback

Usage

luz_callback(  name = NULL,  ...,  private = NULL,  active = NULL,  parent_env = parent.frame(),  inherit = NULL)

Arguments

name

name of the callback

...

Public methods of the callback. The name of the methods is usedto know how they should be called. See the details section.

private

An optional list of private members, which can be functionsand non-functions.

active

An optional list of active binding functions.

parent_env

An environment to use as the parent of newly-createdobjects.

inherit

A R6ClassGenerator object to inherit from; in other words, asuperclass. This is captured as an unevaluated expression which isevaluated inparent_env each time an object is instantiated.

Details

Let’s implement a callback that prints ‘Iterationn’ (wheren is theiteration number) for every batch in the training set and ‘Done’ when anepoch is finished. For that task we use theluz_callback function:

print_callback <- luz_callback(  name = "print_callback",  initialize = function(message) {    self$message <- message  },  on_train_batch_end = function() {    cat("Iteration ", ctx$iter, "\n")  },  on_epoch_end = function() {    cat(self$message, "\n")  })

luz_callback() takes named functions as... arguments, where thename indicates the moment at which the callback should be called. Forinstanceon_train_batch_end() is called for every batch at the end ofthe training procedure, andon_epoch_end() is called at the end ofevery epoch.

The returned value ofluz_callback() is a function that initializes aninstance of the callback. Callbacks can have initialization parameters,like the name of a file where you want to log the results. In that case,you can pass aninitialize method when creating the callbackdefinition, and save these parameters to theself object. In the aboveexample, the callback has amessage parameter that is printed at theend of each epoch.

Once a callback is defined it can be passed to thefit function viathecallbacks parameter:

fitted <- net %>%  setup(...) %>%  fit(..., callbacks = list(    print_callback(message = "Done!")  ))

Callbacks can be called in many different positions of the trainingloop, including combinations of them. Here’s an overview of possiblecallbackbreakpoints:

Start Fit   - on_fit_begin  Start Epoch Loop     - on_epoch_begin    Start Train       - on_train_begin      Start Batch Loop         - on_train_batch_begin          Start Default Training Step            - on_train_batch_after_pred            - on_train_batch_after_loss            - on_train_batch_before_backward            - on_train_batch_before_step            - on_train_batch_after_step          End Default Training Step:         - on_train_batch_end      End Batch Loop       - on_train_end    End Train    Start Valid       - on_valid_begin      Start Batch Loop         - on_valid_batch_begin          Start Default Validation Step            - on_valid_batch_after_pred            - on_valid_batch_after_loss          End Default Validation Step         - on_valid_batch_end      End Batch Loop       - on_valid_end    End Valid      - on_epoch_end  End Epoch Loop   - on_fit_endEnd Fit

Every step marked with⁠on_*⁠ is a point in the training procedure thatis available for callbacks to be called.

The other important part of callbacks is thectx (context) object. Seehelp("ctx") for details.

By default, callbacks are called in the same order as they were passedtofit (orpredict orevaluate), but you can provide aweightattribute that will control the order in which it will be called. Forexample, if one callback hasweight = 10 and another hasweight = 1,then the first one is called after the second one. Callbacks that don’tspecify aweight attribute are consideredweight = 0. A few built-incallbacks in luz already provide a weight value. For example, the?luz_callback_early_stopping has a weight ofInf, since in generalwe want to run it as the last thing in the loop.

Value

Aluz_callback that can be passed tofit.luz_module_generator().

Prediction callbacks

You can also use callbacks when usingpredict(). In this case the supportedcallback methods are detailed below:

Start predict - on_predict_begin Start prediction loop  - on_predict_batch_begin  - on_predict_batch_end End prediction loop - on_predict_endEnd predict

Evaluate callbacks

Callbacks can also be used withevaluate(), in this case, the callbacks thatare used are equivalent to those of the validation loop when usingfit():

Start Valid - on_valid_begin Start Batch Loop  - on_valid_batch_begin  Start Default Validation Step   - on_valid_batch_after_pred   - on_valid_batch_after_loss  End Default Validation Step  - on_valid_batch_end End Batch Loop - on_valid_endEnd Valid

See Also

Other luz_callbacks:luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

print_callback <- luz_callback( name = "print_callback", on_train_batch_end = function() {   cat("Iteration ", ctx$iter, "\n") }, on_epoch_end = function() {   cat("Done!\n") })

Resume training callback

Description

This callback allows you to resume training a model.

Usage

luz_callback_auto_resume(path = "./state.pt")

Arguments

path

Path to save state files for the model.

Details

When using it, model weights, optimizer state are serialized at the end ofeach epoch. If something fails during training simply re-running the samescript will restart the model training from the epoch right after the lastepoch that was serialized.

Customizing serialization

By default model, optimizer state and records are serialized. Callbacks canbe used to customize serialization by implementing thestate_dict() andload_state_dict() methods.If those methods are implemented, thenstate_dict() is called at the end ofeach epoch andload_state_dict() is called when the model is resumed.

Note

In general you will want to add this callback as the last in the callbackslist, this way, the serialized state is likely to contain all possible changesthat other callbacks could have made at'on_epoch_end'. The defaultweightattribute of this callback isInf.

Read the checkpointing article in the pkgdown website for moreinformation.

See Also

Other luz_callbacks:luz_callback(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

if (torch::torch_is_installed()) {library(torch)library(luz)x <- torch_randn(1000, 10)y <- torch_randn(1000, 1)model <- nn_linear %>%  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%  set_hparams(in_features = 10, out_features = 1) %>%  set_opt_hparams(lr = 0.01)# simulate a failure in the middle of epoch 5 happening only once.callback_stop <- luz_callback(  "interrupt",  failed = FALSE,  on_epoch_end = function() {    if (ctx$epoch == 5 && !self$failed) {      self$failed <- TRUE      stop("Error on epoch 5")    }  })path <- tempfile()autoresume <- luz_callback_auto_resume(path = path)interrupt <- callback_stop()# try once and the model failstry({  results <- model %>% fit(    list(x, y),    callbacks = list(autoresume, interrupt),    verbose = FALSE  )})# model resumes and completesresults <- model %>% fit(  list(x, y),  callbacks = list(autoresume, interrupt),  verbose = FALSE)get_metrics(results)}

CSV logger callback

Description

Logs metrics obtained during training a file on disk.The file will have 1 line for each epoch/validation.

Usage

luz_callback_csv_logger(path)

Arguments

path

path to a file on disk.

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()


Early stopping callback

Description

Stops training when a monitored metric stops improving

Usage

luz_callback_early_stopping(  monitor = "valid_loss",  min_delta = 0,  patience = 0,  mode = "min",  baseline = NULL)

Arguments

monitor

A string in the format⁠<set>_<metric>⁠ where⁠<set>⁠ can be'train' or 'valid' and⁠<metric>⁠ can be the abbreviation of any metricthat you are tracking during training. The metric name is case insensitive.

min_delta

Minimum improvement to reset the patience counter.

patience

Number of epochs without improving until stoping training.

mode

Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better).

baseline

An initial value that will be used as the best seen valuein the begining. Model will stop training if no better than baseline valueis found in the firstpatience epochs.

Value

Aluz_callback that does early stopping.

Note

This callback adds aon_early_stopping callback that can be used tocall callbacks as soon as the model stops training.

Ifverbose=TRUE infit.luz_module_generator() a message is printed whenearly stopping.

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

cb <- luz_callback_early_stopping()

Gradient clipping callback

Description

By adding the GradientClip callback, the gradientnorm_type (default:2) normis clipped to at mostmax_norm (default:1) usingtorch::nn_utils_clip_grad_norm_(),which can avoid loss divergence.

Usage

luz_callback_gradient_clip(max_norm = 1, norm_type = 2)

Arguments

max_norm

(float or int): max norm of the gradients

norm_type

(float or int): type of the used p-norm. Can beInf forinfinity norm.

References

See FastAIdocumentationfor the GradientClip callback.


Interrupt callback

Description

Adds a handler that allows interrupting the training loop usingctrl + C.Also registers aon_interrupt breakpoint so users can register callbacks tobe run on training loop interruption.

Usage

luz_callback_interrupt()

Value

Aluz_callback

Note

In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

interrupt_callback <- luz_callback_interrupt()

Keep the best model

Description

Each epoch, if there's improvement in the monitored metric we serialize themodel weights to a temp file. When training is done, we reload weights fromthe best model.

Usage

luz_callback_keep_best_model(  monitor = "valid_loss",  mode = "min",  min_delta = 0)

Arguments

monitor

A string in the format⁠<set>_<metric>⁠ where⁠<set>⁠ can be'train' or 'valid' and⁠<metric>⁠ can be the abbreviation of any metricthat you are tracking during training. The metric name is case insensitive.

mode

Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better).

min_delta

Minimum improvement to reset the patience counter.

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

cb <- luz_callback_keep_best_model()

Learning rate scheduler callback

Description

Initializes and runstorch::lr_scheduler()s.

Usage

luz_callback_lr_scheduler(  lr_scheduler,  ...,  call_on = "on_epoch_end",  opt_name = NULL)

Arguments

lr_scheduler

Atorch::lr_scheduler() that will be initialized withthe optimizer and the... parameters.

...

Additional arguments passed tolr_scheduler together withthe optimizers.

call_on

The callback breakpoint thatscheduler$step() is called.Default is'on_epoch_end'. Seeluz_callback() for more information.

opt_name

name of the optimizer that will be affected by this callback.Should match the name given inset_optimizers. If your module has a singleoptimizer,opt_name is not used.

Value

Aluz_callback() generator.

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

if (torch::torch_is_installed()) {cb <- luz_callback_lr_scheduler(torch::lr_step, step_size = 30)}

Metrics callback

Description

Tracks metrics passed tosetup() during training and validation.

Usage

luz_callback_metrics()

Details

This callback takes care of 2ctx attributes:

Value

Aluz_callback

Note

In general you won't need to explicitly use the metrics callback as it'sused by default infit.luz_module_generator().

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()


Automatic Mixed Precision callback

Description

This callback will enabletorch::local_autocast() training model forwardand during loss computation. It will then disable autocast and scale the lossbeforebackward() andopt$step(). Seeherefor more information.

Usage

luz_callback_mixed_precision(...)

Arguments

...

Passed totorch::cuda_amp_grad_scaler().

Value

Aluz_callback

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()


Mixup callback

Description

Implementation of'mixup: Beyond Empirical Risk Minimization'.As of today, tested only for categorical data,where targets are expected to be integers, not one-hot encoded vectors.This callback is supposed to be used together withnn_mixup_loss().

Usage

luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)

Arguments

alpha

parameter for the beta distribution used to sample mixing coefficients

...

currently unused. Just to force named arguments.

run_valid

Should it run during validation

auto_loss

Should it automatically modify the loss function? This will wrapthe loss function to create the mixup loss. IfTRUE make sure that your lossfunction does not apply reductions. Ifrun_valid=FALSE, then loss will bemean reduced during validation.

Details

Overall, we follow thefastai implementationdescribedhere.Namely,

Value

Aluz_callback

See Also

nn_mixup_loss(),nnf_mixup()

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

if (torch::torch_is_installed()) {mixup_callback <- luz_callback_mixup()}

Checkpoints model weights

Description

This saves checkpoints of the model according to the specified metric andbehavior.

Usage

luz_callback_model_checkpoint(  path,  monitor = "valid_loss",  save_best_only = FALSE,  mode = "min",  min_delta = 0)

Arguments

path

Path to save the model on disk. The path is interpolated withglue,so you can use any attribute within thectx by using'{ctx$epoch}'. Speciallytheepoch andmonitor quantities are already in the environment. If the specifiedpath is a path to a directory (ends with/ or⁠\⁠), then models are saved with the name given by⁠epoch-{epoch:02d}-{self$monitor}-{monitor:.3f}.pt⁠. See more in the examples.You can usesprintf() to quickly format quantities, for example:'{epoch:02d}'.

monitor

A string in the format⁠<set>_<metric>⁠ where⁠<set>⁠ can be'train' or 'valid' and⁠<metric>⁠ can be the abbreviation of any metricthat you are tracking during training. The metric name is case insensitive.

save_best_only

ifTRUE models are only saved if they have an improvementover a previously saved model.

mode

Specifies the direction that is considered an improvement. By default'min' is used. Can also be 'max' (higher is better) and 'zero'(closer to zero is better).

min_delta

Minimum difference to consider as improvement. Only used whensave_best_only=TRUE.

Note

mode andmin_delta are only used whensave_best_only=TRUE.save_best_only will overwrite the saved models if thepath parameterdon't differentiate by epochs.

Read the checkpointing article in the pkgdown website for moreinformation.

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

luz_callback_model_checkpoint(path= "path/to/dir")luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model.pt")luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model-{monitor:.2f}.pt")

Profile callback

Description

Computes the times for high-level operations in the training loops.

Usage

luz_callback_profile()

Details

Records are saved inctx$records$profile. Times are stored as seconds.Data is stored in the following structure:

Value

Aluz_callback

Note

In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_progress(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()

Examples

profile_callback <- luz_callback_profile()

Progress callback

Description

Responsible for printing progress during training.

Usage

luz_callback_progress()

Value

Aluz_callback

Note

In general you don't need to use these callback by yourself because it's alwaysincluded by default infit.luz_module_generator().

Printing can be disabled by passingverbose=FALSE tofit.luz_module_generator().

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_resume_from_checkpoint(),luz_callback_train_valid()


Allow resume model training from a specific checkpoint

Description

Allow resume model training from a specific checkpoint

Usage

luz_callback_resume_from_checkpoint(  path,  ...,  restore_model_state = TRUE,  restore_records = FALSE,  restore_optimizer_state = FALSE,  restore_callbacks_state = FALSE)

Arguments

path

Path to the checkpoint that you want to resume.

...

currently unused.

restore_model_state

Wether to restore the model state from the checkpoint.

restore_records

Wether to restore records from the checkpoint.

restore_optimizer_state

Wether to restore the optimizer state from thecheckpoint.

restore_callbacks_state

Wether to restore the callbacks state from thecheckpoint.

Note

Read the checkpointing article in the pkgdown website for moreinformation.

See Also

luz_callback_model_checkpoint()

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_train_valid()


tfevents callback

Description

Logs metrics and other model information in the tfevents file format.Assuming tensorboard is installed, result can be visualized with

Usage

luz_callback_tfevents(logdir = "logs", histograms = FALSE, ...)

Arguments

logdir

A directory to where log will be written to.

histograms

A boolean specifying if histograms of model weights shouldbe logged. It can also be a character vector specifying the name of the parametersthat should be logged (names are the same asnames(model$parameters)).

...

Currently not used. For future expansion.

Details

tensorboard --logdir=logs

Examples

if (torch::torch_is_installed()) {library(torch)x <- torch_randn(1000, 10)y <- torch_randn(1000, 1)model <- nn_linear %>%  setup(loss = nnf_mse_loss, optimizer = optim_adam) %>%  set_hparams(in_features = 10, out_features = 1) %>%  set_opt_hparams(lr = 1e-4)tmp <- tempfile()model %>% fit(list(x, y), valid_data = 0.2, callbacks = list(  luz_callback_tfevents(tmp, histograms = TRUE)))}

Train-eval callback

Description

Switches important flags for training and evaluation modes.

Usage

luz_callback_train_valid()

Details

It takes care of the threectx attributes:

Value

Aluz_callback

Note

In general you won't need to explicitly use the train_valid callback as it'sused by default infit.luz_module_generator().

See Also

Other luz_callbacks:luz_callback(),luz_callback_auto_resume(),luz_callback_csv_logger(),luz_callback_early_stopping(),luz_callback_interrupt(),luz_callback_keep_best_model(),luz_callback_lr_scheduler(),luz_callback_metrics(),luz_callback_mixed_precision(),luz_callback_mixup(),luz_callback_model_checkpoint(),luz_callback_profile(),luz_callback_progress(),luz_callback_resume_from_checkpoint()


Load trained model

Description

Loads a fitted model. See documentation inluz_save().

Usage

luz_load(path)

Arguments

path

path in file system to the object.

See Also

Other luz_save:luz_save()


Loads a checkpoint

Description

Works with checkpoints created typically withluz_callback_model_checkpoint().

Usage

luz_load_checkpoint(obj, path, ...)

Arguments

obj

Object to which we want to load the checkpoint.

path

Path of the checkpoint on disk.

...

unused. Is there to allow future extensions.


Loads model weights into a fitted object.

Description

This can be useful when you have saved model checkpoints during training andwant to reload the best checkpoint in the end.

Usage

luz_load_model_weights(obj, path, ...)luz_save_model_weights(obj, path)

Arguments

obj

luz object to which you want to copy the new weights.

path

path to saved model in disk.

...

other arguments passed totorch::torch_load().

Value

ReturnsNULL invisibly.

Warning

luz_save_model_weights operates inplace, ie modifies the model object to contain thenew weights.


Creates a new luz metric

Description

Creates a new luz metric

Usage

luz_metric(  name = NULL,  ...,  private = NULL,  active = NULL,  parent_env = parent.frame(),  inherit = NULL)

Arguments

name

string naming the new metric.

...

named list of public methods. You should implement at leastinitialize,update andcompute. See the details section for moreinformation.

private

An optional list of private members, which can be functionsand non-functions.

active

An optional list of active binding functions.

parent_env

An environment to use as the parent of newly-createdobjects.

inherit

A R6ClassGenerator object to inherit from; in other words, asuperclass. This is captured as an unevaluated expression which isevaluated inparent_env each time an object is instantiated.

Details

In order to implement a newluz_metric we need to implement 3 methods:

Optionally, you can implement anabbrev field that gives the metric anabbreviation that will be used when displaying metric information in theconsole or tracking record. If noabbrev is passed, the class namewill be used.

Let’s take a look at the implementation ofluz_metric_accuracy so youcan see how to implement a new one:

luz_metric_accuracy <- luz_metric(  # An abbreviation to be shown in progress bars, or   # when printing progress  abbrev = "Acc",   # Initial setup for the metric. Metrics are initialized  # every epoch, for both training and validation  initialize = function() {    self$correct <- 0    self$total <- 0  },  # Run at every training or validation step and updates  # the internal state. The update function takes `preds`  # and `target` as parameters.  update = function(preds, target) {    pred <- torch::torch_argmax(preds, dim = 2)    self$correct <- self$correct + (pred == target)$      to(dtype = torch::torch_float())$      sum()$      item()    self$total <- self$total + pred$numel()  },  # Use the internal state to query the metric value  compute = function() {    self$correct/self$total  })

Note: It’s good practice that thecompute metric returns regular Rvalues instead of torch tensors and other parts of luz will expect that.

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

luz_metric_accuracy <- luz_metric(  # An abbreviation to be shown in progress bars, or  # when printing progress  abbrev = "Acc",  # Initial setup for the metric. Metrics are initialized  # every epoch, for both training and validation  initialize = function() {    self$correct <- 0    self$total <- 0  },  # Run at every training or validation step and updates  # the internal state. The update function takes `preds`  # and `target` as parameters.  update = function(preds, target) {    pred <- torch::torch_argmax(preds, dim = 2)    self$correct <- self$correct + (pred == target)$      to(dtype = torch::torch_float())$      sum()$      item()    self$total <- self$total + pred$numel()  },  # Use the internal state to query the metric value  compute = function() {    self$correct/self$total  })

Accuracy

Description

Computes accuracy for multi-class classification problems.

Usage

luz_metric_accuracy()

Details

This metric expects to take logits or probabilities at everyupdate. It will then take the columnwise argmax and compareto the target.

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_accuracy()metric <- metric$new()metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100))metric$compute()}

Binary accuracy

Description

Computes the accuracy for binary classification problems where themodel returns probabilities. Commonly used when the loss istorch::nn_bce_loss().

Usage

luz_metric_binary_accuracy(threshold = 0.5)

Arguments

threshold

value used to classifiy observations between 0 and 1.

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_binary_accuracy(threshold = 0.5)metric <- metric$new()metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100))metric$compute()}

Binary accuracy with logits

Description

Computes accuracy for binary classification problems where the modelreturn logits. Commonly used together withtorch::nn_bce_with_logits_loss().

Usage

luz_metric_binary_accuracy_with_logits(threshold = 0.5)

Arguments

threshold

value used to classifiy observations between 0 and 1.

Details

Probabilities are generated usingtorch::nnf_sigmoid() andthreshold is used toclassify between 0 or 1.

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5)metric <- metric$new()metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100))metric$compute()}

Computes the area under the ROC

Description

To avoid storing all predictions and targets for an epoch we compute confusionmatrices across a range of pre-established thresholds.

Usage

luz_metric_binary_auroc(  num_thresholds = 200,  thresholds = NULL,  from_logits = FALSE)

Arguments

num_thresholds

Number of thresholds used to compute confusion matrices.In that case, thresholds are created by gettingnum_thresholds values linearlyspaced in the unit interval.

thresholds

(optional) If threshold are passed, then those are used to compute theconfusion matrices andnum_thresholds is ignored.

from_logits

Boolean indicating if predictions are logits, in that casewe use sigmoid to put them in the unit interval.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()){library(torch)actual <- c(1, 1, 1, 0, 0, 0)predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)y_true <- torch_tensor(actual)y_pred <- torch_tensor(predicted)m <- luz_metric_binary_auroc(thresholds = predicted)m <- m$new()m$update(y_pred[1:2], y_true[1:2])m$update(y_pred[3:4], y_true[3:4])m$update(y_pred[5:6], y_true[5:6])m$compute()}

Mean absolute error

Description

Computes the mean absolute error.

Usage

luz_metric_mae()

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mse(),luz_metric_multiclass_auroc(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()) {library(torch)metric <- luz_metric_mae()metric <- metric$new()metric$update(torch_randn(100), torch_randn(100))metric$compute()}

Mean squared error

Description

Computes the mean squared error

Usage

luz_metric_mse()

Value

A luz_metric object.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_multiclass_auroc(),luz_metric_rmse()


Computes the multi-class AUROC

Description

The same definition asKerasis used by default. This is equivalent to the'micro' method in SciKit Learntoo. Seedocs.

Usage

luz_metric_multiclass_auroc(  num_thresholds = 200,  thresholds = NULL,  from_logits = FALSE,  average = c("micro", "macro", "weighted", "none"))

Arguments

num_thresholds

Number of thresholds used to compute confusion matrices.In that case, thresholds are created by gettingnum_thresholds values linearlyspaced in the unit interval.

thresholds

(optional) If threshold are passed, then those are used to compute theconfusion matrices andnum_thresholds is ignored.

from_logits

IfTRUE then we calltorch::nnf_softmax() in the predictionsbefore computing the metric.

average

The averaging method:

  • 'micro': Stack all classes and computes the AUROC as if it was a binaryclassification problem.

  • 'macro': Finds the AUCROC for each class and computes their mean.

  • 'weighted': Finds the AUROC for each class and computes their weightedmean pondering by the number of instances for each class.

  • 'none': Returns the AUROC for each class in a list.

Details

Note that class imbalance can affect this metric unlikethe AUC for binary classification.

Currently the AUC is approximated using the 'interpolation' method described inKeras.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_rmse()

Examples

if (torch::torch_is_installed()) {library(torch)actual <- c(1, 1, 1, 0, 0, 0) + 1Lpredicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)predicted <- cbind(1-predicted, predicted)y_true <- torch_tensor(as.integer(actual))y_pred <- torch_tensor(predicted)m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted),                                 average = "micro")m <- m$new()m$update(y_pred[1:2,], y_true[1:2])m$update(y_pred[3:4,], y_true[3:4])m$update(y_pred[5:6,], y_true[5:6])m$compute()}

Root mean squared error

Description

Computes the root mean squared error.

Usage

luz_metric_rmse()

Value

Returns new luz metric.

See Also

Other luz_metrics:luz_metric(),luz_metric_accuracy(),luz_metric_binary_accuracy(),luz_metric_binary_accuracy_with_logits(),luz_metric_binary_auroc(),luz_metric_mae(),luz_metric_mse(),luz_metric_multiclass_auroc()


Creates a metric set

Description

A metric set can be used to specify metrics that are only evaluated duringtraining, validation or both.

Usage

luz_metric_set(metrics = NULL, train_metrics = NULL, valid_metrics = NULL)

Arguments

metrics

A list of luz_metrics that are meant to be used in both trainingand validation.

train_metrics

A list of luz_metrics that are only used during training.

valid_metrics

A list of luz_metrics that are only sued for validation.


Saves luz objects to disk

Description

Allows saving luz fitted models to the disk. Objects can be loaded back withluz_load().

Usage

luz_save(obj, path, ...)

Arguments

obj

an object of class 'luz_module_fitted' as returned byfit.luz_module_generator().

path

path in file system to the object.

...

currently unused.

Warning

Thectx is naively serialized. Ie, we only usesaveRDS() to serialize it.Don't expectluz_save to work correctly if you have unserializable objectsin thectx liketorch_tensors and external pointers in general.

Note

Objects are saved as plain.rds files butobj$model is serializedwithtorch_save before saving it.

See Also

Other luz_save:luz_load()


Loss to be used withcallbacks_mixup().

Description

In the training phase, computes individual losses with regard to two targets, weights them item-wise,and averages the linear combinations to yield the mean batch loss.For validation and testing, defers to the passed-in loss.

Usage

nn_mixup_loss(loss)

Arguments

loss

the underlying lossnn_module to call. It mustsupport thereduction field. During training the attribute will be changed to'none' so we get the loss for individual observations. See for for exampledocumentation for thereduction argument intorch::nn_cross_entropy_loss().

Details

It should be used together withluz_callback_mixup().

See Also

luz_callback_mixup()


Mixup logic

Description

Logic underlyingluz_callback_mixup().

Usage

nnf_mixup(x, y, weight)

Arguments

x

an input batch

y

a target batch

weight

weighting coefficient to be used bytorch_lerp()

Details

Based on the passed-in input and target batches, as well as applicable mixing weights,we return new tensors intended to replace the current batch.The new input batch is a weighted linear combination of input batch items, whilethe new target batch bundles the original targets, as well as the mixing weights, ina nested list.

Value

Alist of:

See Also

luz_callback_mixup()

Examples

if (torch::torch_is_installed()) {batch_x <- torch::torch_randn(c(10, 768))batch_y <- torch::torch_randn(10)weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))nnf_mixup(batch_x, batch_y, weight)}

Create predictions for a fitted model

Description

Create predictions for a fitted model

Usage

## S3 method for class 'luz_module_fitted'predict(  object,  newdata,  ...,  callbacks = list(),  accelerator = NULL,  verbose = NULL,  dataloader_options = NULL)

Arguments

object

(fitted model) the fitted model object returned fromfit.luz_module_generator()

newdata

(dataloader, dataset, list or array) returning a list with atleast 1 element. The other elements aren't used.

...

Currently unused.

callbacks

(list, optional) A list of callbacks defined withluz_callback() that will be called during the training procedure. Thecallbacksluz_callback_metrics(),luz_callback_progress() andluz_callback_train_valid() are always added by default.

accelerator

(accelerator, optional) An optionalaccelerator() objectused to configure device placement of the components liketorch::nn_modules,optimizers and batches of data.

verbose

(logical, optional) An optional boolean value indicating ifthe fitting procedure should emit output to the console during training.By default, it will produce output ifinteractive() isTRUE, otherwiseit won't print to the console.

dataloader_options

Options used when creating a dataloader. Seetorch::dataloader().shuffle=TRUE by default for the training data andbatch_size=32 by default. It will error if notNULL anddata isalready a dataloader.

See Also

Other training:evaluate(),fit.luz_module_generator(),setup()


Objects exported from other packages

Description

These objects are imported from other packages. Follow the linksbelow to see their documentation.

generics

fit


Set hyper-parameter of a module

Description

This function is used to define hyper-parameters before callingfit forluz_modules.

Usage

set_hparams(module, ...)

Arguments

module

Annn_module that has beensetup().

...

The parameters set here will be used to initialize thenn_module, ie theyare passed unchanged to theinitialize method of the basenn_module.

Value

The same luz module

See Also

Other set_hparam:set_opt_hparams()


Set optimizer hyper-parameters

Description

This function is used to define hyper-parameters for the optimizer initializationmethod.

Usage

set_opt_hparams(module, ...)

Arguments

module

Annn_module that has beensetup().

...

The parameters passed here will be used to initialize the optimizers.For example, if your optimizer isoptim_adam and you passlr=0.1, then theoptim_adam function is called withoptim_adam(parameters, lr=0.1) when fittingthe model.

Value

The same luz module

See Also

Other set_hparam:set_hparams()


Set's up ann_module to use with luz

Description

The setup function is used to set important attributes and method fornn_modulesto be used with luz.

Usage

setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)

Arguments

module

(nn_module) Thenn_module that you want set up.

loss

(function, optional) An optional function with the signature⁠function(input, target)⁠. It's only requires if yournn_module doesn'timplement a method calledloss.

optimizer

(torch_optimizer, optional) A function with the signature⁠function(parameters, ...)⁠ that is used to initialize an optimizer giventhe model parameters.

metrics

(list, optional) A list of metrics to be tracked duringthe training procedure. Sometimes, you want some metrics to be evaluatedonly during training or validation, in this case you can pass aluz_metric_set()object to specify metrics used in each stage.

backward

(function) A functions that takes the loss scalar values asit's parameter. It must call⁠$backward()⁠ ortorch::autograd_backward().In general you don't need to set this parameter unless you need to customizehow luz calls thebackward(), for example, if you need to add additionalarguments to the backward call. Note that this becomes a method of thenn_modulethus can be used by your customstep() if you override it.

Details

It makes sure the module have all the necessary ingredients in order to be fitted.

Value

A luz module that can be trained withfit().

Note

It also adds adevice active field that can be used to query the currentmoduledevice within methods, with egself$device. This is useful whenctx() is not available, eg, when calling methods from outside theluzwrappers. Users can override the default by implementing adevice activemethod in the inputmodule.

See Also

Other training:evaluate(),fit.luz_module_generator(),predict.luz_module_fitted()


[8]ページ先頭

©2009-2025 Movatter.jp