torch.functional.meshgrid#
- torch.functional.meshgrid(*tensors,indexing=None)[source]#
Creates grids of coordinates specified by the 1D inputs inattr:tensors.
This is helpful when you want to visualize data over somerange of inputs. See below for a plotting example.
Given 1D tensors asinputs with corresponding sizes,this creates N-dimensional tensors, each with shape wherethe output is constructed by expandingto the result shape.
Note
0D inputs are treated equivalently to 1D inputs of asingle element.
Warning
torch.meshgrid(*tensors) currently has the same behavioras callingnumpy.meshgrid(*arrays, indexing=’ij’).
In the futuretorch.meshgrid will transition toindexing=’xy’ as the default.
pytorch/pytorch#50276 tracksthis issue with the goal of migrating to NumPy’s behavior.
See also
torch.cartesian_prod()has the same effect but itcollects the data in a tensor of vectors.- Parameters
tensors (list ofTensor) – list of scalars or 1 dimensional tensors. Scalars will betreated as tensors of size automatically
(str, optional): the indexing mode, either “xy”or “ij”, defaults to “ij”. See warning for future changes.
If “xy” is selected, the first dimension correspondsto the cardinality of the second input and the seconddimension corresponds to the cardinality of the firstinput.
If “ij” is selected, the dimensions are in the sameorder as the cardinality of the inputs.
- Returns
If the input hastensors of size, then theoutput will also have tensors, where each tensoris of shape.
- Return type
seq (sequence of Tensors)
Example:
>>>x=torch.tensor([1,2,3])>>>y=torch.tensor([4,5,6])Observe the element-wise pairings across the grid, (1, 4),(1, 5), ..., (3, 6). This is the same thing as thecartesian product.>>>grid_x,grid_y=torch.meshgrid(x,y,indexing='ij')>>>grid_xtensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]])>>>grid_ytensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]])This correspondence can be seen when these grids arestacked properly.>>>torch.equal(torch.cat(tuple(torch.dstack([grid_x,grid_y]))),...torch.cartesian_prod(x,y))True`torch.meshgrid` is commonly used to produce a grid forplotting.>>>importmatplotlib.pyplotasplt>>>xs=torch.linspace(-5,5,steps=100)>>>ys=torch.linspace(-5,5,steps=100)>>>x,y=torch.meshgrid(xs,ys,indexing='xy')>>>z=torch.sin(torch.sqrt(x*x+y*y))>>>ax=plt.axes(projection='3d')>>>ax.plot_surface(x.numpy(),y.numpy(),z.numpy())>>>plt.show()
