Ground truth shape :torch.Size([4, 1, 512, 512]) Predicted shape:torch.Size([4, 59, 512, 512]) Loss:smp.losses.DiceLoss(mode='multiclass') Model: smp.Unet( encoder_name='resnet18', encoder_weights=None, in_channels=3, classes=59, activation='softmax2d' )
I also tried removingactivation='softmax2d' but the error is still there. Here is the complete traceback ---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)<timed eval> in <module>/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule) 458 ) 459 --> 460 self._run(model) 461 462 assert self.state.stopped/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model) 756 757 # dispatch `start_training` or `start_evaluating` or `start_predicting`--> 758 self.dispatch() 759 760 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self) 797 self.accelerator.start_predicting(self) 798 else:--> 799 self.accelerator.start_training(self) 800 801 def run_stage(self):/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer) 94 95 def start_training(self, trainer: 'pl.Trainer') -> None:---> 96 self.training_type_plugin.start_training(trainer) 97 98 def start_evaluating(self, trainer: 'pl.Trainer') -> None:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer) 142 def start_training(self, trainer: 'pl.Trainer') -> None: 143 # double dispatch to initiate the training loop--> 144 self._results = trainer.run_stage() 145 146 def start_evaluating(self, trainer: 'pl.Trainer') -> None:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self) 807 if self.predicting: 808 return self.run_predict()--> 809 return self.run_train() 810 811 def _pre_training_routine(self):/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self) 842 self.progress_bar_callback.disable() 843 --> 844 self.run_sanity_check(self.lightning_module) 845 846 self.checkpoint_connector.has_trained = False/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model) 1110 1111 # run eval step-> 1112 self.run_evaluation() 1113 1114 self.on_sanity_check_end()/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, on_epoch) 965 # lightning module methods 966 with self.profiler.profile("evaluation_step_and_end"):--> 967 output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) 968 output = self.evaluation_loop.evaluation_step_end(output) 969 /opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py in evaluation_step(self, batch, batch_idx, dataloader_idx) 172 model_ref._current_fx_name = "validation_step" 173 with self.trainer.profiler.profile("validation_step"):--> 174 output = self.trainer.accelerator.validation_step(args) 175 176 # capture any logged information/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, args) 224 225 with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():--> 226 return self.training_type_plugin.validation_step(*args) 227 228 def test_step(self, args: List[Union[Any, int]]) -> Optional[STEP_OUTPUT]:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs) 159 160 def validation_step(self, *args, **kwargs):--> 161 return self.lightning_module.validation_step(*args, **kwargs) 162 163 def test_step(self, *args, **kwargs):<ipython-input-148-9c5092280cd4> in validation_step(self, batch, batch_idx) 54 outputs=self(image) 55 print('before',outputs.shape,segment.shape,segment.max())---> 56 loss=self.criterion(outputs,segment) 57 print('loss',loss) 58 # dice=dice_coef_multilabel(labels,segment)/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else:--> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(),/opt/conda/lib/python3.7/site-packages/segmentation_models_pytorch/losses/dice.py in forward(self, y_pred, y_true) 93 y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W 94 else:---> 95 y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 96 y_true = y_true.permute(0, 2, 1) # H, C, H*W 97 RuntimeError: one_hot is only applicable to index tensor.
|