Skip to content

Commit ff3deb0

Browse files
valeriy42tveasey
andauthored
[7.13][ML] Output max num trees in hyperparameters metadata (#1877)
* [ML] Output max num trees in hyperparameters metadata (#1867) We output max number trees as a hyperparameter in the model metadata and add a unit test to ensure that we achieve reproducible results when retraining a model with all hyperparameters specified. Fixes #1853 . * [ML] Apply tree depth constraint to hyperparameter search bounding box setup (#1870) Following on from #1867, we can and should be imposing the minimum depth constraint to the hyperparameter search bounding box. This was incorrectly applied before and also fixes the issue with reproducibility based on user overrides. This is a bit cleaner than applying the constraint magically in the code to adjust hyperparameters. Co-authored-by: Tom Veasey <[email protected]>
1 parent a7c400c commit ff3deb0

14 files changed

+235
-60
lines changed

.dockerignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@ nbactions.xml
2020
#vscode files
2121
.vscode
2222
.clangd
23+
.cache
2324

2425
# gradle stuff
2526
.gradle/
2627
build/
2728
generated-resources/
2829

30+
# python environment stuff
31+
**/env/*
32+
*.pyc
33+
2934
# testing stuff
3035
**/.local*
3136
.vagrant/

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@ nbactions.xml
2020
#vscode files
2121
.vscode
2222
.clangd
23+
.cache
2324

2425
# gradle stuff
2526
.gradle/
2627
build/
2728
generated-resources/
2829

30+
# python environment stuff
31+
**/env/*
32+
*.pyc
33+
2934
# testing stuff
3035
**/.local*
3136
.vagrant/

docs/CHANGELOG.asciidoc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
* Fail gracefully on encountering unexpected state in restore from snapshot for anomaly
4747
detection. (See {ml-pull}1872[#1872].)
4848

49+
== {es} version 7.12.2
50+
51+
=== Bug Fixes
52+
53+
* Add missing hyperparamter to the model metadata. (See {ml-pull}1867[#1867].)
54+
4955
== {es} version 7.12.1
5056

5157
=== Enhancements

include/api/CInferenceModelMetadata.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class API_EXPORT CInferenceModelMetadata {
5656
void columnNames(const TStrVec& columnNames);
5757
void classValues(const TStrVec& classValues);
5858
void predictionFieldTypeResolverWriter(const TPredictionFieldTypeResolverWriter& resolverWriter);
59-
const std::string& typeString() const;
59+
static const std::string& typeString();
6060
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
6161
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
6262
void addToFeatureImportance(std::size_t i, const TVector& values);
@@ -67,19 +67,13 @@ class API_EXPORT CInferenceModelMetadata {
6767

6868
private:
6969
struct SHyperparameterImportance {
70-
SHyperparameterImportance(std::string hyperparameterName,
71-
double value,
72-
double absoluteImportance,
73-
double relativeImportance,
74-
bool supplied)
75-
: s_HyperparameterName(hyperparameterName), s_Value(value),
76-
s_AbsoluteImportance(absoluteImportance),
77-
s_RelativeImportance(relativeImportance), s_Supplied(supplied) {}
70+
enum EType { E_Double, E_Uint64 };
7871
std::string s_HyperparameterName;
7972
double s_Value;
8073
double s_AbsoluteImportance;
8174
double s_RelativeImportance;
8275
bool s_Supplied;
76+
EType s_Type;
8377
};
8478

8579
using TMeanAccumulator =

include/maths/CBoostedTreeUtils.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,20 @@ enum EHyperparameters {
4040
E_SoftTreeDepthTolerance,
4141
E_Eta,
4242
E_EtaGrowthRatePerTree,
43+
E_MaximumNumberTrees,
4344
E_FeatureBagFraction
4445
};
4546

4647
constexpr std::size_t NUMBER_HYPERPARAMETERS = E_FeatureBagFraction + 1; // This must be last hyperparameter
4748

4849
struct SHyperparameterImportance {
49-
SHyperparameterImportance(EHyperparameters hyperparameter,
50-
double value,
51-
double absoluteImportance,
52-
double relativeImportance,
53-
bool supplied)
54-
: s_Hyperparameter(hyperparameter), s_Value(value),
55-
s_AbsoluteImportance(absoluteImportance),
56-
s_RelativeImportance(relativeImportance), s_Supplied(supplied) {}
50+
enum EType { E_Double = 0, E_Uint64 };
5751
EHyperparameters s_Hyperparameter;
5852
double s_Value;
5953
double s_AbsoluteImportance;
6054
double s_RelativeImportance;
6155
bool s_Supplied;
56+
EType s_Type;
6257
};
6358

6459
//! Get the size of upper triangle of the loss Hessain.

include/test/CDataFrameAnalysisSpecificationFactory.h

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
7676
CDataFrameAnalysisSpecificationFactory& predictionSoftTreeDepthLimit(double limit);
7777
CDataFrameAnalysisSpecificationFactory& predictionSoftTreeDepthTolerance(double tolerance);
7878
CDataFrameAnalysisSpecificationFactory& predictionEta(double eta);
79+
CDataFrameAnalysisSpecificationFactory&
80+
predictionEtaGrowthRatePerTree(double etaGrowthRatePerTree);
7981
CDataFrameAnalysisSpecificationFactory& predictionMaximumNumberTrees(std::size_t number);
8082
CDataFrameAnalysisSpecificationFactory& predictionDownsampleFactor(double downsampleFactor);
8183
CDataFrameAnalysisSpecificationFactory& predictionFeatureBagFraction(double fraction);
@@ -119,37 +121,38 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
119121
TOptionalSize m_Columns;
120122
TOptionalSize m_MemoryLimit;
121123
std::string m_MissingString;
122-
bool m_DiskUsageAllowed = true;
124+
bool m_DiskUsageAllowed{true};
123125
// Outliers
124126
std::string m_Method;
125-
std::size_t m_NumberNeighbours = 0;
126-
bool m_ComputeFeatureInfluence = false;
127+
std::size_t m_NumberNeighbours{0};
128+
bool m_ComputeFeatureInfluence{false};
127129
// Prediction
128-
std::size_t m_NumberRoundsPerHyperparameter = 0;
129-
std::size_t m_BayesianOptimisationRestarts = 0;
130+
std::size_t m_NumberRoundsPerHyperparameter{0};
131+
std::size_t m_BayesianOptimisationRestarts{0};
130132
TStrVec m_CategoricalFieldNames;
131133
std::string m_PredictionFieldName;
132-
double m_Alpha = -1.0;
133-
double m_Lambda = -1.0;
134-
double m_Gamma = -1.0;
135-
double m_SoftTreeDepthLimit = -1.0;
136-
double m_SoftTreeDepthTolerance = -1.0;
137-
double m_Eta = -1.0;
138-
std::size_t m_MaximumNumberTrees = 0;
139-
double m_DownsampleFactor = 0.0;
140-
double m_FeatureBagFraction = -1.0;
141-
std::size_t m_NumberTopShapValues = 0;
142-
TPersisterSupplier* m_PersisterSupplier = nullptr;
143-
TRestoreSearcherSupplier* m_RestoreSearcherSupplier = nullptr;
134+
double m_Alpha{-1.0};
135+
double m_Lambda{-1.0};
136+
double m_Gamma{-1.0};
137+
double m_SoftTreeDepthLimit{-1.0};
138+
double m_SoftTreeDepthTolerance{-1.0};
139+
double m_Eta{-1.0};
140+
double m_EtaGrowthRatePerTree{-1.0};
141+
std::size_t m_MaximumNumberTrees{0};
142+
double m_DownsampleFactor{0.0};
143+
double m_FeatureBagFraction{-1.0};
144+
std::size_t m_NumberTopShapValues{0};
145+
TPersisterSupplier* m_PersisterSupplier{nullptr};
146+
TRestoreSearcherSupplier* m_RestoreSearcherSupplier{nullptr};
144147
rapidjson::Document m_CustomProcessors;
145148
// Regression
146149
TOptionalLossFunctionType m_RegressionLossFunction;
147150
TOptionalDouble m_RegressionLossFunctionParameter;
148151
// Classification
149-
std::size_t m_NumberClasses = 2;
150-
std::size_t m_NumberTopClasses = 0;
152+
std::size_t m_NumberClasses{2};
153+
std::size_t m_NumberTopClasses{0};
151154
std::string m_PredictionFieldType;
152-
bool m_EarlyStoppingEnabled = true;
155+
bool m_EarlyStoppingEnabled{true};
153156
TStrDoublePrVec m_ClassificationWeights;
154157
};
155158
}

include/test/CDataFrameAnalyzerTrainingFactory.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <test/CRandomNumbers.h>
2222
#include <test/ImportExport.h>
2323

24+
#include <boost/optional/optional_fwd.hpp>
25+
2426
#include <string>
2527
#include <vector>
2628

@@ -35,15 +37,20 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
3537
using TLossUPtr = std::unique_ptr<maths::boosted_tree::CLoss>;
3638
using TTargetTransformer = std::function<double(double)>;
3739
using TLossFunctionType = maths::boosted_tree::ELossType;
40+
using TSizeOptional = boost::optional<std::size_t>;
3841

3942
public:
4043
static void addPredictionTestData(TLossFunctionType type,
4144
const TStrVec& fieldNames,
4245
TStrVec fieldValues,
4346
api::CDataFrameAnalyzer& analyzer,
44-
std::size_t numberExamples = 100) {
47+
std::size_t numberExamples = 100,
48+
TSizeOptional seed = {}) {
4549

4650
test::CRandomNumbers rng;
51+
if (seed) {
52+
rng.seed(seed.get());
53+
}
4754

4855
TDoubleVec weights;
4956
rng.generateUniformSamples(-1.0, 1.0, fieldNames.size() - 3, weights);
@@ -86,9 +93,13 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
8693
double eta = 0.0,
8794
std::size_t maximumNumberTrees = 0,
8895
double featureBagFraction = 0.0,
89-
double lossFunctionParameter = 1.0) {
96+
double lossFunctionParameter = 1.0,
97+
TSizeOptional seed = {}) {
9098

9199
test::CRandomNumbers rng;
100+
if (seed) {
101+
rng.seed(seed.get());
102+
}
92103

93104
TDoubleVec weights;
94105
rng.generateUniformSamples(-1.0, 1.0, fieldNames.size() - 3, weights);

lib/api/CInferenceModelMetadata.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <maths/CBoostedTreeUtils.h>
1111

1212
#include <cmath>
13+
#include <cstdint>
1314

1415
namespace ml {
1516
namespace api {
@@ -142,15 +143,21 @@ void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& w
142143
}
143144

144145
void CInferenceModelMetadata::writeHyperparameterImportance(TRapidJsonWriter& writer) const {
145-
// TODO use struct instead of a tuple
146146
writer.Key(JSON_HYPERPARAMETERS_TAG);
147147
writer.StartArray();
148148
for (const auto& item : m_HyperparameterImportance) {
149149
writer.StartObject();
150150
writer.Key(JSON_HYPERPARAMETER_NAME_TAG);
151151
writer.String(item.s_HyperparameterName);
152152
writer.Key(JSON_HYPERPARAMETER_VALUE_TAG);
153-
writer.Double(item.s_Value);
153+
switch (item.s_Type) {
154+
case SHyperparameterImportance::E_Double:
155+
writer.Double(item.s_Value);
156+
break;
157+
case SHyperparameterImportance::E_Uint64:
158+
writer.Uint64(static_cast<std::uint64_t>(item.s_Value));
159+
break;
160+
}
154161
if (item.s_Supplied == false) {
155162
writer.Key(JSON_ABSOLUTE_IMPORTANCE_TAG);
156163
writer.Double(item.s_AbsoluteImportance);
@@ -164,7 +171,7 @@ void CInferenceModelMetadata::writeHyperparameterImportance(TRapidJsonWriter& wr
164171
writer.EndArray();
165172
}
166173

167-
const std::string& CInferenceModelMetadata::typeString() const {
174+
const std::string& CInferenceModelMetadata::typeString() {
168175
return JSON_MODEL_METADATA_TAG;
169176
}
170177

@@ -233,15 +240,19 @@ void CInferenceModelMetadata::hyperparameterImportance(
233240
case maths::boosted_tree_detail::E_SoftTreeDepthTolerance:
234241
hyperparameterName = CDataFrameTrainBoostedTreeRunner::SOFT_TREE_DEPTH_TOLERANCE;
235242
break;
243+
case maths::boosted_tree_detail::E_MaximumNumberTrees:
244+
hyperparameterName = CDataFrameTrainBoostedTreeRunner::MAX_TREES;
245+
break;
236246
}
237247
double absoluteImportance{(std::fabs(item.s_AbsoluteImportance) < 1e-8)
238248
? 0.0
239249
: item.s_AbsoluteImportance};
240250
double relativeImportance{(std::fabs(item.s_RelativeImportance) < 1e-8)
241251
? 0.0
242252
: item.s_RelativeImportance};
243-
m_HyperparameterImportance.emplace_back(hyperparameterName, item.s_Value, absoluteImportance,
244-
relativeImportance, item.s_Supplied);
253+
m_HyperparameterImportance.push_back(
254+
{hyperparameterName, item.s_Value, absoluteImportance, relativeImportance,
255+
item.s_Supplied, static_cast<SHyperparameterImportance::EType>(item.s_Type)});
245256
}
246257
std::sort(m_HyperparameterImportance.begin(),
247258
m_HyperparameterImportance.end(), [](const auto& a, const auto& b) {

0 commit comments

Comments
 (0)