diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index b8baae98..38a074f2 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -70,7 +70,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: bs = y_true.size(0) num_classes = y_pred.size(1) - dims = (0, 2) + dims = (2) if self.mode == BINARY_MODE: y_true = y_true.view(bs, 1, -1)