TransformerDecoder#
- classtorch.nn.TransformerDecoder(decoder_layer,num_layers,norm=None)[source]#
TransformerDecoder is a stack of N decoder layers.
This TransformerDecoder layer implements the original architecture describedin theAttention Is All You Need paper. Theintent of this layer is as a reference implementation for foundational understandingand thus it contains only limited features relative to newer Transformer architectures.Given the fast pace of innovation in transformer-like architectures, we recommendexploring thistutorialto build efficient layers from building blocks in core or using higherlevel libraries from thePyTorch Ecosystem.
Warning
All layers in the TransformerDecoder are initialized with the same parameters.It is recommended to manually initialize the layers after creating the TransformerDecoder instance.
- Parameters
decoder_layer (TransformerDecoderLayer) – an instance of the TransformerDecoderLayer() class (required).
num_layers (int) – the number of sub-decoder-layers in the decoder (required).
norm (Optional[Module]) – the layer normalization component (optional).
Examples
>>>decoder_layer=nn.TransformerDecoderLayer(d_model=512,nhead=8)>>>transformer_decoder=nn.TransformerDecoder(decoder_layer,num_layers=6)>>>memory=torch.rand(10,32,512)>>>tgt=torch.rand(20,32,512)>>>out=transformer_decoder(tgt,memory)
- forward(tgt,memory,tgt_mask=None,memory_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None,tgt_is_causal=None,memory_is_causal=False)[source]#
Pass the inputs (and mask) through the decoder layer in turn.
- Parameters
tgt (Tensor) – the sequence to the decoder (required).
memory (Tensor) – the sequence from the last layer of the encoder (required).
tgt_mask (Optional[Tensor]) – the mask for the tgt sequence (optional).
memory_mask (Optional[Tensor]) – the mask for the memory sequence (optional).
tgt_key_padding_mask (Optional[Tensor]) – the mask for the tgt keys per batch (optional).
memory_key_padding_mask (Optional[Tensor]) – the mask for the memory keys per batch (optional).
tgt_is_causal (Optional[bool]) – If specified, applies a causal mask as
tgtmask.Default:None; try to detect a causal mask.Warning:tgt_is_causalprovides a hint thattgt_maskisthe causal mask. Providing incorrect hints can result inincorrect execution, including forward and backwardcompatibility.memory_is_causal (bool) – If specified, applies a causal mask as
memorymask.Default:False.Warning:memory_is_causalprovides a hint thatmemory_maskis the causal mask. Providing incorrecthints can result in incorrect execution, includingforward and backward compatibility.
- Return type
- Shape:
see the docs in
Transformer.