torch.nn.functional.grid_sample#
- torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)[source]#
Compute grid sample.
Given an
inputand a flow-fieldgrid, computes theoutputusinginputvalues and pixel locations fromgrid.Currently, only spatial (4-D) and volumetric (5-D)
inputaresupported.In the spatial (4-D) case, for
inputwith shape andgridwith shape, the output will have shape.For each output location
output[n,:,h,w], the size-2 vectorgrid[n,h,w]specifiesinputpixel locationsxandy,which are used to interpolate the output valueoutput[n,:,h,w].In the case of 5D inputs,grid[n,d,h,w]specifies thex,y,zpixel locations for interpolatingoutput[n,:,d,h,w].modeargument specifiesnearestorbilinearinterpolation method to sample the input pixels.gridspecifies the sampling pixel locations normalized by theinputspatial dimensions. Therefore, it should have most values inthe range of[-1,1]. For example, valuesx=-1,y=-1is theleft-top pixel ofinput, and valuesx=1,y=1is theright-bottom pixel ofinput.If
gridhas values outside the range of[-1,1], the correspondingoutputs are handled as defined bypadding_mode. Options arepadding_mode="zeros": use0for out-of-bound grid locations,padding_mode="border": use border values for out-of-bound grid locations,padding_mode="reflection": use values at locations reflected bythe border for out-of-bound grid locations. For location far awayfrom the border, it will keep being reflected until becoming in bound,e.g., (normalized) pixel locationx=-3.5reflects by border-1and becomesx'=1.5, then reflects by border1and becomesx''=-0.5.
Note
This function is often used in conjunction with
affine_grid()to buildSpatial Transformer Networks .Note
When using the CUDA backend, this operation may induce nondeterministicbehaviour in its backward pass that is not easily switched off.Please see the notes onReproducibility for background.
Note
NaN values in
gridwould be interpreted as-1.- Parameters
input (Tensor) – input of shape (4-D case)or (5-D case)
grid (Tensor) – flow-field of shape (4-D case)or (5-D case)
mode (str) – interpolation mode to calculate output values
'bilinear'|'nearest'|'bicubic'. Default:'bilinear'Note:mode='bicubic'supports only 4-D input.Whenmode='bilinear'and the input is 5-D, the interpolation modeused internally will actually be trilinear. However, when the input is 4-D,the interpolation mode will legitimately be bilinear.padding_mode (str) – padding mode for outside grid values
'zeros'|'border'|'reflection'. Default:'zeros'align_corners (bool,optional) – Geometrically, we consider the pixels of theinput as squares rather than points.If set to
True, the extrema (-1and1) are considered as referringto the center points of the input’s corner pixels. If set toFalse, theyare instead considered as referring to the corner points of the input’s cornerpixels, making the sampling more resolution agnostic.This option parallels thealign_cornersoption ininterpolate(), and so whichever option is used hereshould also be used there to resize the input image before grid sampling.Default:False
- Returns
output Tensor
- Return type
output (Tensor)
Warning
When
align_corners=True, the grid positions depend on the pixelsize relative to the input image size, and so the locations sampled bygrid_sample()will differ for the same input given at differentresolutions (that is, after being upsampled or downsampled).The default behavior up to version 1.2.0 wasalign_corners=True.Since then, the default behavior has been changed toalign_corners=False,in order to bring it in line with the default forinterpolate().Note
mode='bicubic'is implemented using thecubic convolution algorithm with.The constant might be different from packages to packages.For example,PIL andOpenCV use -0.5 and -0.75 respectively.This algorithm may “overshoot” the range of values it’s interpolating.For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].Clamp the results withtorch.clamp()to ensure they are within the valid range.