Rate this Page

torch.masked#

Created On: Aug 15, 2022 | Last Updated On: Jun 17, 2025

Introduction#

Motivation#

Warning

The PyTorch API of masked tensors is in the prototype stage and may or may not change in the future.

MaskedTensor serves as an extension totorch.Tensor that provides the user with the ability to:

  • use any masked semantics (e.g. variable length tensors, nan* operators, etc.)

  • differentiate between 0 and NaN gradients

  • various sparse applications (see tutorial below)

“Specified” and “unspecified” have a long history in PyTorch without formal semantics and certainly withoutconsistency; indeed, MaskedTensor was born out of a build up of issues that the vanillatorch.Tensorclass could not properly address. Thus, a primary goal of MaskedTensor is to become the source of truth forsaid “specified” and “unspecified” values in PyTorch where they are a first class citizen instead of an afterthought.In turn, this should further unlocksparsity’s potential,enable safer and more consistent operators, and provide a smoother and more intuitive experiencefor users and developers alike.

What is a MaskedTensor?#

A MaskedTensor is a tensor subclass that consists of 1) an input (data), and 2) a mask. The mask tells uswhich entries from the input should be included or ignored.

By way of example, suppose that we wanted to mask out all values that are equal to 0 (represented by the gray)and take the max:

_images/tensor_comparison.jpg

On top is the vanilla tensor example while the bottom is MaskedTensor where all the 0’s are masked out.This clearly yields a different result depending on whether we have the mask, but this flexible structureallows the user to systematically ignore any elements they’d like during computation.

There are already a number of existing tutorials that we’ve written to help users onboard, such as:

Supported Operators#

Unary Operators#

Unary operators are operators that only contain only a single input.Applying them to MaskedTensors is relatively straightforward: if the data is masked out at a given index,we apply the operator, otherwise we’ll continue to mask out the data.

The available unary operators are:

abs

Computes the absolute value of each element ininput.

absolute

Alias fortorch.abs()

acos

Returns a new tensor with the arccosine (in radians) of each element ininput.

arccos

Alias fortorch.acos().

acosh

Returns a new tensor with the inverse hyperbolic cosine of the elements ofinput.

arccosh

Alias fortorch.acosh().

angle

Computes the element-wise angle (in radians) of the giveninput tensor.

asin

Returns a new tensor with the arcsine of the elements (in radians) in theinput tensor.

arcsin

Alias fortorch.asin().

asinh

Returns a new tensor with the inverse hyperbolic sine of the elements ofinput.

arcsinh

Alias fortorch.asinh().

atan

Returns a new tensor with the arctangent of the elements (in radians) in theinput tensor.

arctan

Alias fortorch.atan().

atanh

Returns a new tensor with the inverse hyperbolic tangent of the elements ofinput.

arctanh

Alias fortorch.atanh().

bitwise_not

Computes the bitwise NOT of the given input tensor.

ceil

Returns a new tensor with the ceil of the elements ofinput, the smallest integer greater than or equal to each element.

clamp

Clamps all elements ininput into the range[min,max].

clip

Alias fortorch.clamp().

conj_physical

Computes the element-wise conjugate of the giveninput tensor.

cos

Returns a new tensor with the cosine of the elements ofinput given in radians.

cosh

Returns a new tensor with the hyperbolic cosine of the elements ofinput.

deg2rad

Returns a new tensor with each of the elements ofinput converted from angles in degrees to radians.

digamma

Alias fortorch.special.digamma().

erf

Alias fortorch.special.erf().

erfc

Alias fortorch.special.erfc().

erfinv

Alias fortorch.special.erfinv().

exp

Returns a new tensor with the exponential of the elements of the input tensorinput.

exp2

Alias fortorch.special.exp2().

expm1

Alias fortorch.special.expm1().

fix

Alias fortorch.trunc()

floor

Returns a new tensor with the floor of the elements ofinput, the largest integer less than or equal to each element.

frac

Computes the fractional portion of each element ininput.

lgamma

Computes the natural logarithm of the absolute value of the gamma function oninput.

log

Returns a new tensor with the natural logarithm of the elements ofinput.

log10

Returns a new tensor with the logarithm to the base 10 of the elements ofinput.

log1p

Returns a new tensor with the natural logarithm of (1 +input).

