TransformerEncoder#
- classtorch.nn.TransformerEncoder(encoder_layer,num_layers,norm=None,enable_nested_tensor=True,mask_check=True)[source]#
TransformerEncoder is a stack of N encoder layers.
This TransformerEncoder layer implements the original architecture describedin theAttention Is All You Need paper. Theintent of this layer is as a reference implementation for foundational understandingand thus it contains only limited features relative to newer Transformer architectures.Given the fast pace of innovation in transformer-like architectures, we recommendexploring thistutorialto build efficient layers from building blocks in core or using higherlevel libraries from thePyTorch Ecosystem.
Warning
All layers in the TransformerEncoder are initialized with the same parameters.It is recommended to manually initialize the layers after creating the TransformerEncoder instance.
- Parameters
encoder_layer (TransformerEncoderLayer) – an instance of the TransformerEncoderLayer() class (required).
num_layers (int) – the number of sub-encoder-layers in the encoder (required).
norm (Optional[Module]) – the layer normalization component (optional).
enable_nested_tensor (bool) – if True, input will automatically convert to nested tensor(and convert back on output). This will improve the overall performance ofTransformerEncoder when padding rate is high. Default:
True(enabled).
Examples
>>>encoder_layer=nn.TransformerEncoderLayer(d_model=512,nhead=8)>>>transformer_encoder=nn.TransformerEncoder(encoder_layer,num_layers=6)>>>src=torch.rand(10,32,512)>>>out=transformer_encoder(src)
- forward(src,mask=None,src_key_padding_mask=None,is_causal=None)[source]#
Pass the input through the encoder layers in turn.
- Parameters
src (Tensor) – the sequence to the encoder (required).
mask (Optional[Tensor]) – the mask for the src sequence (optional).
src_key_padding_mask (Optional[Tensor]) – the mask for the src keys per batch (optional).
is_causal (Optional[bool]) – If specified, applies a causal mask as
mask.Default:None; try to detect a causal mask.Warning:is_causalprovides a hint thatmaskis thecausal mask. Providing incorrect hints can result inincorrect execution, including forward and backwardcompatibility.
- Return type
- Shape:
see the docs in
Transformer.