Skip to content

[ML] Multiclass maximise minimum recall #1113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 7, 2020

Conversation

tveasey
Copy link
Contributor

@tveasey tveasey commented Apr 2, 2020

This implements maximise minimum class recall for multiclass classification when assigning class labels, which is often a better objective when classes are imbalanced in the training data.

}
doReduce(frame.readRows(numberThreads, 0, frame.numberRows(),
readCategoryCounts, &rowMask),
copyCategoryCounts, reduceCategoryCounts, result);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check was unnecessary here since readCategoryCounts can't fail.

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work altogether. I just have one comment wrt. the initialization of w0 in the optimization procedure.

Comment on lines +1146 to +1158
// We want to solve max_w{min_j{recall(class_j)}} = max_w{min_j{c_j(w) / n_j}}
// where c_j(w) and n_j are correct predictions for weight w and count of class_j
// in the sample set, respectively. We use an equivalent formulation
//
// min_w{max_j{f_j(w)}} = min_w{max_j{1 - c_j(w) / n_j}}
//
// We can write f_j(w) as
//
// max_j{sum_i{1 - 1{argmax_i(w_i p_i) == j}} / n_j} (1)
//
// where 1{.} denotes the indicator function. (1) has a smooth relaxation given
// by f_j(w) = max_j{sum_i{1 - softmax_j(w_i p_i)} / n_j}. Note that this isn't
// convex so we use multiple restarts.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! 👍

Comment on lines 1243 to 1246
for (std::size_t j = 0; j < numberClasses; ++j) {
interpolate(j) = CSampling::uniformSample(rng, 0.0, 1.0);
}
w0 = (a + interpolate.cwiseProduct(b - a)).array().exp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that this should be at the beginning of the for-loop. Otherwise, you try with w0=(1,1,1..,1) at first and this can be outside of your bounds.

Copy link
Contributor Author

@tveasey tveasey Apr 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually intended to do this. My thought was we should always include the best solution in the vicinity of doing nothing, i.e. all weights being equal. Since we use line search with backtracking we are then guarantying that we never do worse than not reweighting at all, which is a nice property since the optimisation objective is complex. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked initialisation slightly and added a comment.

I still think it is worth trying the one vector for reason outlined. I also now bake in the fact we expect weights to be (roughly) a monotonic decreasing function of class recalls. (This isn't guaranteed because it depends how close probabilities are and what the predicted classes are for the error cases.)

I also reduced the number of restarts because trying out with a wider variety of numbers of classes and range of recalls I didn't see evidence we needed as many restarts after this change. See this commit.

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining. LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants