Skip to content

Commit 1becee7

Browse files
authored
[7.5][ML] Improve boosted tree training initialisation (#697)
Backport #686.
1 parent 87ed67f commit 1becee7

16 files changed

+850
-394
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
For large data sets this change was observed to give a 10% to 20% decrease in
3737
train time. (See {ml-pull}622[#622].)
3838
* Upgrade Boost libraries to version 1.71. (See {ml-pull}638[#638].)
39+
* Improve initialisation of boosted tree training. This generally enables us to
40+
find lower loss models faster. (See {ml-pull}686[#686].)
3941
4042
== {es} version 7.4.1
4143

include/api/CDataFrameBoostedTreeRunner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
4747
private:
4848
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
4949
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
50+
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
5051
using TMemoryEstimator = std::function<void(std::int64_t)>;
5152

5253
private:
@@ -58,7 +59,8 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
5859
TMemoryEstimator memoryEstimator();
5960

6061
bool restoreBoostedTree(core::CDataFrame& frame,
61-
CDataFrameAnalysisSpecification::TDataSearcherUPtr& restoreSearcher);
62+
std::size_t dependentVariableColumn,
63+
TDataSearcherUPtr& restoreSearcher);
6264

6365
private:
6466
// Note custom config is written directly to the factory object.

include/core/CLoopProgress.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
namespace ml {
1616
namespace core {
17+
class CStatePersistInserter;
18+
class CStateRestoreTraverser;
1719

1820
//! \brief Manages recording the progress of a loop.
1921
//!
@@ -46,14 +48,35 @@ class CORE_EXPORT CLoopProgress {
4648
using TProgressCallback = std::function<void(double)>;
4749

4850
public:
51+
CLoopProgress();
4952
template<typename ITR>
50-
CLoopProgress(ITR begin, ITR end, const TProgressCallback& recordProgress, double scale = 1.0)
53+
CLoopProgress(ITR begin, ITR end, const TProgressCallback& recordProgress = noop, double scale = 1.0)
5154
: CLoopProgress(std::distance(begin, end), recordProgress, scale) {}
52-
CLoopProgress(std::size_t size, const TProgressCallback& recordProgress, double scale = 1.0);
55+
CLoopProgress(std::size_t size,
56+
const TProgressCallback& recordProgress = noop,
57+
double scale = 1.0);
58+
59+
//! Attach a new progress monitor callback.
60+
void progressCallback(const TProgressCallback& recordProgress);
5361

5462
//! Increment the progress by \p i.
5563
void increment(std::size_t i = 1);
5664

65+
//! Resume progress monitoring which was restored.
66+
void resumeRestored();
67+
68+
//! Get a checksum for this object.
69+
std::uint64_t checksum() const;
70+
71+
//! Persist by passing information to \p inserter.
72+
void acceptPersistInserter(CStatePersistInserter& inserter) const;
73+
74+
//! Populate the object from serialized data.
75+
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);
76+
77+
private:
78+
static void noop(double);
79+
5780
private:
5881
std::size_t m_Size;
5982
std::size_t m_Steps;

include/maths/CBoostedTreeFactory.h

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <core/CDataFrame.h>
1111

1212
#include <maths/CBoostedTree.h>
13+
#include <maths/CLinearAlgebra.h>
1314
#include <maths/ImportExport.h>
1415

1516
#include <boost/optional.hpp>
@@ -43,12 +44,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
4344
TLossFunctionUPtr loss);
4445

4546
//! Construct a boosted tree object from its serialized version.
46-
static TBoostedTreeUPtr
47-
constructFromString(std::istream& jsonStringStream,
48-
core::CDataFrame& frame,
49-
TProgressCallback recordProgress = noopRecordProgress,
50-
TMemoryUsageCallback recordMemoryUsage = noopRecordMemoryUsage,
51-
TTrainingStateCallback recordTrainingState = noopRecordTrainingState);
47+
//!
48+
//! \warning Throws runtime error on fail to restore.
49+
static CBoostedTreeFactory constructFromString(std::istream& jsonStringStream);
5250

5351
~CBoostedTreeFactory();
5452
CBoostedTreeFactory(CBoostedTreeFactory&) = delete;
@@ -93,17 +91,22 @@ class MATHS_EXPORT CBoostedTreeFactory final {
9391
TBoostedTreeUPtr buildFor(core::CDataFrame& frame, std::size_t dependentVariable);
9492

9593
private:
94+
using TDoubleDoublePr = std::pair<double, double>;
9695
using TOptionalDouble = boost::optional<double>;
9796
using TOptionalSize = boost::optional<std::size_t>;
97+
using TVector = CVectorNx1<double, 3>;
98+
using TOptionalVector = boost::optional<TVector>;
9899
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
99100
using TBoostedTreeImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
101+
using TApplyRegularizerStep =
102+
std::function<void(CBoostedTreeImpl&, double, std::size_t)>;
100103

101104
private:
102105
static const double MINIMUM_ETA;
103106
static const std::size_t MAXIMUM_NUMBER_TREES;
104107

105108
private:
106-
CBoostedTreeFactory(std::size_t numberThreads, TLossFunctionUPtr loss);
109+
CBoostedTreeFactory(bool restored, std::size_t numberThreads, TLossFunctionUPtr loss);
107110

108111
//! Compute the row masks for the missing values for each feature.
109112
void initializeMissingFeatureMasks(const core::CDataFrame& frame) const;
@@ -121,25 +124,53 @@ class MATHS_EXPORT CBoostedTreeFactory final {
121124
//! Initialize the regressors sample distribution.
122125
bool initializeFeatureSampleDistribution() const;
123126

124-
//! Read overrides for hyperparameters and if necessary estimate the initial
125-
//! values for \f$\lambda\f$ and \f$\gamma\f$ which match the gain from an
126-
//! overfit tree.
127-
void initializeHyperparameters(core::CDataFrame& frame) const;
127+
//! Set the initial values for the various hyperparameters.
128+
void initializeHyperparameters(core::CDataFrame& frame);
129+
130+
//! Estimate a good central value for the regularisation hyperparameters
131+
//! search bounding box.
132+
void initializeUnsetRegularizationHyperparameters(core::CDataFrame& frame);
133+
134+
//! Estimate the reduction in gain from a split and the total curvature of
135+
//! the loss function at a split.
136+
TDoubleDoublePr estimateTreeGainAndCurvature(core::CDataFrame& frame,
137+
const core::CPackedBitVector& trainingRowMask) const;
138+
139+
//! Perform a line search for the test loss w.r.t. a single regularization
140+
//! hyperparameter and apply Newton's method to find the minimum. The plan
141+
//! is to find a value near where the model starts to overfit.
142+
//!
143+
//! \return The interval to search during the main hyperparameter optimisation
144+
//! loop or null if this couldn't be found.
145+
TOptionalVector testLossNewtonLineSearch(core::CDataFrame& frame,
146+
core::CPackedBitVector trainingRowMask,
147+
const TApplyRegularizerStep& applyRegularizerStep,
148+
double returnedIntervalLeftEndOffset,
149+
double returnedIntervalRightEndOffset) const;
128150

129151
//! Initialize the state for hyperparameter optimisation.
130152
void initializeHyperparameterOptimisation() const;
131153

132154
//! Get the number of hyperparameter tuning rounds to use.
133155
std::size_t numberHyperparameterTuningRounds() const;
134156

157+
//! Setup monitoring for training progress.
158+
void initializeTrainingProgressMonitoring();
159+
160+
//! Refresh progress monitoring after restoring from saved training state.
161+
void resumeRestoredTrainingProgressMonitoring();
162+
135163
static void noopRecordProgress(double);
136164
static void noopRecordMemoryUsage(std::int64_t);
137165
static void noopRecordTrainingState(CDataFrameRegressionModel::TPersistFunc);
138166

139167
private:
140168
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
141169
TOptionalSize m_BayesianOptimisationRestarts;
170+
bool m_Restored = false;
142171
TBoostedTreeImplUPtr m_TreeImpl;
172+
TVector m_LogGammaSearchInterval;
173+
TVector m_LogLambdaSearchInterval;
143174
TProgressCallback m_RecordProgress = noopRecordProgress;
144175
TMemoryUsageCallback m_RecordMemoryUsage = noopRecordMemoryUsage;
145176
TTrainingStateCallback m_RecordTrainingState = noopRecordTrainingState;

0 commit comments

Comments
 (0)