Rate this Page

torch.multinomial#

torch.multinomial(input,num_samples,replacement=False,*,generator=None,out=None)LongTensor#

Returns a tensor where each row containsnum_samples indices sampledfrom the multinomial (a stricter definition would be multivariate,refer totorch.distributions.multinomial.Multinomial for more details)probability distribution located in the corresponding rowof tensorinput.

Note

The rows ofinput do not need to sum to one (in which case we usethe values as weights), but must be non-negative, finite and havea non-zero sum.

Indices are ordered from left to right according to when each was sampled(first samples are placed in first column).

Ifinput is a vector,out is a vector of sizenum_samples.

Ifinput is a matrix withm rows,out is an matrix of shape(m×num_samples)(m \times \text{num\_samples}).

If replacement isTrue, samples are drawn with replacement.

If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.

Note

When drawn without replacement,num_samples must be lower thannumber of non-zero elements ininput (or the min number of non-zeroelements in each row ofinput if it is a matrix).

Parameters
  • input (Tensor) – the input tensor containing probabilities

  • num_samples (int) – number of samples to draw

  • replacement (bool,optional) – whether to draw with replacement or not

Keyword Arguments
  • generator (torch.Generator, optional) – a pseudorandom number generator for sampling

  • out (Tensor,optional) – the output tensor.

Example:

>>>weights=torch.tensor([0,10,3,0],dtype=torch.float)# create a tensor of weights>>>torch.multinomial(weights,2)tensor([1, 2])>>>torch.multinomial(weights,5)# ERROR!RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement>>>torch.multinomial(weights,4,replacement=True)tensor([ 2,  1,  1,  1])