Skip to content

Commit 21b7c55

Browse files
authored
[ML] Fix weights to maximize minimum recall for multiclass classification (#1235)
Backport #1231.
1 parent 0a0cfcf commit 21b7c55

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

lib/maths/CDataFrameUtils.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
10691069
using TMinAccumulator = CBasicStatistics::SMin<double>::TAccumulator;
10701070

10711071
CPRNG::CXorOShiro128Plus rng;
1072-
std::size_t numberSamples{std::min(std::size_t{1000}, rowMask.size())};
1072+
std::size_t numberSamples{
1073+
static_cast<std::size_t>(std::min(1000.0, rowMask.manhattan()))};
10731074

10741075
TStratifiedSamplerPtr sampler;
10751076
std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler(
@@ -1168,7 +1169,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
11681169
};
11691170

11701171
TDoubleVector objective_;
1171-
doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), computeObjective, &rowMask),
1172+
doReduce(frame.readRows(numberThreads, 0, frame.numberRows(),
1173+
computeObjective, &sampleMask),
11721174
copyObjective, reduceObjective, objective_);
11731175
return objective_.maxCoeff();
11741176
};
@@ -1203,7 +1205,7 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
12031205

12041206
TDoubleMatrix objectiveAndGradient;
12051207
doReduce(frame.readRows(numberThreads, 0, frame.numberRows(),
1206-
computeObjectiveAndGradient, &rowMask),
1208+
computeObjectiveAndGradient, &sampleMask),
12071209
copyObjectiveAndGradient, reduceObjectiveAndGradient, objectiveAndGradient);
12081210
std::size_t max;
12091211
objectiveAndGradient.col(0).maxCoeff(&max);

lib/maths/unittest/CDataFrameUtilsTest.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,8 @@ BOOST_AUTO_TEST_CASE(testMaximumMinimumRecallClassWeights) {
13081308
BOOST_TEST_REQUIRE(minRecalls[0][0] > 1.1 * minRecalls[0][1]);
13091309

13101310
// The minimum and maximum class recalls are close: we're at the global maximum.
1311-
BOOST_TEST_REQUIRE(1.02 * minRecalls[0][0] > maxRecalls[0][0]);
1312-
BOOST_TEST_REQUIRE(1.02 * minRecalls[1][0] > maxRecalls[1][0]);
1311+
BOOST_TEST_REQUIRE(1.06 * minRecalls[0][0] > maxRecalls[0][0]);
1312+
BOOST_TEST_REQUIRE(1.06 * minRecalls[1][0] > maxRecalls[1][0]);
13131313
}
13141314
}
13151315

0 commit comments

Comments
 (0)