From c38a1dd9c72cdab520e8c74a05b67eaba9527e05 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Tue, 12 May 2020 10:54:21 +0100 Subject: [PATCH 1/2] Maximise minimum recall for multiclass should have been computing the objective on the sample set --- lib/maths/CDataFrameUtils.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/maths/CDataFrameUtils.cc b/lib/maths/CDataFrameUtils.cc index ffa157347a..83bfe51745 100644 --- a/lib/maths/CDataFrameUtils.cc +++ b/lib/maths/CDataFrameUtils.cc @@ -1069,7 +1069,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, using TMinAccumulator = CBasicStatistics::SMin::TAccumulator; CPRNG::CXorOShiro128Plus rng; - std::size_t numberSamples{std::min(std::size_t{1000}, rowMask.size())}; + std::size_t numberSamples{ + static_cast(std::min(1000.0, rowMask.manhattan()))}; TStratifiedSamplerPtr sampler; std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler( @@ -1168,7 +1169,8 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, }; TDoubleVector objective_; - doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), computeObjective, &rowMask), + doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), + computeObjective, &sampleMask), copyObjective, reduceObjective, objective_); return objective_.maxCoeff(); }; @@ -1203,7 +1205,7 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, TDoubleMatrix objectiveAndGradient; doReduce(frame.readRows(numberThreads, 0, frame.numberRows(), - computeObjectiveAndGradient, &rowMask), + computeObjectiveAndGradient, &sampleMask), copyObjectiveAndGradient, reduceObjectiveAndGradient, objectiveAndGradient); std::size_t max; objectiveAndGradient.col(0).maxCoeff(&max); From fc36e308f32f232f882d15d8ee59489c48c39b4a Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Tue, 12 May 2020 14:58:10 +0100 Subject: [PATCH 2/2] Relax test threshold slightly --- lib/maths/unittest/CDataFrameUtilsTest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/maths/unittest/CDataFrameUtilsTest.cc b/lib/maths/unittest/CDataFrameUtilsTest.cc index b005e46096..f8b678031a 100644 --- a/lib/maths/unittest/CDataFrameUtilsTest.cc +++ b/lib/maths/unittest/CDataFrameUtilsTest.cc @@ -1308,8 +1308,8 @@ BOOST_AUTO_TEST_CASE(testMaximumMinimumRecallClassWeights) { BOOST_TEST_REQUIRE(minRecalls[0][0] > 1.1 * minRecalls[0][1]); // The minimum and maximum class recalls are close: we're at the global maximum. - BOOST_TEST_REQUIRE(1.02 * minRecalls[0][0] > maxRecalls[0][0]); - BOOST_TEST_REQUIRE(1.02 * minRecalls[1][0] > maxRecalls[1][0]); + BOOST_TEST_REQUIRE(1.06 * minRecalls[0][0] > maxRecalls[0][0]); + BOOST_TEST_REQUIRE(1.06 * minRecalls[1][0] > maxRecalls[1][0]); } }