Rate this Page

torch.nn.functional.gumbel_softmax#

torch.nn.functional.gumbel_softmax(logits,tau=1,hard=False,eps=1e-10,dim=-1)[source]#

Sample from the Gumbel-Softmax distribution (Link 1Link 2) and optionally discretize.

Parameters
  • logits (Tensor) –[…, num_features] unnormalized log probabilities

  • tau (float) – non-negative scalar temperature

  • hard (bool) – ifTrue, the returned samples will be discretized as one-hot vectors,but will be differentiated as if it is the soft sample in autograd

  • dim (int) – A dimension along which softmax will be computed. Default: -1.

Returns

Sampled tensor of same shape aslogits from the Gumbel-Softmax distribution.Ifhard=True, the returned samples will be one-hot, otherwise they willbe probability distributions that sum to 1 acrossdim.

Return type

Tensor

Note

This function is here for legacy reasons, may be removed from nn.Functional in the future.

Note

The main trick forhard is to doy_hard - y_soft.detach() + y_soft

It achieves two things:- makes the output value exactly one-hot(since we add then subtract y_soft value)- makes the gradient equal to y_soft gradient(since we strip all other gradients)

Examples::
>>>logits=torch.randn(20,32)>>># Sample soft categorical using reparametrization trick:>>>F.gumbel_softmax(logits,tau=1,hard=False)>>># Sample hard categorical using "Straight-through" trick:>>>F.gumbel_softmax(logits,tau=1,hard=True)