BCEWithLogitsLoss#
- classtorch.nn.modules.loss.BCEWithLogitsLoss(weight=None,size_average=None,reduce=None,reduction='mean',pos_weight=None)[source]#
This loss combines aSigmoid layer and theBCELoss in one singleclass. This version is more numerically stable than using a plainSigmoidfollowed by aBCELoss as, by combining the operations into one layer,we take advantage of the log-sum-exp trick for numerical stability.
The unreduced (i.e. with
reductionset to'none') loss can be described as:where is the batch size. If
reductionis not'none'(default'mean'), thenThis is used for measuring the error of a reconstruction in for examplean auto-encoder. Note that the targetst[i] should be numbersbetween 0 and 1.
It’s possible to trade off recall and precision by adding weights to positive examples.In the case of multi-label classification the loss can be described as:
where is the class number ( for multi-label binary classification, for single-label binary classification), is the number of the sample in the batch and is the weight of the positive answer for the class.
increases the recall, increases the precision.
For example, if a dataset contains 100 positive and 300 negative examples of a single class,then
pos_weightfor the class should be equal to.The loss would act as if the dataset contains positive examples.Examples
>>>target=torch.ones([10,64],dtype=torch.float32)# 64 classes, batch size = 10>>>output=torch.full([10,64],1.5)# A prediction (logit)>>>pos_weight=torch.ones([64])# All weights are equal to 1>>>criterion=torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)>>>criterion(output,target)# -log(sigmoid(1.5))tensor(0.20...)
In the above example, the
pos_weighttensor’s elements correspond to the 64 distinct classesin a multi-label binary classification scenario. Each element inpos_weightis designed to adjust theloss function based on the imbalance between negative and positive samples for the respective class.This approach is useful in datasets with varying levels of class imbalance, ensuring that the losscalculation accurately accounts for the distribution in each class.- Parameters
weight (Tensor,optional) – a manual rescaling weight given to the lossof each batch element. If given, has to be a Tensor of sizenbatch.
size_average (bool,optional) – Deprecated (see
reduction). By default,the losses are averaged over each loss element in the batch. Note that forsome losses, there are multiple elements per sample. If the fieldsize_averageis set toFalse, the losses are instead summed for each minibatch. IgnoredwhenreduceisFalse. Default:Truereduce (bool,optional) – Deprecated (see
reduction). By default, thelosses are averaged or summed over observations for each minibatch dependingonsize_average. WhenreduceisFalse, returns a loss perbatch element instead and ignoressize_average. Default:Truereduction (str,optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'.'none': no reduction will be applied,'mean': the sum of the output will be divided by the number ofelements in the output,'sum': the output will be summed. Note:size_averageandreduceare in the process of being deprecated, and in the meantime,specifying either of those two args will overridereduction. Default:'mean'pos_weight (Tensor,optional) – a weight of positive examples to be broadcasted with target.Must be a tensor with equal size along the class dimension to the number of classes.Pay close attention to PyTorch’s broadcasting semantics in order to achieve the desiredoperations. For a target of size [B, C, H, W] (where B is batch size) pos_weight ofsize [B, C, H, W] will apply different pos_weights to each element of the batch or[C, H, W] the same pos_weights across the batch. To apply the same positive weightalong all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].Default:
None
- Shape:
Input:, where means any number of dimensions.
Target:, same shape as the input.
Output: scalar. If
reductionis'none', then, sameshape as input.
Examples
>>>loss=nn.BCEWithLogitsLoss()>>>input=torch.randn(3,requires_grad=True)>>>target=torch.empty(3).random_(2)>>>output=loss(input,target)>>>output.backward()