Skip to content

Add softmax_focal_loss() to allow multi-class focal loss #7760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dhruvbird
Copy link

In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here #3250 so I decided to submit a PR for the same.

Putting up this PR to get feedback about whether this sounds okay.

In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7760

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +130 to +134
# Cross Entropy Loss computes:
# pt = softmax(...)
# loss = -1.0 * log(pt)
#
# Hence, exp(-loss) == pt
Copy link

@rehno-lindeque rehno-lindeque Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly I don't think this works anymore:

loss = -weight_t * log(pt)
exp(-loss) == exp(weight_t * log(pt)
exp(-loss) == exp(log(pt ** weight_t)) due to the power rule of logarithms
exp(-loss) == pt ** weight_t

From what I've seen, I think this will instead require

p = softmax(...)
ce_loss = nll_loss(p, weight=weight)
pt = p[:, targets]

Hence,
((1 - pt) ** gamma) * ce_loss
== ((1 - pt) ** gamma) * (-weight_t * log(pt))
== -weight_t * ((1 - pt) ** gamma) * log(pt)

(Just to mention, I'm not a torchvision contributor, I'm commenting mostly because I'm looking for a good implementation myself)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rehno-lindeque Please could you check the file changes and see if the code looks okay - I think the comment may suond misleading. The code is computing cross entropy loss twice, once without the weight and once with the weights, so that the one without can be exponentiated to compute the probability.

Copy link

@rehno-lindeque rehno-lindeque Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I did indeed miss that, you are right the code is functionally correct. Sorry about that.

I would still humbly suggest breaking down the cross_entropy into softmax log_softmax followed by nll_loss as I described for performance reasons though.

That is, calculating cross entropy twice might be rather expensive, especially if the tensors are images. It's only necessary to calculate it once if you take pt directly from the softmax log_softmax result.

CrossEntropyLoss docs:
Note that this case is equivalent to the combination of LogSoftmax and NLLLoss.

EDIT: Oops I also just noticed that softmax should be log_softmax in my comment from last night.

log_p = log_softmax(input=inputs, dim=1)
ce_loss = nll_loss(input=log_p, target=targets, weight=weight)

pt =  exp(log_p.take_along_dim(indices=indices.unsqueeze(dim=1), dim=1)) 

I think this breakdown may also make it easier to extend to targets as class probabilities in future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants