Rate this Page

Class Module#

Inheritance Relationships#

Base Type#

  • publicstd::enable_shared_from_this<Module>

Derived Types#

Class Documentation#

classModule:publicstd::enable_shared_from_this<Module>#

The base class for all modules in PyTorch.

AModule is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModule may contain furtherModules (“submodules”), each with their own implementation, persistent data and further submodules.Modules can thus be said to form a recursive tree structure. AModule is registered as a submodule to anotherModule by callingregister_module(), typically from within a parent module’s constructor.

A distinction is made between three kinds of persistent data that may be associated with aModule:

  1. Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. theweight of aLinear module),

  2. Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.mean andvariance in theBatchNorm module),

  3. Any additional state, not necessarily tensors, required for the implementation or configuration of aModule.

The first two kinds of state are special in that they may be registered with theModule system to allow convenient access and batch configuration. For example, registered parameters in anyModule may be iterated over via theparameters() accessor. Further, changing the data type of aModule’s registered parameters can be done conveniently viaModule::to(), e.g.module->to(torch::kCUDA) to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone() operation, which performs a deepcopy of a cloneableModule hierarchy.

Parameters are registered with aModule viaregister_parameter. Buffers are registered separately viaregister_buffer. These methods are part of the public API ofModule and are typically invoked from within a concreteModules constructor.

Note

The design and implementation of this class is largely based on the PythonAPI. You may want to consult the python documentation fortorch.nn.Module for further clarification on certainmethods or behavior.

Subclassed bytorch::nn::Cloneable< SoftshrinkImpl >,torch::nn::Cloneable< PReLUImpl >,torch::nn::Cloneable< LogSoftmaxImpl >,torch::nn::Cloneable< L1LossImpl >,torch::nn::Cloneable< SequentialImpl >,torch::nn::Cloneable< HardshrinkImpl >,torch::nn::Cloneable< GLUImpl >,torch::nn::Cloneable< RReLUImpl >,torch::nn::Cloneable< ParameterDictImpl >,torch::nn::Cloneable< IdentityImpl >,torch::nn::Cloneable< FoldImpl >,torch::nn::Cloneable< EmbeddingBagImpl >,torch::nn::Cloneable< BilinearImpl >,torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >,torch::nn::Cloneable< SoftminImpl >,torch::nn::Cloneable< SmoothL1LossImpl >,torch::nn::Cloneable< MultiLabelMarginLossImpl >,torch::nn::Cloneable< LeakyReLUImpl >,torch::nn::Cloneable< FunctionalImpl >,torch::nn::Cloneable< ELUImpl >,torch::nn::Cloneable< TanhshrinkImpl >,torch::nn::Cloneable< PairwiseDistanceImpl >,torch::nn::Cloneable< LogSigmoidImpl >,torch::nn::Cloneable< HardtanhImpl >,torch::nn::Cloneable< FractionalMaxPool2dImpl >,torch::nn::Cloneable< FlattenImpl >,torch::nn::Cloneable< CrossMapLRN2dImpl >,torch::nn::Cloneable< TransformerEncoderLayerImpl >,torch::nn::Cloneable< ThresholdImpl >,torch::nn::Cloneable< SoftsignImpl >,torch::nn::Cloneable< MultiMarginLossImpl >,torch::nn::Cloneable< FractionalMaxPool3dImpl >,torch::nn::Cloneable< CTCLossImpl >,torch::nn::Cloneable< UnfoldImpl >,torch::nn::Cloneable< SiLUImpl >,torch::nn::Cloneable< ParameterListImpl >,torch::nn::Cloneable< MultiheadAttentionImpl >,torch::nn::Cloneable< CELUImpl >,torch::nn::Cloneable< UpsampleImpl >,torch::nn::Cloneable< TransformerImpl >,torch::nn::Cloneable< SELUImpl >,torch::nn::Cloneable< PixelUnshuffleImpl >,torch::nn::Cloneable< LinearImpl >,torch::nn::Cloneable< HingeEmbeddingLossImpl >,torch::nn::Cloneable< EmbeddingImpl >,torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >,torch::nn::Cloneable< CrossEntropyLossImpl >,torch::nn::Cloneable< TripletMarginLossImpl >,torch::nn::Cloneable< TransformerDecoderLayerImpl >,torch::nn::Cloneable< SoftMarginLossImpl >,torch::nn::Cloneable< LocalResponseNormImpl >,torch::nn::Cloneable< BCELossImpl >,torch::nn::Cloneable< LayerNormImpl >,torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >,torch::nn::Cloneable< ReLUImpl >,torch::nn::Cloneable< ModuleListImpl >,torch::nn::Cloneable< HuberLossImpl >,torch::nn::Cloneable< GELUImpl >,torch::nn::Cloneable< SoftmaxImpl >,torch::nn::Cloneable< Softmax2dImpl >,torch::nn::Cloneable< SoftplusImpl >,torch::nn::Cloneable< SigmoidImpl >,torch::nn::Cloneable< PoissonNLLLossImpl >,torch::nn::Cloneable< ModuleDictImpl >,torch::nn::Cloneable< MishImpl >,torch::nn::Cloneable< UnflattenImpl >,torch::nn::Cloneable< ReLU6Impl >,torch::nn::Cloneable< MSELossImpl >,torch::nn::Cloneable< CosineSimilarityImpl >,torch::nn::Cloneable< CosineEmbeddingLossImpl >,torch::nn::Cloneable< TransformerDecoderImpl >,torch::nn::Cloneable< TanhImpl >,torch::nn::Cloneable< NLLLossImpl >,torch::nn::Cloneable< MarginRankingLossImpl >,torch::nn::Cloneable< BCEWithLogitsLossImpl >,torch::nn::Cloneable< TransformerEncoderImpl >,torch::nn::Cloneable< PixelShuffleImpl >,torch::nn::Cloneable< KLDivLossImpl >,torch::nn::Cloneable< GroupNormImpl >,torch::nn::Cloneable< Derived >