log2

Returns a new tensor with the logarithm to the base 2 of the elements ofinput.

logit

Alias fortorch.special.logit().

i0

Alias fortorch.special.i0().

isnan

Returns a new tensor with boolean elements representing if each element ofinput is NaN or not.

nan_to_num

ReplacesNaN, positive infinity, and negative infinity values ininput with the values specified bynan,posinf, andneginf, respectively.

neg

Returns a new tensor with the negative of the elements ofinput.

negative

Alias fortorch.neg()

positive

Returnsinput.

pow

Takes the power of each element ininput withexponent and returns a tensor with the result.

rad2deg

Returns a new tensor with each of the elements ofinput converted from angles in radians to degrees.

reciprocal

Returns a new tensor with the reciprocal of the elements ofinput

round

Rounds elements ofinput to the nearest integer.

rsqrt

Returns a new tensor with the reciprocal of the square-root of each of the elements ofinput.

sigmoid

Alias fortorch.special.expit().

sign

Returns a new tensor with the signs of the elements ofinput.

sgn

This function is an extension of torch.sign() to complex tensors.

signbit

Tests if each element ofinput has its sign bit set or not.

sin

Returns a new tensor with the sine of the elements in theinput tensor, where each value in this input tensor is in radians.

sinc

Alias fortorch.special.sinc().

sinh

Returns a new tensor with the hyperbolic sine of the elements ofinput.

sqrt

Returns a new tensor with the square-root of the elements ofinput.

square

Returns a new tensor with the square of the elements ofinput.

tan

Returns a new tensor with the tangent of the elements in theinput tensor, where each value in this input tensor is in radians.

tanh

Returns a new tensor with the hyperbolic tangent of the elements ofinput.

trunc

Returns a new tensor with the truncated integer values of the elements ofinput.

The available inplace unary operators are all of the aboveexcept:

angle

Computes the element-wise angle (in radians) of the giveninput tensor.

positive

Returnsinput.

signbit

Tests if each element ofinput has its sign bit set or not.

isnan

Returns a new tensor with boolean elements representing if each element ofinput is NaN or not.

Binary Operators#

As you may have seen in the tutorial,MaskedTensor also has binary operations implemented with the caveatthat the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if youneed support for a particular operator or have proposed semantics for how they should behave instead, please openan issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that usersknow exactly what is going on and are being intentional about their decisions with masked semantics.

The available binary operators are:

add

Addsother, scaled byalpha, toinput.

atan2

Element-wise arctangent ofinputi/otheri\text{input}_{i} / \text{other}_{i} with consideration of the quadrant.

arctan2

Alias fortorch.atan2().

bitwise_and

Computes the bitwise AND ofinput andother.

bitwise_or

Computes the bitwise OR ofinput andother.

bitwise_xor

Computes the bitwise XOR ofinput andother.

bitwise_left_shift

Computes the left arithmetic shift ofinput byother bits.

bitwise_right_shift

Computes the right arithmetic shift ofinput byother bits.

div

Divides each element of the inputinput by the corresponding element ofother.

divide

Alias fortorch.div().

floor_divide

fmod

Applies C++'sstd::fmod entrywise.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

mul

Multipliesinput byother.

multiply

Alias fortorch.mul().

nextafter

Return the next floating-point value afterinput towardsother, elementwise.

remainder

ComputesPython's modulus operation entrywise.

sub

Subtractsother, scaled byalpha, frominput.

subtract

Alias fortorch.sub().

true_divide

Alias fortorch.div() withrounding_mode=None.

eq

Computes element-wise equality

ne

Computesinputother\text{input} \neq \text{other} element-wise.

le

Computesinputother\text{input} \leq \text{other} element-wise.

ge

Computesinputother\text{input} \geq \text{other} element-wise.

greater

Alias fortorch.gt().

greater_equal

Alias fortorch.ge().

gt

Computesinput>other\text{input} > \text{other} element-wise.

less_equal

Alias fortorch.le().

lt

Computesinput<other\text{input} < \text{other} element-wise.

less

Alias fortorch.lt().

maximum

Computes the element-wise maximum ofinput andother.

minimum

Computes the element-wise minimum ofinput andother.

fmax

Computes the element-wise maximum ofinput andother.

