Skip to content

Commit 14bbed7

Browse files
authored
[7.4][ML] Progress monitoring for regression model training (elastic#599)
Backport elastic#595.
1 parent 8c09604 commit 14bbed7

11 files changed

+155
-41
lines changed

include/maths/CBoostedTree.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,12 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
132132
~CBoostedTree() override;
133133

134134
//! Train on the examples in the data frame supplied to the constructor.
135-
void train(TProgressCallback recordProgress = noop) override;
135+
void train() override;
136136

137137
//! Write the predictions to the data frame supplied to the constructor.
138138
//!
139139
//! \warning This can only be called after train.
140-
void predict(TProgressCallback recordProgress = noop) const override;
140+
void predict() const override;
141141

142142
//! Write the trained model to \p writer.
143143
//!
@@ -160,7 +160,7 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
160160
using TImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
161161

162162
private:
163-
CBoostedTree(core::CDataFrame& frame, TImplUPtr&& impl);
163+
CBoostedTree(core::CDataFrame& frame, TProgressCallback recordProgress, TImplUPtr&& impl);
164164

165165
private:
166166
TImplUPtr m_Impl;

include/maths/CBoostedTreeFactory.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class CBoostedTreeImpl;
3030
class MATHS_EXPORT CBoostedTreeFactory final {
3131
public:
3232
using TBoostedTreeUPtr = std::unique_ptr<CBoostedTree>;
33+
using TProgressCallback = CBoostedTree::TProgressCallback;
3334

3435
public:
3536
//! Construct a boosted tree object from parameters.
@@ -38,7 +39,8 @@ class MATHS_EXPORT CBoostedTreeFactory final {
3839

3940
//! Construct a boosted tree object from its serialized version.
4041
static TBoostedTreeUPtr constructFromString(std::stringstream& jsonStringStream,
41-
core::CDataFrame& frame);
42+
core::CDataFrame& frame,
43+
TProgressCallback recordProgress = noop);
4244

4345
~CBoostedTreeFactory();
4446
CBoostedTreeFactory(CBoostedTreeFactory&) = delete;
@@ -66,7 +68,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
6668
//! Set the number of training examples we need per feature we'll include.
6769
CBoostedTreeFactory& rowsPerFeature(std::size_t rowsPerFeature);
6870
//! Set the callback function for progress monitoring.
69-
CBoostedTreeFactory& progressCallback(CBoostedTree::TProgressCallback callback);
71+
CBoostedTreeFactory& progressCallback(TProgressCallback callback);
7072

7173
//! Estimate the maximum booking memory that training the boosted tree on a data
7274
//! frame with \p numberRows row and \p numberColumns columns will use.
@@ -103,19 +105,20 @@ class MATHS_EXPORT CBoostedTreeFactory final {
103105
//! Read overrides for hyperparameters and if necessary estimate the initial
104106
//! values for \f$\lambda\f$ and \f$\gamma\f$ which match the gain from an
105107
//! overfit tree.
106-
void initializeHyperparameters(core::CDataFrame& frame,
107-
CBoostedTree::TProgressCallback recordProgress) const;
108+
void initializeHyperparameters(core::CDataFrame& frame) const;
108109

109110
//! Initialize the state for hyperparameter optimisation.
110111
void initializeHyperparameterOptimisation() const;
111112

112113
//! Get the number of hyperparameter tuning rounds to use.
113114
std::size_t numberHyperparameterTuningRounds() const;
114115

116+
static void noop(double);
117+
115118
private:
116119
double m_MinimumFrequencyToOneHotEncode;
117120
TBoostedTreeImplUPtr m_TreeImpl;
118-
CBoostedTree::TProgressCallback m_ProgressCallback;
121+
TProgressCallback m_RecordProgress = noop;
119122
};
120123
}
121124
}

include/maths/CDataFrameRegressionModel.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class MATHS_EXPORT CDataFrameRegressionModel {
3333
CDataFrameRegressionModel& operator=(const CDataFrameRegressionModel&) = delete;
3434

3535
//! Train on the examples in the data frame supplied to the constructor.
36-
virtual void train(TProgressCallback recordProgress = noop) = 0;
36+
virtual void train() = 0;
3737

3838
//! Write the predictions to the data frame supplied to the constructor.
3939
//!
4040
//! \warning This can only be called after train.
41-
virtual void predict(TProgressCallback recordProgress = noop) const = 0;
41+
virtual void predict() const = 0;
4242

4343
//! Write this model to \p writer.
4444
//!
@@ -52,12 +52,13 @@ class MATHS_EXPORT CDataFrameRegressionModel {
5252
virtual std::size_t columnHoldingPrediction(std::size_t numberColumns) const = 0;
5353

5454
protected:
55-
CDataFrameRegressionModel(core::CDataFrame& frame) : m_Frame{frame} {}
56-
core::CDataFrame& frame() const { return m_Frame; }
57-
static void noop(double);
55+
CDataFrameRegressionModel(core::CDataFrame& frame, TProgressCallback recordProgress);
56+
core::CDataFrame& frame() const;
57+
const TProgressCallback& progressRecorder() const;
5858

5959
private:
6060
core::CDataFrame& m_Frame;
61+
TProgressCallback m_RecordProgress;
6162
};
6263
}
6364
}

