Skip to content

Commit 3c20a45

Browse files
authored
[7.x][ML] Implement MSLE loss function for regression
This PR implements mean squared logarithmic error loss function for regression. It also adds CLoss::isRegression() function to distinguish between regression and classification loss functions.
1 parent c7652eb commit 3c20a45

12 files changed

+689
-56
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
4343
* Adds new `num_matches` and `preferred_to_categories` fields to category output.
4444
(See {ml-pull}1062[#1062])
45+
* Adds mean squared logarithmic error (MSLE) for regression. (See {ml-pull}1101[#1101].)
4546
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
4647
* Switched data frame analytics model memory estimates from kilobytes to megabytes.
4748
(See {ml-pull}1126[#1126], issue: {issue}54506[#54506].)
@@ -55,6 +56,7 @@
5556
* Fixed background persistence of categorizer state (See {ml-pull}1137[#1137],
5657
issue: {ml-issue}1136[#1136].)
5758

59+
5860
== {es} version 7.7.0
5961

6062
=== New Features

include/api/CDataFrameTrainBoostedTreeRegressionRunner.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@ namespace api {
1818
//! \brief Runs boosted tree regression on a core::CDataFrame.
1919
class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final
2020
: public CDataFrameTrainBoostedTreeRunner {
21+
22+
public:
23+
using TLossFunctionUPtr = std::unique_ptr<maths::boosted_tree::CLoss>;
24+
enum ELossFunctionType { E_Mse, E_Msle };
25+
2126
public:
2227
static const std::string STRATIFIED_CROSS_VALIDATION;
28+
static const std::string LOSS_FUNCTION;
29+
static const std::string MSE;
30+
static const std::string MSLE;
2331

2432
public:
2533
static const CDataFrameAnalysisConfigReader& parameterReader();
2634

35+
static TLossFunctionUPtr lossFunction(const CDataFrameAnalysisParameters& parameters);
36+
2737
//! This is not intended to be called directly: use CDataFrameTrainBoostedTreeRegressionRunnerFactory.
2838
CDataFrameTrainBoostedTreeRegressionRunner(const CDataFrameAnalysisSpecification& spec,
2939
const CDataFrameAnalysisParameters& parameters);

include/maths/CBoostedTreeLoss.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <functional>
2020
#include <memory>
2121
#include <string>
22+
#include <utility>
2223
#include <vector>
2324

2425
namespace ml {
@@ -66,6 +67,66 @@ class MATHS_EXPORT CArgMinMseImpl final : public CArgMinLossImpl {
6667
TMeanAccumulator m_MeanError;
6768
};
6869

70+
//! \brief Finds the value to add to a set of predictions which approximately
71+
//! minimises the regularised mean squared logarithmic error (MSLE).
72+
class MATHS_EXPORT CArgMinMsleImpl final : public CArgMinLossImpl {
73+
public:
74+
using TObjective = std::function<double(double)>;
75+
76+
public:
77+
CArgMinMsleImpl(double lambda);
78+
std::unique_ptr<CArgMinLossImpl> clone() const override;
79+
bool nextPass() override;
80+
void add(const TMemoryMappedFloatVector& prediction, double actual, double weight = 1.0) override;
81+
void merge(const CArgMinLossImpl& other) override;
82+
TDoubleVector value() const override;
83+
84+
// Exposed for unit testing.
85+
TObjective objective() const;
86+
87+
private:
88+
using TMinMaxAccumulator = CBasicStatistics::CMinMax<double>;
89+
using TMeanAccumulator = CBasicStatistics::SSampleMean<double>::TAccumulator;
90+
using TMeanVarAccumulator = CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
91+
using TVector = CVectorNx1<double, 3>;
92+
using TVectorMeanAccumulator = CBasicStatistics::SSampleMean<TVector>::TAccumulator;
93+
using TVectorMeanAccumulatorVec = std::vector<TVectorMeanAccumulator>;
94+
using TVectorMeanAccumulatorVecVec = std::vector<TVectorMeanAccumulatorVec>;
95+
using TDoubleDoublePr = std::pair<double, double>;
96+
using TSizeSizePr = std::pair<std::size_t, std::size_t>;
97+
98+
private:
99+
TSizeSizePr bucket(double prediction, double actual) const {
100+
auto bucketWidth{this->bucketWidth()};
101+
double bucketPrediction{(prediction - m_ExpPredictionMinMax.min()) /
102+
bucketWidth.first};
103+
std::size_t predictionBucketIndex{std::min(
104+
static_cast<std::size_t>(bucketPrediction), m_Buckets.size() - 1)};
105+
106+
double bucketActual{(actual - m_LogActualMinMax.min()) / bucketWidth.second};
107+
std::size_t actualBucketIndex{std::min(
108+
static_cast<std::size_t>(bucketActual), m_Buckets[0].size() - 1)};
109+
110+
return std::make_pair(predictionBucketIndex, actualBucketIndex);
111+
}
112+
113+
TDoubleDoublePr bucketWidth() const {
114+
double predictionBucketWidth{m_ExpPredictionMinMax.range() /
115+
static_cast<double>(m_Buckets.size())};
116+
double actualBucketWidth{m_LogActualMinMax.range() /
117+
static_cast<double>(m_Buckets[0].size())};
118+
return std::make_pair(predictionBucketWidth, actualBucketWidth);
119+
}
120+
121+
private:
122+
std::size_t m_CurrentPass = 0;
123+
TMinMaxAccumulator m_ExpPredictionMinMax;
124+
TMinMaxAccumulator m_LogActualMinMax;
125+
TVectorMeanAccumulatorVecVec m_Buckets;
126+
TMeanVarAccumulator m_MeanLogActual;
127+
TMeanAccumulator m_MeanError;
128+
};
129+
69130
//! \brief Finds the value to add to a set of predicted log-odds which minimises
70131
//! regularised cross entropy loss w.r.t. the actual categories.
71132
//!
@@ -278,6 +339,9 @@ class MATHS_EXPORT CLoss {
278339
//! Get the name of the loss function
279340
virtual const std::string& name() const = 0;
280341

342+
//! Returns true if the loss function is used for regression.
343+
virtual bool isRegression() const = 0;
344+
281345
protected:
282346
CArgMinLoss makeMinimizer(const boosted_tree_detail::CArgMinLossImpl& impl) const;
283347
};
@@ -307,6 +371,7 @@ class MATHS_EXPORT CMse final : public CLoss {
307371
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
308372
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
309373
const std::string& name() const override;
374+
bool isRegression() const override;
310375
};
311376

312377
//! \brief Implements loss for binomial logistic regression.
@@ -342,6 +407,7 @@ class MATHS_EXPORT CBinomialLogisticLoss final : public CLoss {
342407
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
343408
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
344409
const std::string& name() const override;
410+
bool isRegression() const override;
345411
};
346412

347413
//! \brief Implements loss for multinomial logistic regression.
@@ -380,10 +446,49 @@ class MATHS_EXPORT CMultinomialLogisticLoss final : public CLoss {
380446
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
381447
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
382448
const std::string& name() const override;
449+
bool isRegression() const override;
383450

384451
private:
385452
std::size_t m_NumberClasses;
386453
};
454+
//! \brief The MSLE loss function.
455+
//!
456+
//! DESCRIPTION:\n
457+
//! Formally, the MSLE error definition we use is \f$(\log(1+p) - \log(1+a))^2\f$.
458+
//! However, we approximate this by a quadratic form which has its minimum p = a and
459+
//! matches the value and derivative of MSLE loss function. For example, if the
460+
//! current prediction for the i'th training point is \f$p_i\f$, the loss is defined
461+
//! as
462+
//! <pre class="fragment">
463+
//! \f$\displaystyle l_i(p) = c_i + w_i(p - a_i)^2\f$
464+
//! </pre>
465+
//! where \f$w_i = \frac{\log(1+p_i) - \log(1+a_i)}{(1+p_i)(p_i-a_i)}\f$ and \f$c_i\f$
466+
//! is chosen so \f$l_i(p_i) = (\log(1+p_i) - \log(1+a_i))^2\f$.
467+
class MATHS_EXPORT CMsle final : public CLoss {
468+
public:
469+
static const std::string NAME;
470+
471+
public:
472+
EType type() const override;
473+
std::unique_ptr<CLoss> clone() const override;
474+
std::size_t numberParameters() const override;
475+
double value(const TMemoryMappedFloatVector& prediction,
476+
double actual,
477+
double weight = 1.0) const override;
478+
void gradient(const TMemoryMappedFloatVector& prediction,
479+
double actual,
480+
TWriter writer,
481+
double weight = 1.0) const override;
482+
void curvature(const TMemoryMappedFloatVector& prediction,
483+
double actual,
484+
TWriter writer,
485+
double weight = 1.0) const override;
486+
bool isCurvatureConstant() const override;
487+
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
488+
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
489+
const std::string& name() const override;
490+
bool isRegression() const override;
491+
};
387492
}
388493
}
389494
}

include/test/CDataFrameAnalysisSpecificationFactory.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <core/CDataSearcher.h>
1212

1313
#include <api/CDataFrameAnalysisSpecification.h>
14+
#include <api/CDataFrameTrainBoostedTreeRegressionRunner.h>
1415

1516
#include <test/ImportExport.h>
1617

@@ -32,6 +33,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
3233
using TDataSearcherUPtr = std::unique_ptr<core::CDataSearcher>;
3334
using TRestoreSearcherSupplier = std::function<TDataSearcherUPtr()>;
3435
using TSpecificationUPtr = std::unique_ptr<api::CDataFrameAnalysisSpecification>;
36+
using TRegressionLossFunction = api::CDataFrameTrainBoostedTreeRegressionRunner::ELossFunctionType;
3537

3638
public:
3739
CDataFrameAnalysisSpecificationFactory();
@@ -73,6 +75,10 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
7375
CDataFrameAnalysisSpecificationFactory&
7476
predictionRestoreSearcherSupplier(TRestoreSearcherSupplier* restoreSearcherSupplier);
7577

78+
// Regression
79+
CDataFrameAnalysisSpecificationFactory&
80+
regressionLossFunction(TRegressionLossFunction lossFunction);
81+
7682
// Classification
7783
CDataFrameAnalysisSpecificationFactory& numberClasses(std::size_t number);
7884
CDataFrameAnalysisSpecificationFactory& numberTopClasses(std::size_t number);
@@ -116,6 +122,8 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
116122
std::size_t m_NumberTopShapValues = 0;
117123
TPersisterSupplier* m_PersisterSupplier = nullptr;
118124
TRestoreSearcherSupplier* m_RestoreSearcherSupplier = nullptr;
125+
// Regression
126+
TRegressionLossFunction m_RegressionLossFunction = TRegressionLossFunction::E_Mse;
119127
// Classification
120128
std::size_t m_NumberClasses = 2;
121129
std::size_t m_NumberTopClasses = 0;

include/test/CDataFrameAnalyzerTrainingFactory.h

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ namespace test {
2929
class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
3030
public:
3131
enum EPredictionType {
32+
E_MsleRegression,
3233
E_Regression,
3334
E_BinaryClassification,
3435
E_MulticlassClassification
3536
};
3637
using TStrVec = std::vector<std::string>;
3738
using TDoubleVec = std::vector<double>;
3839
using TDataFrameUPtr = std::unique_ptr<core::CDataFrame>;
40+
using TLossUPtr = std::unique_ptr<maths::boosted_tree::CLoss>;
41+
using TTargetTransformer = std::function<double(double)>;
3942

4043
public:
4144
template<typename T>
@@ -67,6 +70,10 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
6770
case E_Regression:
6871
return setupLinearRegressionData(fieldNames, fieldValues, analyzer,
6972
weights, regressors, targets);
73+
case E_MsleRegression:
74+
return setupLinearRegressionData(fieldNames, fieldValues, analyzer,
75+
weights, regressors, targets,
76+
[](double x) { return x * x; });
7077
case E_BinaryClassification:
7178
return setupBinaryClassificationData(fieldNames, fieldValues, analyzer,
7279
weights, regressors, targets);
@@ -76,11 +83,21 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
7683
}
7784
}();
7885

79-
std::unique_ptr<maths::boosted_tree::CLoss> loss;
80-
if (type == E_Regression) {
86+
TLossUPtr loss;
87+
switch (type) {
88+
case E_Regression:
8189
loss = std::make_unique<maths::boosted_tree::CMse>();
82-
} else {
90+
break;
91+
case E_MsleRegression:
92+
loss = std::make_unique<maths::boosted_tree::CMsle>();
93+
break;
94+
case E_BinaryClassification:
8395
loss = std::make_unique<maths::boosted_tree::CBinomialLogisticLoss>();
96+
break;
97+
case E_MulticlassClassification:
98+
// TODO
99+
loss = TLossUPtr{};
100+
break;
84101
}
85102

86103
maths::CBoostedTreeFactory treeFactory{
@@ -121,15 +138,7 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
121138
frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
122139
for (auto row = beginRows; row != endRows; ++row) {
123140
auto prediction = tree->readAndAdjustPrediction(*row);
124-
switch (type) {
125-
case E_Regression:
126-
appendPrediction(*frame, weights.size(), prediction, expectedPredictions);
127-
break;
128-
case E_BinaryClassification:
129-
case E_MulticlassClassification:
130-
appendPrediction(*frame, weights.size(), prediction, expectedPredictions);
131-
break;
132-
}
141+
appendPrediction(*frame, weights.size(), prediction, expectedPredictions);
133142
}
134143
});
135144
}
@@ -140,12 +149,16 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
140149
const TDoubleVec& weights,
141150
const TDoubleVec& regressors,
142151
TStrVec& targets);
143-
static TDataFrameUPtr setupLinearRegressionData(const TStrVec& fieldNames,
144-
TStrVec& fieldValues,
145-
api::CDataFrameAnalyzer& analyzer,
146-
const TDoubleVec& weights,
147-
const TDoubleVec& regressors,
148-
TStrVec& targets);
152+
static TDataFrameUPtr
153+
setupLinearRegressionData(const TStrVec& fieldNames,
154+
TStrVec& fieldValues,
155+
api::CDataFrameAnalyzer& analyzer,
156+
const TDoubleVec& weights,
157+
const TDoubleVec& regressors,
158+
TStrVec& targets,
159+
TTargetTransformer targetTransformer = [](double x) {
160+
return x;
161+
});
149162

150163
private:
151164
using TDouble2Vec = core::CSmallVector<double, 2>;

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <api/ElasticsearchStateIndex.h>
2222

2323
#include <cmath>
24+
#include <memory>
2425
#include <set>
26+
#include <string>
2527

2628
namespace ml {
2729
namespace api {
@@ -38,16 +40,30 @@ CDataFrameTrainBoostedTreeRegressionRunner::parameterReader() {
3840
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
3941
theReader.addParameter(STRATIFIED_CROSS_VALIDATION,
4042
CDataFrameAnalysisConfigReader::E_OptionalParameter);
43+
theReader.addParameter(LOSS_FUNCTION, CDataFrameAnalysisConfigReader::E_OptionalParameter,
44+
{{MSE, int{E_Mse}}, {MSLE, int{E_Msle}}});
4145
return theReader;
4246
}()};
4347
return PARAMETER_READER;
4448
}
4549

50+
CDataFrameTrainBoostedTreeRegressionRunner::TLossFunctionUPtr
51+
CDataFrameTrainBoostedTreeRegressionRunner::lossFunction(const CDataFrameAnalysisParameters& parameters) {
52+
ELossFunctionType lossFunctionType{parameters[LOSS_FUNCTION].fallback(E_Mse)};
53+
switch (lossFunctionType) {
54+
case E_Msle:
55+
return std::make_unique<maths::boosted_tree::CMsle>();
56+
case E_Mse:
57+
return std::make_unique<maths::boosted_tree::CMse>();
58+
}
59+
}
60+
4661
CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegressionRunner(
4762
const CDataFrameAnalysisSpecification& spec,
4863
const CDataFrameAnalysisParameters& parameters)
4964
: CDataFrameTrainBoostedTreeRunner{
50-
spec, parameters, std::make_unique<maths::boosted_tree::CMse>()} {
65+
spec, parameters,
66+
CDataFrameTrainBoostedTreeRegressionRunner::lossFunction(parameters)} {
5167

5268
this->boostedTreeFactory().stratifyRegressionCrossValidation(
5369
parameters[STRATIFIED_CROSS_VALIDATION].fallback(true));
@@ -117,6 +133,9 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
117133

118134
// clang-format off
119135
const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"};
136+
const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"};
137+
const std::string CDataFrameTrainBoostedTreeRegressionRunner::MSE{"mse"};
138+
const std::string CDataFrameTrainBoostedTreeRegressionRunner::MSLE{"msle"};
120139
// clang-format on
121140

122141
const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() const {

0 commit comments

Comments
 (0)