-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add softmax_focal_loss() to allow multi-class focal loss #7760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7760
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Merge Blocking SEVsThere is 1 active merge blocking SEVs. Please view them below:
If you must merge, use This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# Cross Entropy Loss computes: | ||
# pt = softmax(...) | ||
# loss = -1.0 * log(pt) | ||
# | ||
# Hence, exp(-loss) == pt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly I don't think this works anymore:
loss = -weight_t * log(pt)
exp(-loss) == exp(weight_t * log(pt)
exp(-loss) == exp(log(pt ** weight_t))
due to the power rule of logarithms
exp(-loss) == pt ** weight_t
From what I've seen, I think this will instead require
p = softmax(...)
ce_loss = nll_loss(p, weight=weight)
pt = p[:, targets]
Hence,
((1 - pt) ** gamma) * ce_loss
== ((1 - pt) ** gamma) * (-weight_t * log(pt))
== -weight_t * ((1 - pt) ** gamma) * log(pt)
(Just to mention, I'm not a torchvision contributor, I'm commenting mostly because I'm looking for a good implementation myself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rehno-lindeque Please could you check the file changes and see if the code looks okay - I think the comment may suond misleading. The code is computing cross entropy loss twice, once without the weight and once with the weights, so that the one without can be exponentiated to compute the probability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I did indeed miss that, you are right the code is functionally correct. Sorry about that.
I would still humbly suggest breaking down the cross_entropy into softmax log_softmax followed by nll_loss as I described for performance reasons though.
That is, calculating cross entropy twice might be rather expensive, especially if the tensors are images. It's only necessary to calculate it once if you take pt
directly from the softmax log_softmax result.
CrossEntropyLoss docs:
Note that this case is equivalent to the combination of LogSoftmax and NLLLoss.
EDIT: Oops I also just noticed that softmax
should be log_softmax
in my comment from last night.
log_p = log_softmax(input=inputs, dim=1)
ce_loss = nll_loss(input=log_p, target=targets, weight=weight)
pt = exp(log_p.take_along_dim(indices=indices.unsqueeze(dim=1), dim=1))
I think this breakdown may also make it easier to extend to targets as class probabilities in future.
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately,
sigmoid_focal_loss()
isn't useful in such cases. I found that other have been asking for this as well here #3250 so I decided to submit a PR for the same.Putting up this PR to get feedback about whether this sounds okay.