lib/api/CDataFrameBoostedTreeRunner.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ CDataFrameBoostedTreeRunner::CDataFrameBoostedTreeRunner(const CDataFrameAnalysi
8686
maths::CBoostedTreeFactory::constructFromParameters(
8787
this->spec().numberThreads(), std::make_unique<maths::boosted_tree::CMse>()));
8888

89+
m_BoostedTreeFactory->progressCallback(this->progressRecorder());
8990
if (lambda >= 0.0) {
9091
m_BoostedTreeFactory->lambda(lambda);
9192
}
@@ -139,8 +140,8 @@ void CDataFrameBoostedTreeRunner::runImpl(const TStrVec& featureNames,
139140

140141
m_BoostedTree = m_BoostedTreeFactory->buildFor(
141142
frame, dependentVariableColumn - featureNames.begin());
142-
m_BoostedTree->train(this->progressRecorder());
143-
m_BoostedTree->predict(this->progressRecorder());
143+
m_BoostedTree->train();
144+
m_BoostedTree->predict();
144145
}
145146

146147
std::size_t CDataFrameBoostedTreeRunner::estimateBookkeepingMemoryUsage(

lib/api/unittest/CDataFrameAnalyzerTest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ void addRegressionTestData(TStrVec fieldNames,
261261
std::unique_ptr<maths::CBoostedTree> tree =
262262
treeFactory.buildFor(*frame, weights.size());
263263

264-
tree->train(ml::maths::CBoostedTree::TProgressCallback());
264+
tree->train();
265265

266266
frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
267267
for (auto row = beginRows; row != endRows; ++row) {

lib/maths/CBoostedTree.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ double CArgMinMse::value() const {
5757
}
5858
}
5959

60-
CBoostedTree::CBoostedTree(core::CDataFrame& frame, TImplUPtr&& impl)
61-
: CDataFrameRegressionModel{frame}, m_Impl{std::move(impl)} {
60+
CBoostedTree::CBoostedTree(core::CDataFrame& frame, TProgressCallback recordProgress, TImplUPtr&& impl)
61+
: CDataFrameRegressionModel{frame, std::move(recordProgress)}, m_Impl{std::move(impl)} {
6262
}
6363

6464
CBoostedTree::~CBoostedTree() = default;
6565

66-
void CBoostedTree::train(TProgressCallback recordProgress) {
67-
m_Impl->train(this->frame(), recordProgress);
66+
void CBoostedTree::train() {
67+
m_Impl->train(this->frame(), this->progressRecorder());
6868
}
6969

70-
void CBoostedTree::predict(TProgressCallback recordProgress) const {
71-
m_Impl->predict(this->frame(), recordProgress);
70+
void CBoostedTree::predict() const {
71+
m_Impl->predict(this->frame(), this->progressRecorder());
7272
}
7373

7474
void CBoostedTree::write(core::CRapidJsonConcurrentLineWriter& writer) const {

lib/maths/CBoostedTreeFactory.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ CBoostedTreeFactory::buildFor(core::CDataFrame& frame, std::size_t dependentVari
3737
this->selectFeaturesAndEncodeCategories(frame);
3838

3939
if (this->initializeFeatureSampleDistribution()) {
40-
this->initializeHyperparameters(frame, m_ProgressCallback);
40+
this->initializeHyperparameters(frame);
4141
this->initializeHyperparameterOptimisation();
4242
}
4343

4444
// TODO can only use factory to create one object since this is moved. This seems trappy.
45-
return TBoostedTreeUPtr{new CBoostedTree(frame, std::move(m_TreeImpl))};
45+
return TBoostedTreeUPtr{new CBoostedTree(frame, m_RecordProgress, std::move(m_TreeImpl))};
4646
}
4747

4848
std::size_t CBoostedTreeFactory::numberHyperparameterTuningRounds() const {
@@ -180,8 +180,7 @@ bool CBoostedTreeFactory::initializeFeatureSampleDistribution() const {
180180
return false;
181181
}
182182

183-
void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame,
184-
CBoostedTree::TProgressCallback recordProgress) const {
183+
void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame) const {
185184

186185
m_TreeImpl->m_Lambda = m_TreeImpl->m_LambdaOverride.value_or(0.0);
187186
m_TreeImpl->m_Gamma = m_TreeImpl->m_GammaOverride.value_or(0.0);
@@ -232,7 +231,7 @@ void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame,
232231
LOG_TRACE(<< "loss = " << L[0] << ", # leaves = " << T[0]
233232
<< ", sum square weights = " << W[0]);
234233

235-
auto forest = m_TreeImpl->trainForest(frame, trainingRowMask, recordProgress);
234+
auto forest = m_TreeImpl->trainForest(frame, trainingRowMask, m_RecordProgress);
236235

237236
std::tie(L[1], T[1], W[1]) =
238237
m_TreeImpl->regularisedLoss(frame, trainingRowMask, forest);
@@ -272,10 +271,12 @@ CBoostedTreeFactory::constructFromParameters(std::size_t numberThreads,
272271

273272
CBoostedTreeFactory::TBoostedTreeUPtr
274273
CBoostedTreeFactory::constructFromString(std::stringstream& jsonStringStream,
275-
core::CDataFrame& frame) {
274+
core::CDataFrame& frame,
275+
TProgressCallback recordProgress) {
276276
try {
277277
TBoostedTreeUPtr treePtr{
278-
new CBoostedTree{frame, TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
278+
new CBoostedTree{frame, std::move(recordProgress),
279+
TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
279280
core::CJsonStateRestoreTraverser traverser(jsonStringStream);
280281
if (treePtr->acceptRestoreTraverser(traverser) == false || traverser.haveBadState()) {
281282
throw std::runtime_error{"failed to restore boosted tree"};
@@ -388,9 +389,8 @@ CBoostedTreeFactory& CBoostedTreeFactory::rowsPerFeature(std::size_t rowsPerFeat
388389
return *this;
389390
}
390391

391-
CBoostedTreeFactory&
392-
CBoostedTreeFactory::progressCallback(CBoostedTree::TProgressCallback callback) {
393-
m_ProgressCallback = callback;
392+
CBoostedTreeFactory& CBoostedTreeFactory::progressCallback(TProgressCallback callback) {
393+
m_RecordProgress = callback;
394394
return *this;
395395
}
396396

@@ -403,6 +403,9 @@ std::size_t CBoostedTreeFactory::numberExtraColumnsForTrain() const {
403403
return m_TreeImpl->numberExtraColumnsForTrain();
404404
}
405405

406+
void CBoostedTreeFactory::noop(double) {
407+
}
408+
406409
const double CBoostedTreeFactory::MINIMUM_ETA{1e-3};
407410
const std::size_t CBoostedTreeFactory::MAXIMUM_NUMBER_TREES{
408411
static_cast<std::size_t>(2.0 / MINIMUM_ETA + 0.5)};

lib/maths/CBoostedTreeImpl.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <maths/CBoostedTreeImpl.h>
88

9+
#include <core/CLoopProgress.h>
910
#include <core/CPersistUtils.h>
1011

1112
#include <maths/CBayesianOptimisation.h>
@@ -88,6 +89,12 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
8889

8990
LOG_TRACE(<< "Main training loop...");
9091

92+
// We account for cost of setup as one round. The main optimisation loop runs
93+
// for "m_NumberRounds + 1" rounds and training on the choosen hyperparameter
94+
// values is counted as one round. This gives a total of m_NumberRounds + 3.
95+
core::CLoopProgress progress{m_NumberRounds + 3, recordProgress};
96+
progress.increment();
97+
9198
if (this->canTrain() == false) {
9299
// Fallback to using the constant predictor which minimises the loss.
93100

@@ -121,13 +128,20 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
121128
if (this->selectNextHyperparameters(lossMoments, *m_BayesianOptimization) == false) {
122129
break;
123130
}
131+
132+
progress.increment();
133+
124134
} while (m_CurrentRound++ < m_NumberRounds);
125135

126136
LOG_TRACE(<< "Test loss = " << m_BestForestTestLoss);
127137

128138
this->restoreBestHyperparameters();
129139
m_BestForest = this->trainForest(frame, this->allTrainingRowsMask(), recordProgress);
130140
}
141+
142+
// Force to at least one here because we can have early exit from loop or take
143+
// a different path.
144+
recordProgress(1.0);
131145
}
132146

133147
void CBoostedTreeImpl::predict(core::CDataFrame& frame,

lib/maths/CDataFrameRegressionModel.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99
namespace ml {
1010
namespace maths {
1111

12-
void CDataFrameRegressionModel::noop(double) {
12+
CDataFrameRegressionModel::CDataFrameRegressionModel(core::CDataFrame& frame,
13+
TProgressCallback recordProgress)
14+
: m_Frame{frame}, m_RecordProgress{recordProgress} {
15+
}
16+
17+
core::CDataFrame& CDataFrameRegressionModel::frame() const {
18+
return m_Frame;
19+
}
20+
21+
const CDataFrameRegressionModel::TProgressCallback&
22+
CDataFrameRegressionModel::progressRecorder() const {
23+
return m_RecordProgress;
1324
}
1425
}
1526
}

0 commit comments

Comments
 (0)