6

I am doing multi class segmentation using UNet. My input to the model isHxWxC and my output is,

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

UsingSparseCategoricalCrossentropy I can train the network fine. Now I would like to also try dice coefficient as the loss function. Implemented as follows,

def dice_loss(y_true, y_pred, smooth=1e-6):    y_true = tf.cast(y_true, tf.float32)    y_pred = tf.math.sigmoid(y_pred)    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth    denominator = tf.reduce_sum(y_true + y_pred) + smooth    return 1 - numerator / denominator

However, I am actually getting an increasing loss instead of decreasing loss. I have checked multiple sources but all the material I find uses dice loss for binary classification and not multiclass. So my question is there a problem with the implementation.

Innat's user avatar
Innat
17.3k6 gold badges60 silver badges115 bronze badges
askedDec 3, 2020 at 12:06
Hamza Yerlikaya's user avatar
1
  • @DavidS thanks that did fix the problemCommentedDec 5, 2020 at 14:09

4 Answers4

7

The problem is that your dice loss doesn't address the number of classes you have but rather assumes binary case, so it might explain the increase in your loss.

You should implement generalized dice loss that accounts for all the classes and return the value for all of them.

Something like the following:

def dice_coef_9cat(y_true, y_pred, smooth=1e-7):    '''    Dice coefficient for 10 categories. Ignores background pixel label 0    Pass to model as metric during compile statement    '''    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=10)[...,1:])    y_pred_f = K.flatten(y_pred[...,1:])    intersect = K.sum(y_true_f * y_pred_f, axis=-1)    denom = K.sum(y_true_f + y_pred_f, axis=-1)    return K.mean((2. * intersect / (denom + smooth)))def dice_coef_9cat_loss(y_true, y_pred):    '''    Dice loss to minimize. Pass to model as loss during compile statement    '''    return 1 - dice_coef_9cat(y_true, y_pred)

This snippet is taken fromhttps://github.com/keras-team/keras/issues/9395#issuecomment-370971561

This is for 9 categories, while you should adjust to the number of categories you have.

answeredDec 4, 2020 at 9:09
David's user avatar
Sign up to request clarification or add additional context in comments.

1 Comment

This does work however it always trains worse thanSparseCategoricalCrossentropy. I think the problem isy_pred is not converted to one hot when converted loss calculates as expected how ever it results in aValueError: No gradients provided for any variable: error
1

If you are doing multi-class segmentation, the 'softmax' activation function should be used.

I would recommend using one-hot encoded ground-truth masks. This needs to be done outside of the loss calculation code.

The generalized dice loss and others were implemented in the following link:

https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py

answeredMar 16, 2021 at 5:47
teriyaki's user avatar

2 Comments

"This needs to be done outside of the loss calculation code" is there a reason to do the one-hot encoding outside of the loss function?
Actually you can do the one-hot encoding inside the loss function but if you have multiple loss functions, then you would need to add one-hot encoding into each loss function. And this makes the codes longer.
1

You can use Dice Loss fromsegmentation-models-pytorch library which supports multi-class segmentation. You could install the library by:

pip install -U segmentation-models-pytorch

enter image description here

This is the reference:

https://smp.readthedocs.io/en/latest/losses.html

answeredNov 26, 2023 at 9:19
Hamzah Al-Qadasi's user avatar

1 Comment

Unfortunately this DiceLoss implementation does not work for multiclass. It says theTarget size (torch.Size([5, 1, 512, 512])) must be the same as input size (torch.Size([5, 2, 512, 512])) but no documentation on constructing the right shape (maybe one-hot in the target?) The package seems not to be maintained anymore either.
0

Not sure why but the last layer has "sigmoid" as activation function.For Multiclass segmentation it has to be "softmax" not "sigmoid".

Also, the loss you are considering is SparseCategoricalCrossentropy along with a multichannel output. If the last layer would have just 1 channel (when doing multi class segmentation), then using SparseCategoricalCrossentropy makes sense but when you have multiple channels as your output the loss which is to be considered is "CategoricalCrossentropy".

Your loss is increasing as the activation and output channels aren't matching (as mentioned above).

change

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

to

outputs = layers.Conv2D(n_classes, (1, 1), activation='softmax')(decoder0)
answeredSep 1, 2021 at 11:10
Vedant Joshi's user avatar

Comments

Your Answer

Sign up orlog in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

By clicking “Post Your Answer”, you agree to ourterms of service and acknowledge you have read ourprivacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.