Skip to content

[ML] Fix weights to maximize minimum recall for multiclass classification when the training data is missing classes #1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions lib/maths/CDataFrameUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,27 +1072,32 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
std::size_t numberSamples{
static_cast<std::size_t>(std::min(1000.0, rowMask.manhattan()))};

TStratifiedSamplerPtr sampler;
std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler(
numberThreads, frame, targetColumn, rng, numberSamples, rowMask);

TSizeVec rowIndices;
frame.readRows(1, 0, frame.numberRows(),
[&](TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
sampler->sample(*row);
}
},
&rowMask);
sampler->finishSampling(rng, rowIndices);
std::sort(rowIndices.begin(), rowIndices.end());
core::CPackedBitVector sampleMask;
for (auto row : rowIndices) {
sampleMask.extend(false, row - sampleMask.size());
sampleMask.extend(true);
if (numberSamples < static_cast<std::size_t>(rowMask.manhattan())) {
TStratifiedSamplerPtr sampler;
std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler(
numberThreads, frame, targetColumn, rng, numberSamples, rowMask);

TSizeVec rowIndices;
frame.readRows(1, 0, frame.numberRows(),
[&](TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
sampler->sample(*row);
}
},
&rowMask);
sampler->finishSampling(rng, rowIndices);
std::sort(rowIndices.begin(), rowIndices.end());
LOG_TRACE(<< "# row indices = " << rowIndices.size());

for (auto row : rowIndices) {
sampleMask.extend(false, row - sampleMask.size());
sampleMask.extend(true);
}
sampleMask.extend(false, rowMask.size() - sampleMask.size());
} else {
sampleMask = rowMask;
}
sampleMask.extend(false, rowMask.size() - sampleMask.size());
LOG_TRACE(<< "# row indices = " << rowIndices.size());

// Compute the count of each class in the sample set.
auto readClassCountsAndRecalls = core::bindRetrievableState(
Expand All @@ -1119,6 +1124,11 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
copyClassCountsAndRecalls, reduceClassCountsAndRecalls, classCountsAndRecalls);
TDoubleVector classCounts{classCountsAndRecalls.topRows(numberClasses)};
TDoubleVector classRecalls{classCountsAndRecalls.bottomRows(numberClasses)};
// If a class count is zero the derivative of the loss functin w.r.t. that
// class is not well defined. In practice, the objective is independent of
// such a a class so the choice for its count and recall is not important,
// but the count must be non-zero so that we don't run into NaN cascade.
classCounts = classCounts.cwiseMax(1.0);
classRecalls.array() /= classCounts.array();
LOG_TRACE(<< "class counts = " << classCounts.transpose());
LOG_TRACE(<< "class recalls = " << classRecalls.transpose());
Expand Down Expand Up @@ -1219,15 +1229,15 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
// and use this to bias initial values.

TMinAccumulator minLoss;
TDoubleVector minWeights;
TDoubleVector minLossWeights;
TDoubleVector w0{TDoubleVector::Ones(numberClasses)};
for (std::size_t i = 0; i < 5; ++i) {
CLbfgs<TDoubleVector> lbfgs{5};
double loss;
std::tie(w0, loss) = lbfgs.minimize(objective, objectiveGradient, std::move(w0));
LOG_TRACE(<< "weights* = " << w0.transpose() << ", loss* = " << loss);
if (minLoss.add(loss)) {
minWeights = std::move(w0);
minLossWeights = std::move(w0);
}
w0 = TDoubleVector::Ones(numberClasses);
for (std::size_t j = 1; j < numberClasses; ++j) {
Expand All @@ -1237,10 +1247,10 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,

// Since we take argmax_i w_i p_i we can multiply by a constant. We arrange for
// the largest weight to always be one.
minWeights.array() /= minWeights.maxCoeff();
LOG_TRACE(<< "weights = " << minWeights.transpose());
minLossWeights.array() /= minLossWeights.maxCoeff();
LOG_TRACE(<< "weights = " << minLossWeights.transpose());

return minWeights;
return minLossWeights;
}

void CDataFrameUtils::removeMetricColumns(const core::CDataFrame& frame, TSizeVec& columnMask) {
Expand Down