Public Types

usingModuleApplyFunction=std::function<void(Module&)>#
usingConstModuleApplyFunction=std::function<void(constModule&)>#
usingNamedModuleApplyFunction=std::function<void(conststd::string&,Module&)>#
usingConstNamedModuleApplyFunction=std::function<void(conststd::string&,constModule&)>#
usingModulePointerApplyFunction=std::function<void(conststd::shared_ptr<Module>&)>#
usingNamedModulePointerApplyFunction=std::function<void(conststd::string&,conststd::shared_ptr<Module>&)>#

Public Functions

explicitModule(std::stringname)#

Tells the baseModule about the name of the submodule.

Module()#

Constructs the module without immediate knowledge of the submodule’s name.

The name of the submodule is inferred via RTTI (if possible) the first time.name() is invoked.

Module(constModule&)=default#
Module&operator=(constModule&)=default#
Module(Module&&)noexcept=default#
Module&operator=(Module&&)noexcept=default#
virtual~Module()=default#
conststd::string&name()constnoexcept#

Returns the name of theModule.

AModule has an associatedname, which is a string representation of the kind of concreteModule it represents, such as"Linear" for theLinear module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModules by passing the string name to theModule base class’ constructor.

virtualstd::shared_ptr<Module>clone(conststd::optional<Device>&device=std::nullopt)const#

Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.

Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.

Attention

Attempting to call theclone() method inherited from the baseModuleclass (the one documented here) will fail. To inherit an actualimplementation ofclone(), you must subclassCloneable.Cloneableis templatized on the concrete module type, and can thus properly copy aModule. This method is provided on the base class’ API solely for aneasier-to-use polymorphic interface.

voidapply(constModuleApplyFunction&function)#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aModule&.

voidapply(constConstModuleApplyFunction&function)const#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aconstModule&.

voidapply(constNamedModuleApplyFunction&function,conststd::string&name_prefix=std::string())#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aconststd::string& for the key of the module, and aModule&. The key of the module itself is the empty string. Ifname_prefix is given, it is prepended to every key as<name_prefix>.<key> (and justname_prefix for the module itself).

