diff --git a/lib/maths/CDataFrameUtils.cc b/lib/maths/CDataFrameUtils.cc index ffa157347a..83bfe51745 100644 --- a/lib/maths/CDataFrameUtils.cc +++ b/lib/maths/CDataFrameUtils.cc @@ -1069,7 +1069,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, using TMinAccumulator = CBasicStatistics::SMin::TAccumulator; CPRNG::CXorOShiro128Plus rng; - std::size_t numberSamples{std::min(std::size_t{1000}, rowMask.size())}; + std::size_t numberSamples{ + static_cast(std::min(1000.0, rowMask.manhattan()))}; TStratifiedSamplerPtr sampler; std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler( @@ -1168,7 +1169,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, }; TDoubleVector objective_; - doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), computeObjective, &rowMask), + doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), + computeObjective, &sampleMask), copyObjective, reduceObjective, objective_); return objective_.maxCoeff(); }; @@ -1203,7 +1205,7 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, TDoubleMatrix objectiveAndGradient; doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), - computeObjectiveAndGradient, &rowMask), + computeObjectiveAndGradient, &sampleMask), copyObjectiveAndGradient, reduceObjectiveAndGradient, objectiveAndGradient); std::size_t max; objectiveAndGradient.col(0).maxCoeff(&max); diff --git a/lib/maths/unittest/CDataFrameUtilsTest.cc b/lib/maths/unittest/CDataFrameUtilsTest.cc index b005e46096..f8b678031a 100644 --- a/lib/maths/unittest/CDataFrameUtilsTest.cc +++ b/lib/maths/unittest/CDataFrameUtilsTest.cc @@ -1308,8 +1308,8 @@ BOOST_AUTO_TEST_CASE(testMaximumMinimumRecallClassWeights) { BOOST_TEST_REQUIRE(minRecalls[0][0] > 1.1 * minRecalls[0][1]); // The minimum and maximum class recalls are close: we're at the global maximum. - BOOST_TEST_REQUIRE(1.02 * minRecalls[0][0] > maxRecalls[0][0]); - BOOST_TEST_REQUIRE(1.02 * minRecalls[1][0] > maxRecalls[1][0]); + BOOST_TEST_REQUIRE(1.06 * minRecalls[0][0] > maxRecalls[0][0]); + BOOST_TEST_REQUIRE(1.06 * minRecalls[1][0] > maxRecalls[1][0]); } }