torch.optim#
Created On: Jun 13, 2025 | Last Updated On: Aug 24, 2025
torch.optim is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is generalenough, so that more sophisticated ones can also be easily integrated in thefuture.
How to use an optimizer#
To usetorch.optim you have to construct an optimizer object that will holdthe current state and will update the parameters based on the computed gradients.
Constructing it#
To construct anOptimizer you have to give it an iterable containing theparameters (all should beParameter s) or named parameters(tuples of (str,Parameter)) to optimize. Then,you can specify optimizer-specific options such as the learning rate, weight decay, etc.
Example:
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)optimizer=optim.Adam([var1,var2],lr=0.0001)
Named parameters example:
optimizer=optim.SGD(model.named_parameters(),lr=0.01,momentum=0.9)optimizer=optim.Adam([('layer0',var1),('layer1',var2)],lr=0.0001)
Per-parameter options#
Optimizer s also support specifying per-parameter options. To do this, insteadof passing an iterable ofVariable s, pass in an iterable ofdict s. Each of them will define a separate parameter group, and should containaparams key, containing a list of parameters belonging to it. Other keysshould match the keyword arguments accepted by the optimizers, and will be usedas optimization options for this group.
For example, this is very useful when one wants to specify per-layer learning rates:
optim.SGD([{'params':model.base.parameters(),'lr':1e-2},{'params':model.classifier.parameters()}],lr=1e-3,momentum=0.9)optim.SGD([{'params':model.base.named_parameters(),'lr':1e-2},{'params':model.classifier.named_parameters()}],lr=1e-3,momentum=0.9)
This means thatmodel.base’s parameters will use a learning rate of1e-2, whereasmodel.classifier’s parameters will stick to the default learning rate of1e-3.Finally a momentum of0.9 will be used for all parameters.
Note
You can still pass options as keyword arguments. They will be used asdefaults, in the groups that didn’t override them. This is useful when youonly want to vary a single option, while keeping all others consistentbetween parameter groups.
Also consider the following example related to the distinct penalization of parameters.Remember thatparameters() returns an iterable thatcontains all learnable parameters, including biases and otherparameters that may prefer distinct penalization. To address this, one can specifyindividual penalization weights for each parameter group:
bias_params=[pforname,pinself.named_parameters()if'bias'inname]others=[pforname,pinself.named_parameters()if'bias'notinname]optim.SGD([{'params':others},{'params':bias_params,'weight_decay':0}],weight_decay=1e-2,lr=1e-2)
In this manner, bias terms are isolated from non-bias terms, and aweight_decayof0 is set specifically for the bias terms, as to avoid any penalization forthis group.
Taking an optimization step#
All optimizers implement astep() method, that updates theparameters. It can be used in two ways:
optimizer.step()#
This is a simplified version supported by most optimizers. The function can becalled once the gradients are computed using e.g.backward().
Example:
forinput,targetindataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()
optimizer.step(closure)#
Some optimization algorithms such as Conjugate Gradient and LBFGS need toreevaluate the function multiple times, so you have to pass in a closure thatallows them to recompute your model. The closure should clear the gradients,compute the loss, and return it.
Example:
forinput,targetindataset:defclosure():optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()returnlossoptimizer.step(closure)
Base class#
- classtorch.optim.Optimizer(params,defaults)[source]#
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministicordering that is consistent between runs. Examples of objects that don’tsatisfy those properties are sets and iterators over values of dictionaries.
- Parameters
params (iterable) – an iterable of
torch.Tensors ordicts. Specifies what Tensors should be optimized.defaults (dict[str,Any]) – (dict): a dict containing default values of optimizationoptions (used when a parameter group doesn’t specify them).
Add a param group to the | |
Load the optimizer state. | |
Register a load_state_dict pre-hook which will be called before | |
Register a load_state_dict post-hook which will be called after | |
Return the state of the optimizer as a | |
Register a state dict pre-hook which will be called before | |
Register a state dict post-hook which will be called after | |
Perform a single optimization step to update parameter. | |
Register an optimizer step pre hook which will be called before optimizer step. | |
Register an optimizer step post hook which will be called after optimizer step. | |
Reset the gradients of all optimized |
Algorithms#
Adadelta | Implements Adadelta algorithm. |
Adafactor | Implements Adafactor algorithm. |
Adagrad | Implements Adagrad algorithm. |
Adam | Implements Adam algorithm. |
AdamW | Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance. |
SparseAdam | SparseAdam implements a masked version of the Adam algorithm suitable for sparse gradients. |
Adamax | Implements Adamax algorithm (a variant of Adam based on infinity norm). |
ASGD | Implements Averaged Stochastic Gradient Descent. |
LBFGS | Implements L-BFGS algorithm. |
Muon | Implements Muon algorithm. |
NAdam | Implements NAdam algorithm. |
RAdam | Implements RAdam algorithm. |
RMSprop | Implements RMSprop algorithm. |
Rprop | Implements the resilient backpropagation algorithm. |
SGD | Implements stochastic gradient descent (optionally with momentum). |
Many of our algorithms have various implementations optimized for performance,readability and/or generality, so we attempt to default to the generally fastestimplementation for the current device if no particular implementation has beenspecified by the user.
We have 3 major categories of implementations: for-loop, foreach (multi-tensor), andfused. The most straightforward implementations are for-loops over the parameters withbig chunks of computation. For-looping is usually slower than our foreachimplementations, which combine parameters into a multi-tensor and run the big chunksof computation all at once, thereby saving many sequential kernel calls. A few of ouroptimizers have even faster fused implementations, which fuse the big chunks ofcomputation into one kernel. We can think of foreach implementations as fusinghorizontally and fused implementations as fusing vertically on top of that.
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop.So when applicable, we default to foreach over for-loop. Applicable means the foreachimplementation is available, the user has not specified any implementation-specific kwargs(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fusedshould be even faster than foreach, the implementations are newer and we would like to givethem more bake-in time before flipping the switch everywhere. We summarize the stability statusfor each implementation on the second table below, you are welcome to try them out though!
Below is a table showing the available and default implementations of each algorithm:
Algorithm | Default | Has foreach? | Has fused? |
|---|---|---|---|
Adadelta | foreach | yes | no |
Adafactor | for-loop | no | no |
Adagrad | foreach | yes | yes (cpu only) |
Adam | foreach | yes | yes |
AdamW | foreach | yes | yes |
SparseAdam | for-loop | no | no |
Adamax | foreach | yes | no |
ASGD | foreach | yes | no |
LBFGS | for-loop | no | no |
Muon | for-loop | no | no |
NAdam | foreach | yes | no |
RAdam | foreach | yes | no |
RMSprop | foreach | yes | no |
Rprop | foreach | yes | no |
SGD | foreach | yes | yes |
Below table is showing the stability status for fused implementations:
Algorithm | CPU | CUDA | MPS |
|---|---|---|---|
Adadelta | unsupported | unsupported | unsupported |
Adafactor | unsupported | unsupported | unsupported |
Adagrad | beta | unsupported | unsupported |
Adam | beta | stable | beta |
AdamW | beta | stable | beta |
SparseAdam | unsupported | unsupported | unsupported |
Adamax | unsupported | unsupported | unsupported |
ASGD | unsupported | unsupported | unsupported |
LBFGS | unsupported | unsupported | unsupported |
Muon | unsupported | unsupported | unsupported |
NAdam | unsupported | unsupported | unsupported |
RAdam | unsupported | unsupported | unsupported |
RMSprop | unsupported | unsupported | unsupported |
Rprop | unsupported | unsupported | unsupported |
SGD | beta | beta | beta |
How to adjust learning rate#
torch.optim.lr_scheduler.LRScheduler provides several methods to adjust the learningrate based on the number of epochs.torch.optim.lr_scheduler.ReduceLROnPlateauallows dynamic learning rate reducing based on some validation measurements.
Learning rate scheduling should be applied after optimizer’s update; e.g., youshould write your code this way:
Example:
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)scheduler=ExponentialLR(optimizer,gamma=0.9)forepochinrange(20):forinput,targetindataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()scheduler.step()
Most learning rate schedulers can be called back-to-back (also referred to aschaining schedulers). The result is that each scheduler is applied one after theother on the learning rate obtained by the one preceding it.
Example:
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)scheduler1=ExponentialLR(optimizer,gamma=0.9)scheduler2=MultiStepLR(optimizer,milestones=[30,80],gamma=0.1)forepochinrange(20):forinput,targetindataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()
In many places in the documentation, we will use the following template to refer to schedulersalgorithms.
>>>scheduler=...>>>forepochinrange(100):>>>train(...)>>>validate(...)>>>scheduler.step()
Warning
Prior to PyTorch 1.1.0, the learning rate scheduler was expected to be called beforethe optimizer’s update; 1.1.0 changed this behavior in a BC-breaking way. If you usethe learning rate scheduler (callingscheduler.step()) before the optimizer’s update(callingoptimizer.step()), this will skip the first value of the learning rate schedule.If you are unable to reproduce results after upgrading to PyTorch 1.1.0, please checkif you are callingscheduler.step() at the wrong time.
Adjusts the learning rate during optimization. | |
Sets the initial learning rate. | |
Multiply the learning rate of each parameter group by the factor given in the specified function. | |
Decays the learning rate of each parameter group by gamma every step_size epochs. | |
Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. | |
Multiply the learning rate of each parameter group by a small constant factor. | |
Decays the learning rate of each parameter group by linearly changing small multiplicative factor. | |
Decays the learning rate of each parameter group by gamma every epoch. | |
Decays the learning rate of each parameter group using a polynomial function in the given total_iters. | |
Set the learning rate of each parameter group using a cosine annealing schedule. | |
Chains a list of learning rate schedulers. | |
Contains a list of schedulers expected to be called sequentially during the optimization process. | |
Reduce learning rate when a metric has stopped improving. | |
Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). | |
Sets the learning rate of each parameter group according to the 1cycle learning rate policy. | |
Set the learning rate of each parameter group using a cosine annealing schedule. |
How to utilize named parameters to load optimizer state dict#
The functionload_state_dict() stores the optionalparam_names content from theloaded state dict if present. However, the process of loading the optimizer state is not affected,as the order of the parameters matters to maintain compatibility (in case of different ordering).To utilize the loaded parameters names from the loaded state dict, a customregister_load_state_dict_pre_hookneeds to be implemented according to the desired behavior.
This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need toremain unchanged. The following example demonstrates how to implement this customization.
Example:
classOneLayerModel(nn.Module):def__init__(self):super().__init__()self.fc=nn.Linear(3,4)defforward(self,x):returnself.fc(x)model=OneLayerModel()optimizer=optim.SGD(model.named_parameters(),lr=0.01,momentum=0.9)# training..torch.save(optimizer.state_dict(),PATH)
Let’s say thatmodel implements an expert (MoE), and we want to duplicate it and resume trainingfor two experts, both initialized the same way as thefc layer. For the followingmodel2 we create two layers identical tofc and resume training by loading the model weights and optimizer states frommodel into bothfc1 andfc2 ofmodel2 (and adjust them accordingly):
classTwoLayerModel(nn.Module):def__init__(self):super().__init__()self.fc1=nn.Linear(3,4)self.fc2=nn.Linear(3,4)defforward(self,x):return(self.fc1(x)+self.fc2(x))/2model2=TwoLayerModel()# adapt and load model weights..optimizer2=optim.SGD(model2.named_parameters(),lr=0.01,momentum=0.9)
To load the state dict foroptimizer2 with the state dict of the previous optimizer such that bothfc1 andfc2 will be initialized with a copy offc optimizer states(to resume training for each layer fromfc), we can use the following hook:
defadapt_state_dict_ids(optimizer,state_dict):adapted_state_dict=deepcopy(optimizer.state_dict())# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.fork,vinstate_dict['param_groups'][0].items():ifknotin['params','param_names']:adapted_state_dict['param_groups'][0][k]=vlookup_dict={'fc1.weight':'fc.weight','fc1.bias':'fc.bias','fc2.weight':'fc.weight','fc2.bias':'fc.bias'}clone_deepcopy=lambdad:{k:(v.clone()ifisinstance(v,torch.Tensor)elsedeepcopy(v))fork,vind.items()}forparam_id,param_nameinzip(optimizer.state_dict()['param_groups'][0]['params'],optimizer.state_dict()['param_groups'][0]['param_names']):name_in_loaded=lookup_dict[param_name]index_in_loaded_list=state_dict['param_groups'][0]['param_names'].index(name_in_loaded)id_in_loaded=state_dict['param_groups'][0]['params'][index_in_loaded_list]# Copy the state of the corresponding parameterifid_in_loadedinstate_dict['state']:adapted_state_dict['state'][param_id]=clone_deepcopy(state_dict['state'][id_in_loaded])returnadapted_state_dictoptimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)optimizer2.load_state_dict(torch.load(PATH))# The previous optimizer saved state_dict
This ensures that the adapted state_dict with the correct states for the layers ofmodel2 will be usedduring model loading.Note that this code is designed specifically for this example (e.g., assuming a single parameter group),and other cases might require different adaptations.
The following example shows how to handle missing parameters in a loadedstatedict when the model structure changes.TheModel_bypass adds a newbypass layer, which is not present in the originalModel1.To resume training, a customadapt_state_dict_missing_param hook is used to adapt the optimizer’sstate_dict,ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged(as initialized in this example).This approach enables smooth loading and resuming of the optimizer state despite model changes.The new bypass layer will be trained from scratch:
classModel1(nn.Module):def__init__(self):super().__init__()self.fc=nn.Linear(5,5)defforward(self,x):returnself.fc(x)+xmodel=Model1()optimizer=optim.SGD(model.named_parameters(),lr=0.01,momentum=0.9)# training..torch.save(optimizer.state_dict(),PATH)classModel_bypass(nn.Module):def__init__(self):super().__init__()self.fc=nn.Linear(5,5)self.bypass=nn.Linear(5,5,bias=False)torch.nn.init.eye_(self.bypass.weight)defforward(self,x):returnself.fc(x)+self.bypass(x)model2=Model_bypass()optimizer2=optim.SGD(model2.named_parameters(),lr=0.01,momentum=0.9)defadapt_state_dict_missing_param(optimizer,state_dict):adapted_state_dict=deepcopy(optimizer.state_dict())# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.fork,vinstate_dict['param_groups'][0].items():ifknotin['params','param_names']:adapted_state_dict['param_groups'][0][k]=vlookup_dict={'fc.weight':'fc.weight','fc.bias':'fc.bias','bypass.weight':None,}clone_deepcopy=lambdad:{k:(v.clone()ifisinstance(v,torch.Tensor)elsedeepcopy(v))fork,vind.items()}forparam_id,param_nameinzip(optimizer.state_dict()['param_groups'][0]['params'],optimizer.state_dict()['param_groups'][0]['param_names']):name_in_loaded=lookup_dict[param_name]ifname_in_loadedinstate_dict['param_groups'][0]['param_names']:index_in_loaded_list=state_dict['param_groups'][0]['param_names'].index(name_in_loaded)id_in_loaded=state_dict['param_groups'][0]['params'][index_in_loaded_list]# Copy the state of the corresponding parameterifid_in_loadedinstate_dict['state']:adapted_state_dict['state'][param_id]=clone_deepcopy(state_dict['state'][id_in_loaded])returnadapted_state_dictoptimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)optimizer2.load_state_dict(torch.load(PATH))# The previous optimizer saved state_dict
As a third example, instead of loading a state according to the order of parameters (the default approach),this hook can be used to load according to the parameters’ names:
defnames_matching(optimizer,state_dict):assertlen(state_dict['param_groups'])==len(optimizer.state_dict()['param_groups'])adapted_state_dict=deepcopy(optimizer.state_dict())forg_indinrange(len(state_dict['param_groups'])):assertlen(state_dict['param_groups'][g_ind]['params'])==len(optimizer.state_dict()['param_groups'][g_ind]['params'])fork,vinstate_dict['param_groups'][g_ind].items():ifknotin['params','param_names']:adapted_state_dict['param_groups'][g_ind][k]=vforparam_id,param_nameinzip(optimizer.state_dict()['param_groups'][g_ind]['params'],optimizer.state_dict()['param_groups'][g_ind]['param_names']):index_in_loaded_list=state_dict['param_groups'][g_ind]['param_names'].index(param_name)id_in_loaded=state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]# Copy the state of the corresponding parameterifid_in_loadedinstate_dict['state']:adapted_state_dict['state'][param_id]=deepcopy(state_dict['state'][id_in_loaded])returnadapted_state_dict
Weight Averaging (SWA and EMA)#
torch.optim.swa_utils.AveragedModel implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA),torch.optim.swa_utils.SWALR implements the SWA learning rate scheduler andtorch.optim.swa_utils.update_bn() is a utility function used to update SWA/EMA batchnormalization statistics at the end of training.
SWA has been proposed inAveraging Weights Leads to Wider Optima and Better Generalization.
EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed.It is a variation ofPolyak averaging, but using exponential weights instead of equal weights across iterations.
Constructing averaged models#
TheAveragedModel class serves to compute the weights of the SWA or EMA model.
You can create an SWA averaged model by running:
>>>averaged_model=AveragedModel(model)
EMA models are constructed by specifying themulti_avg_fn argument as follows:
>>>decay=0.999>>>averaged_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided totorch.optim.swa_utils.get_ema_multi_avg_fn(), the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues.
torch.optim.swa_utils.get_ema_multi_avg_fn() returns a function that applies the following EMA equation to the weights:
where alpha is the EMA decay.
Here the modelmodel can be an arbitrarytorch.nn.Module object.averaged_modelwill keep track of the running averages of the parameters of themodel. To update theseaverages, you should use theupdate_parameters() function after theoptimizer.step():
>>>averaged_model.update_parameters(model)
For SWA and EMA, this call is usually done right after the optimizerstep(). In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training.
Custom averaging strategies#
By default,torch.optim.swa_utils.AveragedModel computes a running equal average ofthe parameters that you provide, but you can also use custom averaging functions with theavg_fn ormulti_avg_fn parameters:
avg_fnallows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter.multi_avg_fnallows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using thetorch._foreach*functions. This function must update the averaged parameters in-place.
In the following exampleema_model computes an exponential moving average using theavg_fn parameter:
>>>ema_avg=lambdaaveraged_model_parameter,model_parameter,num_averaged:\>>>0.9*averaged_model_parameter+0.1*model_parameter>>>ema_model=torch.optim.swa_utils.AveragedModel(model,avg_fn=ema_avg)
In the following exampleema_model computes an exponential moving average using the more efficientmulti_avg_fn parameter:
>>>ema_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(0.9))
SWA learning rate schedules#
Typically, in SWA the learning rate is set to a high constant value.SWALR is alearning rate scheduler that anneals the learning rate to a fixed value, and then keeps itconstant. For example, the following code creates a scheduler that linearly anneals thelearning rate from its initial value to 0.05 in 5 epochs within each parameter group:
>>>swa_scheduler=torch.optim.swa_utils.SWALR(optimizer, \>>>anneal_strategy="linear",anneal_epochs=5,swa_lr=0.05)
You can also use cosine annealing to a fixed value instead of linear annealing by settinganneal_strategy="cos".
Taking care of batch normalization#
update_bn() is a utility function that allows to compute the batchnorm statistics for the SWA modelon a given dataloaderloader at the end of training:
>>>torch.optim.swa_utils.update_bn(loader,swa_model)
update_bn() applies theswa_model to every element in the dataloader and computes the activationstatistics for each batch normalization layer in the model.
Warning
update_bn() assumes that each batch in the dataloaderloader is either a tensors or a list oftensors where the first element is the tensor that the networkswa_model should be applied to.If your dataloader has a different structure, you can update the batch normalization statistics of theswa_model by doing a forward pass with theswa_model on each element of the dataset.
Putting it all together: SWA#
In the example below,swa_model is the SWA model that accumulates the averages of the weights.We train the model for a total of 300 epochs and we switch to the SWA learning rate scheduleand start to collect SWA averages of the parameters at epoch 160:
>>>loader,optimizer,model,loss_fn=...>>>swa_model=torch.optim.swa_utils.AveragedModel(model)>>>scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=300)>>>swa_start=160>>>swa_scheduler=SWALR(optimizer,swa_lr=0.05)>>>>>>forepochinrange(300):>>>forinput,targetinloader:>>>optimizer.zero_grad()>>>loss_fn(model(input),target).backward()>>>optimizer.step()>>>ifepoch>swa_start:>>>swa_model.update_parameters(model)>>>swa_scheduler.step()>>>else:>>>scheduler.step()>>>>>># Update bn statistics for the swa_model at the end>>>torch.optim.swa_utils.update_bn(loader,swa_model)>>># Use swa_model to make predictions on test data>>>preds=swa_model(test_input)
Putting it all together: EMA#
In the example below,ema_model is the EMA model that accumulates the exponentially-decayed averages of the weights with a decay rate of 0.999.We train the model for a total of 300 epochs and start to collect EMA averages immediately.
>>>loader,optimizer,model,loss_fn=...>>>ema_model=torch.optim.swa_utils.AveragedModel(model, \>>>multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))>>>>>>forepochinrange(300):>>>forinput,targetinloader:>>>optimizer.zero_grad()>>>loss_fn(model(input),target).backward()>>>optimizer.step()>>>ema_model.update_parameters(model)>>>>>># Update bn statistics for the ema_model at the end>>>torch.optim.swa_utils.update_bn(loader,ema_model)>>># Use ema_model to make predictions on test data>>>preds=ema_model(test_input)
Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). | |
Anneals the learning rate in each parameter group to a fixed value. |
- torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source]#
Get the function applying exponential moving average (EMA) across multiple params.
- torch.optim.swa_utils.update_bn(loader,model,device=None)[source]#
Update BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data inloader to estimate the activationstatistics for BatchNorm layers in the model.
- Parameters
loader (torch.utils.data.DataLoader) – dataset loader to compute theactivation statistics on. Each data batch should be either atensor, or a list/tuple whose first element is a tensorcontaining data.
model (torch.nn.Module) – model for which we seek to update BatchNormstatistics.
device (torch.device,optional) – If set, data will be transferred to
devicebefore being passed intomodel.
Example
>>>loader,model=...>>>torch.optim.swa_utils.update_bn(loader,model)
Note
Theupdate_bn utility assumes that each data batch in
loaderis either a tensor or a list or tuple of tensors; in the latter case itis assumed thatmodel.forward()should be called on the firstelement of the list or tuple corresponding to the data batch.