fmin

Computes the element-wise minimum ofinput andother.

not_equal

Alias fortorch.ne().

The available inplace binary operators are all of the aboveexcept:

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

equal

True if two tensors have the same size and elements,False otherwise.

fmin

Computes the element-wise minimum ofinput andother.

minimum

Computes the element-wise minimum ofinput andother.

fmax

Computes the element-wise maximum ofinput andother.

Reductions#

The following reductions are available (with autograd support). For more information, theOverview tutorialdetails some examples of reductions, while theAdvanced semantics tutorialhas some further in-depth discussions about how we decided on certain reduction semantics.

sum

Returns the sum of all elements in theinput tensor.

mean

amin

Returns the minimum value of each slice of theinput tensor in the given dimension(s)dim.

amax

Returns the maximum value of each slice of theinput tensor in the given dimension(s)dim.

argmin

Returns the indices of the minimum value(s) of the flattened tensor or along a dimension

argmax

Returns the indices of the maximum value of all elements in theinput tensor.

prod

Returns the product of all elements in theinput tensor.

all

Tests if all elements ininput evaluate toTrue.

norm

Returns the matrix norm or vector norm of a given tensor.

var

Calculates the variance over the dimensions specified bydim.

std

Calculates the standard deviation over the dimensions specified bydim.

View and select functions#

We’ve included a number of view and select functions as well; intuitively, these operators will apply toboth the data and the mask and then wrap the result in aMaskedTensor. For a quick example,considerselect():

>>>data=torch.arange(12,dtype=torch.float).reshape(3,4)>>>datatensor([[0.,1.,2.,3.],[4.,5.,6.,7.],[8.,9.,10.,11.]])>>>mask=torch.tensor([[True,False,False,True],[False,True,False,False],[True,True,True,True]])>>>mt=masked_tensor(data,mask)>>>data.select(0,1)tensor([4.,5.,6.,7.])>>>mask.select(0,1)tensor([False,True,False,False])>>>mt.select(0,1)MaskedTensor([--,5.0000,--,--])

The following ops are currently supported:

atleast_1d

Returns a 1-dimensional view of each input tensor with zero dimensions.

broadcast_tensors

Broadcasts the given tensors according toBroadcasting semantics.

broadcast_to

Broadcastsinput to the shapeshape.

cat

Concatenates the given sequence of tensors intensors in the given dimension.

chunk

Attempts to split a tensor into the specified number of chunks.

column_stack

Creates a new tensor by horizontally stacking the tensors intensors.

dsplit

Splitsinput, a tensor with three or more dimensions, into multiple tensors depthwise according toindices_or_sections.

flatten

Flattensinput by reshaping it into a one-dimensional tensor.

hsplit

Splitsinput, a tensor with one or more dimensions, into multiple tensors horizontally according toindices_or_sections.

hstack

Stack tensors in sequence horizontally (column wise).

kron

Computes the Kronecker product, denoted by\otimes, ofinput andother.

meshgrid

Creates grids of coordinates specified by the 1D inputs inattr:tensors.

narrow

Returns a new tensor that is a narrowed version ofinput tensor.

nn.functional.unfold

Extract sliding local blocks from a batched input tensor.

ravel

Return a contiguous flattened tensor.

select

Slices theinput tensor along the selected dimension at the given index.

split

Splits the tensor into chunks.

stack

Concatenates a sequence of tensors along a new dimension.

t

Expectsinput to be <= 2-D tensor and transposes dimensions 0 and 1.

transpose

Returns a tensor that is a transposed version ofinput.

vsplit

Splitsinput, a tensor with two or more dimensions, into multiple tensors vertically according toindices_or_sections.

vstack

Stack tensors in sequence vertically (row wise).

Tensor.expand

Returns a new view of theself tensor with singleton dimensions expanded to a larger size.

Tensor.expand_as

Expand this tensor to the same size asother.

Tensor.reshape

Returns a tensor with the same data and number of elements asself but with the specified shape.

Tensor.reshape_as

Returns this tensor as the same shape asother.

Tensor.unfold

Returns a view of the original tensor which contains all slices of sizesize fromself tensor in the dimensiondimension.

Tensor.view

Returns a new tensor with the same data as theself tensor but of a differentshape.