voidapply(constConstNamedModuleApplyFunction&function,conststd::string&name_prefix=std::string())const#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aconststd::string& for the key of the module, and aconstModule&. The key of the module itself is the empty string. Ifname_prefix is given, it is prepended to every key as<name_prefix>.<key> (and justname_prefix for the module itself).

voidapply(constModulePointerApplyFunction&function)const#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aconststd::shared_ptr<Module>&.

voidapply(constNamedModulePointerApplyFunction&function,conststd::string&name_prefix=std::string())const#

Applies thefunction to theModule and recursively to every submodule.

The function must accept aconststd::string& for the key of the module, and aconststd::shared_ptr<Module>&. The key of the module itself is the empty string. Ifname_prefix is given, it is prepended to every key as<name_prefix>.<key> (and justname_prefix for the module itself).

std::vector<Tensor>parameters(boolrecurse=true)const#

Returns the parameters of thisModule and ifrecurse is true, also recursively of every submodule.

OrderedDict<std::string,Tensor>named_parameters(boolrecurse=true)const#

Returns anOrderedDict with the parameters of thisModule along with their keys, and ifrecurse is true also recursively of every submodule.

std::vector<Tensor>buffers(boolrecurse=true)const#

Returns the buffers of thisModule and ifrecurse is true, also recursively of every submodule.

OrderedDict<std::string,Tensor>named_buffers(boolrecurse=true)const#

Returns anOrderedDict with the buffers of thisModule along with their keys, and ifrecurse is true also recursively of every submodule.

std::vector<std::shared_ptr<Module>>modules(boolinclude_self=true)const#

Returns the submodules of thisModule (the entire submodule hierarchy) and ifinclude_self is true, also inserts ashared_ptr to this module in the first position.

Warning

Only passinclude_self astrue if thisModule is stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method withinclude_self set to false if yourModule is notstored in ashared_ptr.

OrderedDict<std::string,std::shared_ptr<Module>>named_modules(conststd::string&name_prefix=std::string(),boolinclude_self=true)const#

Returns anOrderedDict of the submodules of thisModule (the entire submodule hierarchy) and their keys, and ifinclude_self is true, also inserts ashared_ptr to this module in the first position.

Ifname_prefix is given, it is prepended to every key as<name_prefix>.<key> (and justname_prefix for the module itself).

Warning

Only passinclude_self astrue if thisModule is stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method withinclude_self set to false if yourModule is notstored in ashared_ptr.

std::vector<std::shared_ptr<Module>>children()const#

Returns the direct submodules of thisModule.

OrderedDict<std::string,std::shared_ptr<Module>>named_children()const#

Returns anOrderedDict of the direct submodules of thisModule and their keys.

virtualvoidtrain(boolon=true)#

Enables “training” mode.

voideval()#

Calls train(false) to enable “eval” mode.

Do not override this method, overridetrain() instead.

virtualboolis_training()constnoexcept#

True if the module is in training mode.

EveryModule has a boolean associated with it that determines whether theModule is currently intraining mode (set via.train()) or inevaluation (inference) mode (set via.eval()). This property is exposed viais_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNorm orDropout modules for examples ofModules that use different code paths depending on this property.

virtualvoidto(torch::Devicedevice,torch::Dtypedtype,boolnon_blocking=false)#

Recursively casts all parameters to the givendtype anddevice.

Ifnon_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtualvoidto(torch::Dtypedtype,boolnon_blocking=false)#

Recursively casts all parameters to the given dtype.

Ifnon_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtualvoidto(torch::Devicedevice,boolnon_blocking=false)#

Recursively moves all parameters to the given device.

Ifnon_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtualvoidzero_grad(boolset_to_none=true)#

Recursively zeros out thegrad value of each registered parameter.

template<typenameModuleType>
ModuleType::ContainedType*as()noexcept#

Attempts to cast thisModule to the givenModuleType.

This method is useful when callingapply().

voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module->apply(initialize_weights);

template<typenameModuleType>
constModuleType::ContainedType*as()constnoexcept#

Attempts to cast thisModule to the givenModuleType.

This method is useful when callingapply().

template<typenameModuleType,typename=torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType*as()noexcept#

Attempts to cast thisModule to the givenModuleType.

This method is useful when callingapply().

voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module.apply(initialize_weights);

template<typenameModuleType,typename=torch::detail::disable_if_module_holder_t<ModuleType>>
constModuleType*as()constnoexcept#

Attempts to cast thisModule to the givenModuleType.

This method is useful when callingapply().

voidinitialize_weights(nn::Module&module){torch::NoGradGuardno_grad;if(auto*linear=module.as<nn::Linear>()){linear->weight.normal_(0.0,0.02);}}MyModulemodule;module.apply(initialize_weights);

virtualvoidsave(serialize::OutputArchive&archive)const#

Serializes theModule into the givenOutputArchive.

If theModule contains unserializable submodules (e.g.nn::Functional), those submodules are skipped when serializing.

virtualvoidload(serialize::InputArchive&archive)#

Deserializes theModule from the givenInputArchive.

If theModule contains unserializable submodules (e.g.nn::Functional), we don’t check the existence of those submodules in theInputArchive when deserializing.

virtualvoidpretty_print(std::ostream&stream)const#

Streams a pretty representation of theModule into the givenstream.

By default, this representation will be the name of the module (taken fromname()), followed by a recursive pretty print of all of theModule’s submodules.

Override this method to change the pretty print. The inputstream should be returned from the method, to allow easy chaining.

virtualboolis_serializable()const#

Returns whether theModule is serializable.

Tensor&register_parameter(std::stringname,Tensortensor,boolrequires_grad=true)#

Registers a parameter with thisModule.

A parameter should be any gradient-recording tensor used in the implementation of yourModule. Registering it makes it available to methods such asparameters(),clone() orto().

Note that registering an undefined Tensor (e.g.module.register_parameter("param",Tensor())) is allowed, and is equivalent tomodule.register_parameter("param",None) in Python API.

MyModule::MyModule(){weight_=register_parameter("weight",torch::randn({A,B}));}

Tensor&register_buffer(std::stringname,Tensortensor)#

Registers a buffer with thisModule.

A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such asbuffers(),clone() or `to().

MyModule::MyModule(){mean_=register_buffer("mean",torch::empty({num_features_}));}

template<typenameModuleType>
std::shared_ptr<ModuleType>register_module(std::stringname,std::shared_ptr<ModuleType>module)#

Registers a submodule with thisModule.

Registering a module makes it available to methods such asmodules(),clone() orto().

MyModule::MyModule(){submodule_=register_module("linear",torch::nn::Linear(3,4));}

template<typenameModuleType>
std::shared_ptr<ModuleType>register_module(std::stringname,ModuleHolder<ModuleType>module_holder)#

Registers a submodule with thisModule.

This method deals withModuleHolders.

Registering a module makes it available to methods such asmodules(),clone() orto().

MyModule::MyModule(){submodule_=register_module("linear",torch::nn::Linear(3,4));}

template<typenameModuleType>
std::shared_ptr<ModuleType>replace_module(conststd::string&name,std::shared_ptr<ModuleType>module)#

Replaces a registered submodule with thisModule.

This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”,torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.

This is useful for replacing a module after initialization, e.g. for finetuning.

template<typenameModuleType>
std::shared_ptr<ModuleType>replace_module(conststd::string&name,ModuleHolder<ModuleType>module_holder)#

Replaces a registered submodule with thisModule.

This method deals withModuleHolders.

This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.

This is useful for replacing a module after initialization, e.g. for finetuning.

voidunregister_module(conststd::string&name)#

Unregisters a submodule from thisModule.

If there is no such module withname an exception is thrown.

Protected Functions

inlinevirtualbool_forward_has_default_args()#

The following three functions allow a module with default arguments in its forward method to be used in aSequential module.

You should NEVER override these functions manually. Instead, you should use theFORWARD_HAS_DEFAULT_ARGS macro.

inlinevirtualunsignedint_forward_num_required_args()#
inlinevirtualstd::vector<AnyValue>_forward_populate_default_args(std::vector<AnyValue>&&arguments)#

Protected Attributes

OrderedDict<std::string,Tensor>parameters_#

The registered parameters of thisModule.

Inorder to access parameters_ inParameterDict andParameterList