Skip to content

Commit 505960a

Browse files
authored
[ML] Fix weights to maximize minimum recall for multiclass classification when the training data is missing classes (#1239)
1 parent 2f4a2d9 commit 505960a

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

lib/maths/CDataFrameUtils.cc

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,27 +1072,34 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
10721072
std::size_t numberSamples{
10731073
static_cast<std::size_t>(std::min(1000.0, rowMask.manhattan()))};
10741074

1075-
TStratifiedSamplerPtr sampler;
1076-
std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler(
1077-
numberThreads, frame, targetColumn, rng, numberSamples, rowMask);
1078-
1079-
TSizeVec rowIndices;
1080-
frame.readRows(1, 0, frame.numberRows(),
1081-
[&](TRowItr beginRows, TRowItr endRows) {
1082-
for (auto row = beginRows; row != endRows; ++row) {
1083-
sampler->sample(*row);
1084-
}
1085-
},
1086-
&rowMask);
1087-
sampler->finishSampling(rng, rowIndices);
1088-
std::sort(rowIndices.begin(), rowIndices.end());
10891075
core::CPackedBitVector sampleMask;
1090-
for (auto row : rowIndices) {
1091-
sampleMask.extend(false, row - sampleMask.size());
1092-
sampleMask.extend(true);
1076+
1077+
// No need to sample if were going to use every row we've been given.
1078+
if (numberSamples < static_cast<std::size_t>(rowMask.manhattan())) {
1079+
TStratifiedSamplerPtr sampler;
1080+
std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler(
1081+
numberThreads, frame, targetColumn, rng, numberSamples, rowMask);
1082+
1083+
TSizeVec rowIndices;
1084+
frame.readRows(1, 0, frame.numberRows(),
1085+
[&](TRowItr beginRows, TRowItr endRows) {
1086+
for (auto row = beginRows; row != endRows; ++row) {
1087+
sampler->sample(*row);
1088+
}
1089+
},
1090+
&rowMask);
1091+
sampler->finishSampling(rng, rowIndices);
1092+
std::sort(rowIndices.begin(), rowIndices.end());
1093+
LOG_TRACE(<< "# row indices = " << rowIndices.size());
1094+
1095+
for (auto row : rowIndices) {
1096+
sampleMask.extend(false, row - sampleMask.size());
1097+
sampleMask.extend(true);
1098+
}
1099+
sampleMask.extend(false, rowMask.size() - sampleMask.size());
1100+
} else {
1101+
sampleMask = rowMask;
10931102
}
1094-
sampleMask.extend(false, rowMask.size() - sampleMask.size());
1095-
LOG_TRACE(<< "# row indices = " << rowIndices.size());
10961103

10971104
// Compute the count of each class in the sample set.
10981105
auto readClassCountsAndRecalls = core::bindRetrievableState(
@@ -1119,6 +1126,11 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
11191126
copyClassCountsAndRecalls, reduceClassCountsAndRecalls, classCountsAndRecalls);
11201127
TDoubleVector classCounts{classCountsAndRecalls.topRows(numberClasses)};
11211128
TDoubleVector classRecalls{classCountsAndRecalls.bottomRows(numberClasses)};
1129+
// If a class count is zero the derivative of the loss functin w.r.t. that
1130+
// class is not well defined. The objective is independent of such a class
1131+
// so the choice for its count and recall is not important. However, its
1132+
// count must be non-zero so that we don't run into a NaN cascade.
1133+
classCounts = classCounts.cwiseMax(1.0);
11221134
classRecalls.array() /= classCounts.array();
11231135
LOG_TRACE(<< "class counts = " << classCounts.transpose());
11241136
LOG_TRACE(<< "class recalls = " << classRecalls.transpose());
@@ -1219,15 +1231,15 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
12191231
// and use this to bias initial values.
12201232

12211233
TMinAccumulator minLoss;
1222-
TDoubleVector minWeights;
1234+
TDoubleVector minLossWeights;
12231235
TDoubleVector w0{TDoubleVector::Ones(numberClasses)};
12241236
for (std::size_t i = 0; i < 5; ++i) {
12251237
CLbfgs<TDoubleVector> lbfgs{5};
12261238
double loss;
12271239
std::tie(w0, loss) = lbfgs.minimize(objective, objectiveGradient, std::move(w0));
12281240
LOG_TRACE(<< "weights* = " << w0.transpose() << ", loss* = " << loss);
12291241
if (minLoss.add(loss)) {
1230-
minWeights = std::move(w0);
1242+
minLossWeights = std::move(w0);
12311243
}
12321244
w0 = TDoubleVector::Ones(numberClasses);
12331245
for (std::size_t j = 1; j < numberClasses; ++j) {
@@ -1237,10 +1249,10 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
12371249

12381250
// Since we take argmax_i w_i p_i we can multiply by a constant. We arrange for
12391251
// the largest weight to always be one.
1240-
minWeights.array() /= minWeights.maxCoeff();
1241-
LOG_TRACE(<< "weights = " << minWeights.transpose());
1252+
minLossWeights.array() /= minLossWeights.maxCoeff();
1253+
LOG_TRACE(<< "weights = " << minLossWeights.transpose());
12421254

1243-
return minWeights;
1255+
return minLossWeights;
12441256
}
12451257

12461258
void CDataFrameUtils::removeMetricColumns(const core::CDataFrame& frame, TSizeVec& columnMask) {

0 commit comments

Comments
 (0)