torch.multinomial#
- torch.multinomial(input,num_samples,replacement=False,*,generator=None,out=None)→LongTensor#
Returns a tensor where each row contains
num_samplesindices sampledfrom the multinomial (a stricter definition would be multivariate,refer totorch.distributions.multinomial.Multinomialfor more details)probability distribution located in the corresponding rowof tensorinput.Note
The rows of
inputdo not need to sum to one (in which case we usethe values as weights), but must be non-negative, finite and havea non-zero sum.Indices are ordered from left to right according to when each was sampled(first samples are placed in first column).
If
inputis a vector,outis a vector of sizenum_samples.If
inputis a matrix withm rows,outis an matrix of shape.If replacement is
True, samples are drawn with replacement.If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.
Note
When drawn without replacement,
num_samplesmust be lower thannumber of non-zero elements ininput(or the min number of non-zeroelements in each row ofinputif it is a matrix).- Parameters
- Keyword Arguments
generator (
torch.Generator, optional) – a pseudorandom number generator for samplingout (Tensor,optional) – the output tensor.
Example:
>>>weights=torch.tensor([0,10,3,0],dtype=torch.float)# create a tensor of weights>>>torch.multinomial(weights,2)tensor([1, 2])>>>torch.multinomial(weights,5)# ERROR!RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement>>>torch.multinomial(weights,4,replacement=True)tensor([ 2, 1, 1, 1])