-
Notifications
You must be signed in to change notification settings - Fork 65
[ML] Multiclass maximise minimum recall #1113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Multiclass maximise minimum recall #1113
Conversation
} | ||
doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), | ||
readCategoryCounts, &rowMask), | ||
copyCategoryCounts, reduceCategoryCounts, result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check was unnecessary here since readCategoryCounts can't fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work altogether. I just have one comment wrt. the initialization of w0 in the optimization procedure.
// We want to solve max_w{min_j{recall(class_j)}} = max_w{min_j{c_j(w) / n_j}} | ||
// where c_j(w) and n_j are correct predictions for weight w and count of class_j | ||
// in the sample set, respectively. We use an equivalent formulation | ||
// | ||
// min_w{max_j{f_j(w)}} = min_w{max_j{1 - c_j(w) / n_j}} | ||
// | ||
// We can write f_j(w) as | ||
// | ||
// max_j{sum_i{1 - 1{argmax_i(w_i p_i) == j}} / n_j} (1) | ||
// | ||
// where 1{.} denotes the indicator function. (1) has a smooth relaxation given | ||
// by f_j(w) = max_j{sum_i{1 - softmax_j(w_i p_i)} / n_j}. Note that this isn't | ||
// convex so we use multiple restarts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! 👍
lib/maths/CDataFrameUtils.cc
Outdated
for (std::size_t j = 0; j < numberClasses; ++j) { | ||
interpolate(j) = CSampling::uniformSample(rng, 0.0, 1.0); | ||
} | ||
w0 = (a + interpolate.cwiseProduct(b - a)).array().exp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that this should be at the beginning of the for-loop. Otherwise, you try with w0=(1,1,1..,1) at first and this can be outside of your bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually intended to do this. My thought was we should always include the best solution in the vicinity of doing nothing, i.e. all weights being equal. Since we use line search with backtracking we are then guarantying that we never do worse than not reweighting at all, which is a nice property since the optimisation objective is complex. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reworked initialisation slightly and added a comment.
I still think it is worth trying the one vector for reason outlined. I also now bake in the fact we expect weights to be (roughly) a monotonic decreasing function of class recalls. (This isn't guaranteed because it depends how close probabilities are and what the predicted classes are for the error cases.)
I also reduced the number of restarts because trying out with a wider variety of numbers of classes and range of recalls I didn't see evidence we needed as many restarts after this change. See this commit.
…onic decreasing function of recalls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for explaining. LGTM
This implements maximise minimum class recall for multiclass classification when assigning class labels, which is often a better objective when classes are imbalanced in the training data.