Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

RuntimeError: one_hot is only applicable to index tensor.#454

AnsweredbyMKaczkow
talhaanwarch asked this question inQ&A
Discussion options

Ground truth shape :torch.Size([4, 1, 512, 512])
Predicted shape:torch.Size([4, 59, 512, 512])
Loss:smp.losses.DiceLoss(mode='multiclass')
Model:

smp.Unet(               encoder_name='resnet18',                  encoder_weights=None,                 in_channels=3,                      classes=59,                 activation='softmax2d' )

I also tried removingactivation='softmax2d' but the error is still there.
Here is the complete traceback

---------------------------------------------------------------------------RuntimeError                              Traceback (most recent call last)<timed eval> in <module>/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)    458         )    459 --> 460         self._run(model)    461     462         assert self.state.stopped/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)    756     757         # dispatch `start_training` or `start_evaluating` or `start_predicting`--> 758         self.dispatch()    759     760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)    797             self.accelerator.start_predicting(self)    798         else:--> 799             self.accelerator.start_training(self)    800     801     def run_stage(self):/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)     94      95     def start_training(self, trainer: 'pl.Trainer') -> None:---> 96         self.training_type_plugin.start_training(trainer)     97      98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)    142     def start_training(self, trainer: 'pl.Trainer') -> None:    143         # double dispatch to initiate the training loop--> 144         self._results = trainer.run_stage()    145     146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)    807         if self.predicting:    808             return self.run_predict()--> 809         return self.run_train()    810     811     def _pre_training_routine(self):/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)    842             self.progress_bar_callback.disable()    843 --> 844         self.run_sanity_check(self.lightning_module)    845     846         self.checkpoint_connector.has_trained = False/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)   1110    1111             # run eval step-> 1112             self.run_evaluation()   1113    1114             self.on_sanity_check_end()/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, on_epoch)    965                 # lightning module methods    966                 with self.profiler.profile("evaluation_step_and_end"):--> 967                     output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)    968                     output = self.evaluation_loop.evaluation_step_end(output)    969 /opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py in evaluation_step(self, batch, batch_idx, dataloader_idx)    172             model_ref._current_fx_name = "validation_step"    173             with self.trainer.profiler.profile("validation_step"):--> 174                 output = self.trainer.accelerator.validation_step(args)    175     176         # capture any logged information/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, args)    224     225         with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():--> 226             return self.training_type_plugin.validation_step(*args)    227     228     def test_step(self, args: List[Union[Any, int]]) -> Optional[STEP_OUTPUT]:/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)    159     160     def validation_step(self, *args, **kwargs):--> 161         return self.lightning_module.validation_step(*args, **kwargs)    162     163     def test_step(self, *args, **kwargs):<ipython-input-148-9c5092280cd4> in validation_step(self, batch, batch_idx)     54     outputs=self(image)     55     print('before',outputs.shape,segment.shape,segment.max())---> 56     loss=self.criterion(outputs,segment)     57     print('loss',loss)     58 #     dice=dice_coef_multilabel(labels,segment)/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)    725             result = self._slow_forward(*input, **kwargs)    726         else:--> 727             result = self.forward(*input, **kwargs)    728         for hook in itertools.chain(    729                 _global_forward_hooks.values(),/opt/conda/lib/python3.7/site-packages/segmentation_models_pytorch/losses/dice.py in forward(self, y_pred, y_true)     93                 y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)  # H, C, H*W     94             else:---> 95                 y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C     96                 y_true = y_true.permute(0, 2, 1)  # H, C, H*W     97 RuntimeError: one_hot is only applicable to index tensor.
You must be logged in to vote

For anyone looking for answer: gt input tensor (y_pred) must have typetorch.LongTensor (ortorch.cuda.LongTensor). If not, this line:

y_true=F.one_hot(y_true,num_classes)# N,H*W -> N,H*W, C

fromsmp.losses module fails. This is also describedon SO. Though it's unclear for me whyLongTensor is specifically required intorch.nn.functional.one_hot (any integer type would do, I guess?), thedocs state it is and.long() solves problem.

Replies: 2 comments 4 replies

Comment options

Try to squeeze ground truth tensor

You must be logged in to vote
3 replies
@talhaanwarch
Comment options

I squeeze the shape of the ground label totorch.Size([4, 512, 512]) but still error exists

@talhaanwarch
Comment options

i need to convert ground truth to.long()

@ljb-1
Comment options

I have the same problem. Can you ask me how to solve it?

Comment options

I have the same problem. Can you ask me how to solve it?

You must be logged in to vote
1 reply
@MKaczkow
Comment options

For anyone looking for answer: gt input tensor (y_pred) must have typetorch.LongTensor (ortorch.cuda.LongTensor). If not, this line:

y_true=F.one_hot(y_true,num_classes)# N,H*W -> N,H*W, C

fromsmp.losses module fails. This is also describedon SO. Though it's unclear for me whyLongTensor is specifically required intorch.nn.functional.one_hot (any integer type would do, I guess?), thedocs state it is and.long() solves problem.

Answer selected byqubvel
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
Q&A
Labels
None yet
4 participants
@talhaanwarch@qubvel@ljb-1@MKaczkow

[8]ページ先頭

©2009-2025 Movatter.jp