Class Module#
Defined inFile module.h
Inheritance Relationships#
Base Type#
publicstd::enable_shared_from_this<Module>
Derived Types#
publictorch::nn::Cloneable<SoftshrinkImpl>(Template Class Cloneable)publictorch::nn::Cloneable<PReLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LogSoftmaxImpl>(Template Class Cloneable)publictorch::nn::Cloneable<L1LossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SequentialImpl>(Template Class Cloneable)publictorch::nn::Cloneable<HardshrinkImpl>(Template Class Cloneable)publictorch::nn::Cloneable<GLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<RReLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ParameterDictImpl>(Template Class Cloneable)publictorch::nn::Cloneable<IdentityImpl>(Template Class Cloneable)publictorch::nn::Cloneable<FoldImpl>(Template Class Cloneable)publictorch::nn::Cloneable<EmbeddingBagImpl>(Template Class Cloneable)publictorch::nn::Cloneable<BilinearImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TripletMarginWithDistanceLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SoftminImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SmoothL1LossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MultiLabelMarginLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LeakyReLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<FunctionalImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ELUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TanhshrinkImpl>(Template Class Cloneable)publictorch::nn::Cloneable<PairwiseDistanceImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LogSigmoidImpl>(Template Class Cloneable)publictorch::nn::Cloneable<HardtanhImpl>(Template Class Cloneable)publictorch::nn::Cloneable<FractionalMaxPool2dImpl>(Template Class Cloneable)publictorch::nn::Cloneable<FlattenImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CrossMapLRN2dImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TransformerEncoderLayerImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ThresholdImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SoftsignImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MultiMarginLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<FractionalMaxPool3dImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CTCLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<UnfoldImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SiLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ParameterListImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MultiheadAttentionImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CELUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<UpsampleImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TransformerImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SELUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<PixelUnshuffleImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LinearImpl>(Template Class Cloneable)publictorch::nn::Cloneable<HingeEmbeddingLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<EmbeddingImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MultiLabelSoftMarginLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CrossEntropyLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TripletMarginLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TransformerDecoderLayerImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SoftMarginLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LocalResponseNormImpl>(Template Class Cloneable)publictorch::nn::Cloneable<BCELossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<LayerNormImpl>(Template Class Cloneable)publictorch::nn::Cloneable<AdaptiveLogSoftmaxWithLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ReLUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ModuleListImpl>(Template Class Cloneable)publictorch::nn::Cloneable<HuberLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<GELUImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SoftmaxImpl>(Template Class Cloneable)publictorch::nn::Cloneable<Softmax2dImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SoftplusImpl>(Template Class Cloneable)publictorch::nn::Cloneable<SigmoidImpl>(Template Class Cloneable)publictorch::nn::Cloneable<PoissonNLLLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ModuleDictImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MishImpl>(Template Class Cloneable)publictorch::nn::Cloneable<UnflattenImpl>(Template Class Cloneable)publictorch::nn::Cloneable<ReLU6Impl>(Template Class Cloneable)publictorch::nn::Cloneable<MSELossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CosineSimilarityImpl>(Template Class Cloneable)publictorch::nn::Cloneable<CosineEmbeddingLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TransformerDecoderImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TanhImpl>(Template Class Cloneable)publictorch::nn::Cloneable<NLLLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<MarginRankingLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<BCEWithLogitsLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<TransformerEncoderImpl>(Template Class Cloneable)publictorch::nn::Cloneable<PixelShuffleImpl>(Template Class Cloneable)publictorch::nn::Cloneable<KLDivLossImpl>(Template Class Cloneable)publictorch::nn::Cloneable<GroupNormImpl>(Template Class Cloneable)publictorch::nn::Cloneable<Derived>(Template Class Cloneable)
Class Documentation#
- classModule:publicstd::enable_shared_from_this<Module>#
The base class for all modules in PyTorch.
A
Moduleis an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModulemay contain furtherModules (“submodules”), each with their own implementation, persistent data and further submodules.Modules can thus be said to form a recursive tree structure. AModuleis registered as a submodule to anotherModuleby callingregister_module(), typically from within a parent module’s constructor.A distinction is made between three kinds of persistent data that may be associated with a
Module:Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the
weightof aLinearmodule),Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.
meanandvariancein theBatchNormmodule),Any additional state, not necessarily tensors, required for the implementation or configuration of a
Module.
The first two kinds of state are special in that they may be registered with the
Modulesystem to allow convenient access and batch configuration. For example, registered parameters in anyModulemay be iterated over via theparameters()accessor. Further, changing the data type of aModule’s registered parameters can be done conveniently viaModule::to(), e.g.module->to(torch::kCUDA)to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone()operation, which performs a deepcopy of a cloneableModulehierarchy.Parameters are registered with a
Moduleviaregister_parameter. Buffers are registered separately viaregister_buffer. These methods are part of the public API ofModuleand are typically invoked from within a concreteModules constructor.Note
The design and implementation of this class is largely based on the PythonAPI. You may want to consult the python documentation for
torch.nn.Modulefor further clarification on certainmethods or behavior.Subclassed bytorch::nn::Cloneable< SoftshrinkImpl >,torch::nn::Cloneable< PReLUImpl >,torch::nn::Cloneable< LogSoftmaxImpl >,torch::nn::Cloneable< L1LossImpl >,torch::nn::Cloneable< SequentialImpl >,torch::nn::Cloneable< HardshrinkImpl >,torch::nn::Cloneable< GLUImpl >,torch::nn::Cloneable< RReLUImpl >,torch::nn::Cloneable< ParameterDictImpl >,torch::nn::Cloneable< IdentityImpl >,torch::nn::Cloneable< FoldImpl >,torch::nn::Cloneable< EmbeddingBagImpl >,torch::nn::Cloneable< BilinearImpl >,torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >,torch::nn::Cloneable< SoftminImpl >,torch::nn::Cloneable< SmoothL1LossImpl >,torch::nn::Cloneable< MultiLabelMarginLossImpl >,torch::nn::Cloneable< LeakyReLUImpl >,torch::nn::Cloneable< FunctionalImpl >,torch::nn::Cloneable< ELUImpl >,torch::nn::Cloneable< TanhshrinkImpl >,torch::nn::Cloneable< PairwiseDistanceImpl >,torch::nn::Cloneable< LogSigmoidImpl >,torch::nn::Cloneable< HardtanhImpl >,torch::nn::Cloneable< FractionalMaxPool2dImpl >,torch::nn::Cloneable< FlattenImpl >,torch::nn::Cloneable< CrossMapLRN2dImpl >,torch::nn::Cloneable< TransformerEncoderLayerImpl >,torch::nn::Cloneable< ThresholdImpl >,torch::nn::Cloneable< SoftsignImpl >,torch::nn::Cloneable< MultiMarginLossImpl >,torch::nn::Cloneable< FractionalMaxPool3dImpl >,torch::nn::Cloneable< CTCLossImpl >,torch::nn::Cloneable< UnfoldImpl >,torch::nn::Cloneable< SiLUImpl >,torch::nn::Cloneable< ParameterListImpl >,torch::nn::Cloneable< MultiheadAttentionImpl >,torch::nn::Cloneable< CELUImpl >,torch::nn::Cloneable< UpsampleImpl >,torch::nn::Cloneable< TransformerImpl >,torch::nn::Cloneable< SELUImpl >,torch::nn::Cloneable< PixelUnshuffleImpl >,torch::nn::Cloneable< LinearImpl >,torch::nn::Cloneable< HingeEmbeddingLossImpl >,torch::nn::Cloneable< EmbeddingImpl >,torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >,torch::nn::Cloneable< CrossEntropyLossImpl >,torch::nn::Cloneable< TripletMarginLossImpl >,torch::nn::Cloneable< TransformerDecoderLayerImpl >,torch::nn::Cloneable< SoftMarginLossImpl >,torch::nn::Cloneable< LocalResponseNormImpl >,torch::nn::Cloneable< BCELossImpl >,torch::nn::Cloneable< LayerNormImpl >,torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >,torch::nn::Cloneable< ReLUImpl >,torch::nn::Cloneable< ModuleListImpl >,torch::nn::Cloneable< HuberLossImpl >,torch::nn::Cloneable< GELUImpl >,torch::nn::Cloneable< SoftmaxImpl >,torch::nn::Cloneable< Softmax2dImpl >,torch::nn::Cloneable< SoftplusImpl >,torch::nn::Cloneable< SigmoidImpl >,torch::nn::Cloneable< PoissonNLLLossImpl >,torch::nn::Cloneable< ModuleDictImpl >,torch::nn::Cloneable< MishImpl >,torch::nn::Cloneable< UnflattenImpl >,torch::nn::Cloneable< ReLU6Impl >,torch::nn::Cloneable< MSELossImpl >,torch::nn::Cloneable< CosineSimilarityImpl >,torch::nn::Cloneable< CosineEmbeddingLossImpl >,torch::nn::Cloneable< TransformerDecoderImpl >,torch::nn::Cloneable< TanhImpl >,torch::nn::Cloneable< NLLLossImpl >,torch::nn::Cloneable< MarginRankingLossImpl >,torch::nn::Cloneable< BCEWithLogitsLossImpl >,torch::nn::Cloneable< TransformerEncoderImpl >,torch::nn::Cloneable< PixelShuffleImpl >,torch::nn::Cloneable< KLDivLossImpl >,torch::nn::Cloneable< GroupNormImpl >,torch::nn::Cloneable< Derived >
Public Types
Public Functions
- Module()#
Constructs the module without immediate knowledge of the submodule’s name.
The name of the submodule is inferred via RTTI (if possible) the first time
.name()is invoked.
- virtual~Module()=default#
- conststd::string&name()constnoexcept#
Returns the name of the
Module.A
Modulehas an associatedname, which is a string representation of the kind of concreteModuleit represents, such as"Linear"for theLinearmodule. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModules by passing the string name to theModulebase class’ constructor.
- virtualstd::shared_ptr<Module>clone(conststd::optional<Device>&device=std::nullopt)const#
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
Attention
Attempting to call theclone() method inherited from the baseModuleclass (the one documented here) will fail. To inherit an actualimplementation ofclone(), you must subclassCloneable.Cloneableis templatized on the concrete module type, and can thus properly copy aModule. This method is provided on the base class’ API solely for aneasier-to-use polymorphic interface.
- voidapply(constModuleApplyFunction&function)#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
Module&.
- voidapply(constConstModuleApplyFunction&function)const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
constModule&.
- voidapply(constNamedModuleApplyFunction&function,conststd::string&name_prefix=std::string())#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
conststd::string&for the key of the module, and aModule&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
- voidapply(constConstNamedModuleApplyFunction&function,conststd::string&name_prefix=std::string())const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
conststd::string&for the key of the module, and aconstModule&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
- voidapply(constModulePointerApplyFunction&function)const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
conststd::shared_ptr<Module>&.
- voidapply(constNamedModulePointerApplyFunction&function,conststd::string&name_prefix=std::string())const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
conststd::string&for the key of the module, and aconststd::shared_ptr<Module>&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
- std::vector<Tensor>parameters(boolrecurse=true)const#
Returns the parameters of this
Moduleand ifrecurseis true, also recursively of every submodule.
- OrderedDict<std::string,Tensor>named_parameters(boolrecurse=true)const#
Returns an
OrderedDictwith the parameters of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
- std::vector<Tensor>buffers(boolrecurse=true)const#
Returns the buffers of this
Moduleand ifrecurseis true, also recursively of every submodule.
- OrderedDict<std::string,Tensor>named_buffers(boolrecurse=true)const#
Returns an
OrderedDictwith the buffers of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
- std::vector<std::shared_ptr<Module>>modules(boolinclude_self=true)const#
Returns the submodules of this
Module(the entire submodule hierarchy) and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position.Warning
Only passinclude_self astrue if thisModule is stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method withinclude_self set to false if yourModule is notstored in ashared_ptr.
- OrderedDict<std::string,std::shared_ptr<Module>>named_modules(conststd::string&name_prefix=std::string(),boolinclude_self=true)const#
Returns an
OrderedDictof the submodules of thisModule(the entire submodule hierarchy) and their keys, and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position.If
name_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).Warning
Only passinclude_self astrue if thisModule is stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method withinclude_self set to false if yourModule is notstored in ashared_ptr.
- OrderedDict<std::string,std::shared_ptr<Module>>named_children()const#
Returns an
OrderedDictof the direct submodules of thisModuleand their keys.
- virtualvoidtrain(boolon=true)#
Enables “training” mode.
- voideval()#
Calls train(false) to enable “eval” mode.
Do not override this method, override
train()instead.
- virtualboolis_training()constnoexcept#
True if the module is in training mode.
Every
Modulehas a boolean associated with it that determines whether theModuleis currently intraining mode (set via.train()) or inevaluation (inference) mode (set via.eval()). This property is exposed viais_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNormorDropoutmodules for examples ofModules that use different code paths depending on this property.
- virtualvoidto(torch::Devicedevice,torch::Dtypedtype,boolnon_blocking=false)#
Recursively casts all parameters to the given
dtypeanddevice.If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
- virtualvoidto(torch::Dtypedtype,boolnon_blocking=false)#
Recursively casts all parameters to the given dtype.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
- virtualvoidto(torch::Devicedevice,boolnon_blocking=false)#
Recursively moves all parameters to the given device.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
- virtualvoidzero_grad(boolset_to_none=true)#
Recursively zeros out the
gradvalue of each registered parameter.
- template<typenameModuleType>
ModuleType::ContainedType*as()noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module->apply(initialize_weights);
- template<typenameModuleType>
constModuleType::ContainedType*as()constnoexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().
- template<typenameModuleType,typename=torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType*as()noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module.apply(initialize_weights);
- template<typenameModuleType,typename=torch::detail::disable_if_module_holder_t<ModuleType>>
constModuleType*as()constnoexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module.apply(initialize_weights);
- virtualvoidsave(serialize::OutputArchive&archive)const#
Serializes the
Moduleinto the givenOutputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), those submodules are skipped when serializing.
- virtualvoidload(serialize::InputArchive&archive)#
Deserializes the
Modulefrom the givenInputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), we don’t check the existence of those submodules in theInputArchivewhen deserializing.
- virtualvoidpretty_print(std::ostream&stream)const#
Streams a pretty representation of the
Moduleinto the givenstream.By default, this representation will be the name of the module (taken from
name()), followed by a recursive pretty print of all of theModule’s submodules.Override this method to change the pretty print. The input
streamshould be returned from the method, to allow easy chaining.
- Tensor®ister_parameter(std::stringname,Tensortensor,boolrequires_grad=true)#
Registers a parameter with this
Module.A parameter should be any gradient-recording tensor used in the implementation of your
Module. Registering it makes it available to methods such asparameters(),clone()orto().Note that registering an undefined Tensor (e.g.
module.register_parameter("param",Tensor())) is allowed, and is equivalent tomodule.register_parameter("param",None)in Python API.MyModule::MyModule(){weight_=register_parameter("weight",torch::randn({A,B}));}
- Tensor®ister_buffer(std::stringname,Tensortensor)#
Registers a buffer with this
Module.A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as
buffers(),clone()or `to().MyModule::MyModule(){mean_=register_buffer("mean",torch::empty({num_features_}));}
- template<typenameModuleType>
std::shared_ptr<ModuleType>register_module(std::stringname,std::shared_ptr<ModuleType>module)# Registers a submodule with this
Module.Registering a module makes it available to methods such as
modules(),clone()orto().MyModule::MyModule(){submodule_=register_module("linear",torch::nn::Linear(3,4));}
- template<typenameModuleType>
std::shared_ptr<ModuleType>register_module(std::stringname,ModuleHolder<ModuleType>module_holder)# Registers a submodule with this
Module.This method deals with
ModuleHolders.Registering a module makes it available to methods such as
modules(),clone()orto().MyModule::MyModule(){submodule_=register_module("linear",torch::nn::Linear(3,4));}
- template<typenameModuleType>
std::shared_ptr<ModuleType>replace_module(conststd::string&name,std::shared_ptr<ModuleType>module)# Replaces a registered submodule with this
Module.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”,torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
- template<typenameModuleType>
std::shared_ptr<ModuleType>replace_module(conststd::string&name,ModuleHolder<ModuleType>module_holder)# Replaces a registered submodule with this
Module.This method deals with
ModuleHolders.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Protected Functions
- inlinevirtualbool_forward_has_default_args()#
The following three functions allow a module with default arguments in its forward method to be used in aSequential module.
You should NEVER override these functions manually. Instead, you should use the
FORWARD_HAS_DEFAULT_ARGSmacro.
- inlinevirtualunsignedint_forward_num_required_args()#
Protected Attributes
- OrderedDict<std::string,Tensor>parameters_#
The registered parameters of this
Module.Inorder to access parameters_ inParameterDict andParameterList