CrossEntropyLoss#
- classtorch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)[source]#
This criterion computes the cross entropy loss between input logitsand target.
It is useful when training a classification problem withC classes.If provided, the optional argument
weightshould be a 1DTensorassigning weight to each of the classes.This is particularly useful when you have an unbalanced training set.Theinput is expected to contain the unnormalized logits for each class (which donot needto be positive or sum to 1, in general).input has to be a Tensor of size for unbatched input, or with for theK-dimensional case. The last being useful for higher dimension inputs, suchas computing cross entropy loss per-pixel for 2D images.
Thetarget that this criterion expects should contain either:
Class indices in the range where is the number of classes; ifignore_index is specified, this loss also accepts this class index (this indexmay not necessarily be in the class range). The unreduced (i.e. with
reductionset to'none') loss for this case can be described as:where is the input, is the target, is the weight, is the number of classes, and spans the minibatch dimension as well as for theK-dimensional case. If
reductionis not'none'(default'mean'), thenNote that this case is equivalent to applying
LogSoftmaxon an input, followed byNLLLoss.Probabilities for each class; useful when labels beyond a single class per minibatch itemare required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with
reductionset to'none') loss for this case can be described as:where is the input, is the target, is the weight, is the number of classes, and spans the minibatch dimension as well as for theK-dimensional case. If
reductionis not'none'(default'mean'), then
Note
The performance of this criterion is generally better whentarget contains classindices, as this allows for optimized computation. Consider providingtarget asclass probabilities only when a single class label per minibatch item is too restrictive.
- Parameters
weight (Tensor,optional) – a manual rescaling weight given to each class.If given, has to be a Tensor of sizeC.
size_average (bool,optional) – Deprecated (see
reduction). By default,the losses are averaged over each loss element in the batch. Note that forsome losses, there are multiple elements per sample. If the fieldsize_averageis set toFalse, the losses are instead summed for each minibatch. IgnoredwhenreduceisFalse. Default:Trueignore_index (int,optional) – Specifies a target value that is ignoredand does not contribute to the input gradient. When
size_averageisTrue, the loss is averaged over non-ignored targets. Note thatignore_indexis only applicable when the target contains class indices.reduce (bool,optional) – Deprecated (see
reduction). By default, thelosses are averaged or summed over observations for each minibatch dependingonsize_average. WhenreduceisFalse, returns a loss perbatch element instead and ignoressize_average. Default:Truereduction (str,optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction willbe applied,'mean': the weighted mean of the output is taken,'sum': the output will be summed. Note:size_averageandreduceare in the process of being deprecated, and inthe meantime, specifying either of those two args will overridereduction. Default:'mean'label_smoothing (float,optional) – A float in [0.0, 1.0]. Specifies the amountof smoothing when computing the loss, where 0.0 means no smoothing. The targetsbecome a mixture of the original ground truth and a uniform distribution as described inRethinking the Inception Architecture for Computer Vision. Default:.
- Shape:
Input: Shape, or within the case ofK-dimensional loss.
Target: If containing class indices, shape, or with in the case of K-dimensional loss where each value should be between. Thetarget data type is required to be long when using class indices. If containing class probabilities, thetarget must be the same shape input, and each value should be between. This means the targetdata type is required to be float when using class probabilities. Note that PyTorch does not strictly enforceprobability constraints on the class probabilities and that it is the user’s responsibility to ensure
targetcontains valid probability distributions (see below examples section for more details).Output: If reduction is ‘none’, shape, or within the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.
where:
Examples
>>># Example of target with class indices>>>loss=nn.CrossEntropyLoss()>>>input=torch.randn(3,5,requires_grad=True)>>>target=torch.empty(3,dtype=torch.long).random_(5)>>>output=loss(input,target)>>>output.backward()>>>>>># Example of target with class probabilities>>>input=torch.randn(3,5,requires_grad=True)>>>target=torch.randn(3,5).softmax(dim=1)>>>output=loss(input,target)>>>output.backward()
Note
When
targetcontains class probabilities, it should consist of soft labels—that is,eachtargetentry should represent a probability distribution over the possible classes for a given data sample,with individual probabilities between[0,1]and the total distribution summing to 1.This is why thesoftmax()function is applied to thetargetin the class probabilities example above.PyTorch does not validate whether the values provided in
targetlie in the range[0,1]or whether the distribution of each data sample sums to1.No warning will be raised and it is the user’s responsibilityto ensure thattargetcontains valid probability distributions.Providing arbitrary values may yield misleading loss values and unstable gradients during training.Examples
>>># Example of target with incorrectly specified class probabilities>>>loss=nn.CrossEntropyLoss()>>>torch.manual_seed(283)>>>input=torch.randn(3,5,requires_grad=True)>>>target=torch.randn(3,5)>>># Provided target class probabilities are not in range [0,1]>>>targettensor([[ 0.7105, 0.4446, 2.0297, 0.2671, -0.6075], [-1.0496, -0.2753, -0.3586, 0.9270, 1.0027], [ 0.7551, 0.1003, 1.3468, -0.3581, -0.9569]])>>># Provided target class probabilities do not sum to 1>>>target.sum(axis=1)tensor([2.8444, 0.2462, 0.8873])>>># No error message and possible misleading loss value>>>loss(input,target).item()4.6379876136779785>>>>>># Example of target with correctly specified class probabilities>>># Use .softmax() to ensure true probability distribution>>>target_new=target.softmax(dim=1)>>># New target class probabilities all in range [0,1]>>>target_newtensor([[0.1559, 0.1195, 0.5830, 0.1000, 0.0417], [0.0496, 0.1075, 0.0990, 0.3579, 0.3860], [0.2607, 0.1355, 0.4711, 0.0856, 0.0471]])>>># New target class probabilities sum to 1>>>target_new.sum(axis=1)tensor([1.0000, 1.0000, 1.0000])>>>loss(input,target_new).item()2.55349063873291