Skip to content

Commit 1b8aebe

Browse files
authored
[7.7][ML] Fix a bug with progress reporting during model training (#1009)
Backport #1001.
1 parent a45db7d commit 1b8aebe

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

lib/maths/CBoostedTreeFactory.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const std::size_t MAX_NUMBER_TREES{static_cast<std::size_t>(2.0 / MIN_ETA + 0.5)
5555
// for progress monitoring because we don't know what value we'll choose in the
5656
// line search. Assuming it is less than one avoids a large pause in progress if
5757
// it is reduced in the line search.
58-
const double LINE_SEARCH_ETA_MARGIN{0.5};
58+
const double MAIN_LOOP_ETA_SCALE_FOR_PROGRESS{0.5};
5959

6060
double computeEta(std::size_t numberRegressors) {
6161
// eta is the learning rate. There is a lot of empirical evidence that
@@ -315,7 +315,7 @@ void CBoostedTreeFactory::selectFeaturesAndEncodeCategories(const core::CDataFra
315315
.minimumFrequencyToOneHotEncode(m_MinimumFrequencyToOneHotEncode)
316316
.rowMask(m_TreeImpl->allTrainingRowsMask())
317317
.columnMask(std::move(regressors)));
318-
m_TreeImpl->m_TrainingProgress.increment(1);
318+
m_TreeImpl->m_TrainingProgress.increment(100);
319319
}
320320

321321
void CBoostedTreeFactory::determineFeatureDataTypes(const core::CDataFrame& frame) const {
@@ -741,7 +741,7 @@ void CBoostedTreeFactory::initializeUnsetEta(core::CDataFrame& frame) {
741741

742742
m_TreeImpl->m_TrainingProgress.incrementRange(
743743
static_cast<int>(this->mainLoopNumberSteps(m_TreeImpl->m_Eta)) -
744-
static_cast<int>(this->mainLoopNumberSteps(LINE_SEARCH_ETA_MARGIN * eta)));
744+
static_cast<int>(this->mainLoopNumberSteps(MAIN_LOOP_ETA_SCALE_FOR_PROGRESS * eta)));
745745
}
746746
}
747747

@@ -1162,7 +1162,7 @@ void CBoostedTreeFactory::initializeTrainingProgressMonitoring(const core::CData
11621162
//
11631163
// This comprises:
11641164
// - The cost of category encoding and feature selection which we count as
1165-
// one unit,
1165+
// one hundred units,
11661166
// - One unit for estimating the expected gain and sum curvature per node,
11671167
// - LINE_SEARCH_ITERATIONS * "maximum number trees" units per regularization
11681168
// parameter which isn't user defined,
@@ -1178,7 +1178,7 @@ void CBoostedTreeFactory::initializeTrainingProgressMonitoring(const core::CData
11781178
? *m_TreeImpl->m_EtaOverride
11791179
: computeEta(frame.numberColumns())};
11801180

1181-
std::size_t totalNumberSteps{2};
1181+
std::size_t totalNumberSteps{101};
11821182
std::size_t lineSearchMaximumNumberTrees{computeMaximumNumberTrees(eta)};
11831183
if (m_TreeImpl->m_RegularizationOverride.softTreeDepthLimit() == boost::none) {
11841184
totalNumberSteps += MAX_LINE_SEARCH_ITERATIONS * lineSearchMaximumNumberTrees;
@@ -1196,10 +1196,10 @@ void CBoostedTreeFactory::initializeTrainingProgressMonitoring(const core::CData
11961196
totalNumberSteps += MAX_LINE_SEARCH_ITERATIONS * lineSearchMaximumNumberTrees;
11971197
}
11981198
if (m_TreeImpl->m_EtaOverride == boost::none) {
1199-
totalNumberSteps += MAX_LINE_SEARCH_ITERATIONS * lineSearchMaximumNumberTrees *
1200-
computeMaximumNumberTrees(LINE_SEARCH_ETA_MARGIN * eta);
1199+
totalNumberSteps += MAX_LINE_SEARCH_ITERATIONS *
1200+
computeMaximumNumberTrees(MAIN_LOOP_ETA_SCALE_FOR_PROGRESS * eta);
12011201
}
1202-
totalNumberSteps += this->mainLoopNumberSteps(LINE_SEARCH_ETA_MARGIN * eta);
1202+
totalNumberSteps += this->mainLoopNumberSteps(MAIN_LOOP_ETA_SCALE_FOR_PROGRESS * eta);
12031203
LOG_TRACE(<< "total number steps = " << totalNumberSteps);
12041204
m_TreeImpl->m_TrainingProgress = core::CLoopProgress{
12051205
totalNumberSteps, m_TreeImpl->m_Instrumentation->progressCallback(), 1.0, 1024};

lib/maths/unittest/CBoostedTreeTest.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,30 @@ using TMemoryMappedFloatVector = maths::boosted_tree::CLoss::TMemoryMappedFloatV
4747
namespace {
4848

4949
class CTestInstrumentation : public maths::CDataFrameAnalysisInstrumentationInterface {
50+
public:
51+
using TIntVec = std::vector<int>;
52+
5053
public:
5154
CTestInstrumentation()
5255
: m_TotalFractionalProgress{0}, m_MemoryUsage{0}, m_MaxMemoryUsage{0} {}
5356

54-
int progress() const { return m_TotalFractionalProgress.load(); }
57+
int progress() const {
58+
return (100 * m_TotalFractionalProgress.load()) / 65536;
59+
}
60+
TIntVec tenPercentProgressPoints() const {
61+
return m_TenPercentProgressPoints;
62+
}
5563
std::int64_t maxMemoryUsage() const { return m_MaxMemoryUsage.load(); }
5664

5765
void updateProgress(double fractionalProgress) override {
58-
m_TotalFractionalProgress.fetch_add(
59-
static_cast<int>(65536.0 * fractionalProgress + 0.5));
66+
int progress{m_TotalFractionalProgress.fetch_add(
67+
static_cast<int>(65536.0 * fractionalProgress + 0.5))};
68+
// This needn't be protected because progress is only written from one thread and
69+
// the tests arrange that it is never read at the same time it is being written.
70+
if (m_TenPercentProgressPoints.empty() ||
71+
100 * progress > 65536 * (m_TenPercentProgressPoints.back() + 10)) {
72+
m_TenPercentProgressPoints.push_back(10 * ((10 * progress) / 65536));
73+
}
6074
}
6175

6276
void updateMemoryUsage(std::int64_t delta) override {
@@ -73,6 +87,7 @@ class CTestInstrumentation : public maths::CDataFrameAnalysisInstrumentationInte
7387

7488
private:
7589
std::atomic_int m_TotalFractionalProgress;
90+
TIntVec m_TenPercentProgressPoints;
7691
std::atomic<std::int64_t> m_MemoryUsage;
7792
std::atomic<std::int64_t> m_MaxMemoryUsage;
7893
};
@@ -1392,23 +1407,26 @@ BOOST_AUTO_TEST_CASE(testProgressMonitoring) {
13921407
finished.store(true);
13931408
}};
13941409

1395-
int lastTotalFractionalProgress{0};
13961410
int lastProgressReport{0};
13971411

13981412
bool monotonic{true};
1399-
std::size_t percentage{0};
1413+
int percentage{0};
14001414
while (finished.load() == false) {
1401-
if (instrumentation.progress() > lastProgressReport) {
1415+
if (instrumentation.progress() > percentage) {
14021416
LOG_DEBUG(<< percentage << "% complete");
14031417
percentage += 10;
1404-
lastProgressReport += 6554;
14051418
}
1406-
monotonic &= (instrumentation.progress() >= lastTotalFractionalProgress);
1407-
lastTotalFractionalProgress = instrumentation.progress();
1419+
monotonic &= (instrumentation.progress() >= lastProgressReport);
1420+
lastProgressReport = instrumentation.progress();
14081421
}
14091422
worker.join();
14101423

14111424
BOOST_TEST_REQUIRE(monotonic);
1425+
LOG_DEBUG(<< "progress points = "
1426+
<< core::CContainerPrinter::print(instrumentation.tenPercentProgressPoints()));
1427+
BOOST_REQUIRE_EQUAL("[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]",
1428+
core::CContainerPrinter::print(
1429+
instrumentation.tenPercentProgressPoints()));
14121430

14131431
core::startDefaultAsyncExecutor();
14141432
}

0 commit comments

Comments
 (0)