Conv3d#
- classtorch.nn.modules.conv.Conv3d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode='zeros',device=None,dtype=None)[source]#
Applies a 3D convolution over an input signal composed of several inputplanes.
In the simplest case, the output value of the layer with input sizeand output can be precisely described as:
where is the valid 3Dcross-correlation operator
This module supportsTensorFloat32.
On certain ROCm devices, when using float16 inputs this module will usedifferent precision for backward.
stridecontrols the stride for the cross-correlation.paddingcontrols the amount of padding applied to the input. Itcan be either a string {‘valid’, ‘same’} or a tuple of ints giving theamount of implicit padding applied on both sides.dilationcontrols the spacing between the kernel points; also known as the à trous algorithm.It is harder to describe, but thislink has a nice visualization of whatdilationdoes.groupscontrols the connections between inputs and outputs.in_channelsandout_channelsmust both be divisible bygroups. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two convlayers side by side, each seeing half the input channelsand producing half the output channels, and both subsequentlyconcatenated.
At groups=
in_channels, each input channel is convolved withits own set of filters (of size).
The parameters
kernel_size,stride,padding,dilationcan either be:a single
int– in which case the same value is used for the depth, height and width dimensiona
tupleof three ints – in which case, the firstint is used for the depth dimension,the secondint for the height dimension and the thirdint for the width dimension
Note
Whengroups == in_channels andout_channels == K * in_channels,whereK is a positive integer, this operation is also known as a “depthwise convolution”.
In other words, for an input of size,a depthwise convolution with a depthwise multiplierK can be performed with the arguments.
Note
In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting
torch.backends.cudnn.deterministic=True. SeeReproducibility for more information.Note
padding='valid'is the same as no padding.padding='same'padsthe input so the output has the shape as the input. However, this modedoesn’t support any stride values other than 1.Note
This module supports complex data types i.e.
complex32,complex64,complex128.- Parameters
in_channels (int) – Number of channels in the input image
out_channels (int) – Number of channels produced by the convolution
stride (int ortuple,optional) – Stride of the convolution. Default: 1
padding (int,tuple orstr,optional) – Padding added to all six sides ofthe input. Default: 0
dilation (int ortuple,optional) – Spacing between kernel elements. Default: 1
groups (int,optional) – Number of blocked connections from input channels to output channels. Default: 1
bias (bool,optional) – If
True, adds a learnable bias to the output. Default:Truepadding_mode (str,optional) –
'zeros','reflect','replicate'or'circular'. Default:'zeros'
- Shape:
Input: or
Output: or,where
- Variables
Examples:
>>># With square kernels and equal stride>>>m=nn.Conv3d(16,33,3,stride=2)>>># non-square kernels and unequal stride and with padding>>>m=nn.Conv3d(16,33,(3,5,2),stride=(2,1,1),padding=(4,2,0))>>>input=torch.randn(20,16,10,50,100)>>>output=m(input)