10
10
#include < core/CDataFrame.h>
11
11
12
12
#include < maths/CBoostedTree.h>
13
+ #include < maths/CLinearAlgebra.h>
13
14
#include < maths/ImportExport.h>
14
15
15
16
#include < boost/optional.hpp>
@@ -43,12 +44,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
43
44
TLossFunctionUPtr loss);
44
45
45
46
// ! Construct a boosted tree object from its serialized version.
46
- static TBoostedTreeUPtr
47
- constructFromString (std::istream& jsonStringStream,
48
- core::CDataFrame& frame,
49
- TProgressCallback recordProgress = noopRecordProgress,
50
- TMemoryUsageCallback recordMemoryUsage = noopRecordMemoryUsage,
51
- TTrainingStateCallback recordTrainingState = noopRecordTrainingState);
47
+ // !
48
+ // ! \warning Throws runtime error on fail to restore.
49
+ static CBoostedTreeFactory constructFromString (std::istream& jsonStringStream);
52
50
53
51
~CBoostedTreeFactory ();
54
52
CBoostedTreeFactory (CBoostedTreeFactory&) = delete ;
@@ -93,17 +91,22 @@ class MATHS_EXPORT CBoostedTreeFactory final {
93
91
TBoostedTreeUPtr buildFor (core::CDataFrame& frame, std::size_t dependentVariable);
94
92
95
93
private:
94
+ using TDoubleDoublePr = std::pair<double , double >;
96
95
using TOptionalDouble = boost::optional<double >;
97
96
using TOptionalSize = boost::optional<std::size_t >;
97
+ using TVector = CVectorNx1<double , 3 >;
98
+ using TOptionalVector = boost::optional<TVector>;
98
99
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
99
100
using TBoostedTreeImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
101
+ using TApplyRegularizerStep =
102
+ std::function<void (CBoostedTreeImpl&, double , std::size_t )>;
100
103
101
104
private:
102
105
static const double MINIMUM_ETA;
103
106
static const std::size_t MAXIMUM_NUMBER_TREES;
104
107
105
108
private:
106
- CBoostedTreeFactory (std::size_t numberThreads, TLossFunctionUPtr loss);
109
+ CBoostedTreeFactory (bool restored, std::size_t numberThreads, TLossFunctionUPtr loss);
107
110
108
111
// ! Compute the row masks for the missing values for each feature.
109
112
void initializeMissingFeatureMasks (const core::CDataFrame& frame) const ;
@@ -121,25 +124,53 @@ class MATHS_EXPORT CBoostedTreeFactory final {
121
124
// ! Initialize the regressors sample distribution.
122
125
bool initializeFeatureSampleDistribution () const ;
123
126
124
- // ! Read overrides for hyperparameters and if necessary estimate the initial
125
- // ! values for \f$\lambda\f$ and \f$\gamma\f$ which match the gain from an
126
- // ! overfit tree.
127
- void initializeHyperparameters (core::CDataFrame& frame) const ;
127
+ // ! Set the initial values for the various hyperparameters.
128
+ void initializeHyperparameters (core::CDataFrame& frame);
129
+
130
+ // ! Estimate a good central value for the regularisation hyperparameters
131
+ // ! search bounding box.
132
+ void initializeUnsetRegularizationHyperparameters (core::CDataFrame& frame);
133
+
134
+ // ! Estimate the reduction in gain from a split and the total curvature of
135
+ // ! the loss function at a split.
136
+ TDoubleDoublePr estimateTreeGainAndCurvature (core::CDataFrame& frame,
137
+ const core::CPackedBitVector& trainingRowMask) const ;
138
+
139
+ // ! Perform a line search for the test loss w.r.t. a single regularization
140
+ // ! hyperparameter and apply Newton's method to find the minimum. The plan
141
+ // ! is to find a value near where the model starts to overfit.
142
+ // !
143
+ // ! \return The interval to search during the main hyperparameter optimisation
144
+ // ! loop or null if this couldn't be found.
145
+ TOptionalVector testLossNewtonLineSearch (core::CDataFrame& frame,
146
+ core::CPackedBitVector trainingRowMask,
147
+ const TApplyRegularizerStep& applyRegularizerStep,
148
+ double returnedIntervalLeftEndOffset,
149
+ double returnedIntervalRightEndOffset) const ;
128
150
129
151
// ! Initialize the state for hyperparameter optimisation.
130
152
void initializeHyperparameterOptimisation () const ;
131
153
132
154
// ! Get the number of hyperparameter tuning rounds to use.
133
155
std::size_t numberHyperparameterTuningRounds () const ;
134
156
157
+ // ! Setup monitoring for training progress.
158
+ void initializeTrainingProgressMonitoring ();
159
+
160
+ // ! Refresh progress monitoring after restoring from saved training state.
161
+ void resumeRestoredTrainingProgressMonitoring ();
162
+
135
163
static void noopRecordProgress (double );
136
164
static void noopRecordMemoryUsage (std::int64_t );
137
165
static void noopRecordTrainingState (CDataFrameRegressionModel::TPersistFunc);
138
166
139
167
private:
140
168
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
141
169
TOptionalSize m_BayesianOptimisationRestarts;
170
+ bool m_Restored = false ;
142
171
TBoostedTreeImplUPtr m_TreeImpl;
172
+ TVector m_LogGammaSearchInterval;
173
+ TVector m_LogLambdaSearchInterval;
143
174
TProgressCallback m_RecordProgress = noopRecordProgress;
144
175
TMemoryUsageCallback m_RecordMemoryUsage = noopRecordMemoryUsage;
145
176
TTrainingStateCallback m_RecordTrainingState = noopRecordTrainingState;
0 commit comments