Skip to content

Commit 9957f81

Browse files
authored
[7.5][ML] Logistic regression loss function for boosted tree training (#730)
Backport #713.
1 parent 874eb2e commit 9957f81

12 files changed

+861
-280
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ boosted tree training. Hard depth based regularization is often the strategy of
4343
choice to prevent over fitting for XGBoost. By smoothing we can make better tradeoffs.
4444
Also, the parameters of the penalty function are mode suited to optimising with our
4545
Bayesian optimisation based hyperparameter search. (See {ml-pull}698[#698].)
46+
* Binomial logistic regression targeting cross entropy. (See {ml-pull}713[#713].)
4647
* Improvements to count and sum anomaly detection for sparse data. This primarily
4748
aims to improve handling of data which are predictably present: detecting when they
4849
are unexpectedly missing. (See {ml-pull}721[#721].)

include/maths/CBoostedTree.h

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313

1414
#include <maths/CBasicStatistics.h>
1515
#include <maths/CDataFrameRegressionModel.h>
16+
#include <maths/CLinearAlgebra.h>
1617
#include <maths/ImportExport.h>
1718

1819
#include <cstddef>
1920
#include <memory>
21+
#include <string>
22+
#include <vector>
2023

2124
namespace ml {
2225
namespace core {
@@ -29,18 +32,29 @@ class CEncodedDataFrameRowRef;
2932
namespace boosted_tree_detail {
3033
class MATHS_EXPORT CArgMinLossImpl {
3134
public:
35+
CArgMinLossImpl(double lambda);
3236
virtual ~CArgMinLossImpl() = default;
3337

3438
virtual std::unique_ptr<CArgMinLossImpl> clone() const = 0;
39+
virtual bool nextPass() = 0;
3540
virtual void add(double prediction, double actual) = 0;
3641
virtual void merge(const CArgMinLossImpl& other) = 0;
3742
virtual double value() const = 0;
43+
44+
protected:
45+
double lambda() const;
46+
47+
private:
48+
double m_Lambda;
3849
};
3950

40-
//! \brief Finds the value to add to a set of predictions which minimises the MSE.
51+
//! \brief Finds the value to add to a set of predictions which minimises the
52+
//! regularized MSE w.r.t. the actual values.
4153
class MATHS_EXPORT CArgMinMseImpl final : public CArgMinLossImpl {
4254
public:
55+
CArgMinMseImpl(double lambda);
4356
std::unique_ptr<CArgMinLossImpl> clone() const override;
57+
bool nextPass() override;
4458
void add(double prediction, double actual) override;
4559
void merge(const CArgMinLossImpl& other) override;
4660
double value() const override;
@@ -51,6 +65,46 @@ class MATHS_EXPORT CArgMinMseImpl final : public CArgMinLossImpl {
5165
private:
5266
TMeanAccumulator m_MeanError;
5367
};
68+
69+
//! \brief Finds the value to add to a set of predicted log-odds which minimises
70+
//! regularised cross entropy loss w.r.t. the actual categories.
71+
class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
72+
public:
73+
CArgMinLogisticImpl(double lambda);
74+
std::unique_ptr<CArgMinLossImpl> clone() const override;
75+
bool nextPass() override;
76+
void add(double prediction, double actual) override;
77+
void merge(const CArgMinLossImpl& other) override;
78+
double value() const override;
79+
80+
private:
81+
using TMinMaxAccumulator = CBasicStatistics::CMinMax<double>;
82+
using TSizeVector = CVectorNx1<std::size_t, 2>;
83+
using TSizeVectorVec = std::vector<TSizeVector>;
84+
85+
private:
86+
std::size_t bucket(double prediction) const {
87+
double bucket{(prediction - m_PredictionMinMax.min()) / this->bucketWidth()};
88+
return std::min(static_cast<std::size_t>(bucket),
89+
m_BucketCategoryCounts.size() - 1);
90+
}
91+
92+
double bucketCentre(std::size_t bucket) const {
93+
return m_PredictionMinMax.min() +
94+
(static_cast<double>(bucket) + 0.5) * this->bucketWidth();
95+
}
96+
97+
double bucketWidth() const {
98+
return m_PredictionMinMax.range() /
99+
static_cast<double>(m_BucketCategoryCounts.size());
100+
}
101+
102+
private:
103+
std::size_t m_CurrentPass = 0;
104+
TMinMaxAccumulator m_PredictionMinMax;
105+
TSizeVector m_CategoryCounts;
106+
TSizeVectorVec m_BucketCategoryCounts;
107+
};
54108
}
55109

56110
namespace boosted_tree {
@@ -64,6 +118,11 @@ class MATHS_EXPORT CArgMinLoss {
64118
CArgMinLoss& operator=(const CArgMinLoss& other);
65119
CArgMinLoss& operator=(CArgMinLoss&& other) = default;
66120

121+
//! Start another pass over the predictions and actuals.
122+
//!
123+
//! \return True if we need to perform another pass to compute value().
124+
bool nextPass() const;
125+
67126
//! Update with a point prediction and actual value.
68127
void add(double prediction, double actual);
69128

@@ -94,6 +153,8 @@ class MATHS_EXPORT CArgMinLoss {
94153
class MATHS_EXPORT CLoss {
95154
public:
96155
virtual ~CLoss() = default;
156+
//! Clone the loss.
157+
virtual std::unique_ptr<CLoss> clone() const = 0;
97158
//! The value of the loss function.
98159
virtual double value(double prediction, double actual) const = 0;
99160
//! The slope of the loss function.
@@ -103,7 +164,7 @@ class MATHS_EXPORT CLoss {
103164
//! Returns true if the loss curvature is constant.
104165
virtual bool isCurvatureConstant() const = 0;
105166
//! Get an object which computes the leaf value that minimises loss.
106-
virtual CArgMinLoss minimizer() const = 0;
167+
virtual CArgMinLoss minimizer(double lambda) const = 0;
107168
//! Get the name of the loss function
108169
virtual const std::string& name() const = 0;
109170

@@ -114,11 +175,34 @@ class MATHS_EXPORT CLoss {
114175
//! \brief The MSE loss function.
115176
class MATHS_EXPORT CMse final : public CLoss {
116177
public:
178+
std::unique_ptr<CLoss> clone() const override;
117179
double value(double prediction, double actual) const override;
118180
double gradient(double prediction, double actual) const override;
119181
double curvature(double prediction, double actual) const override;
120182
bool isCurvatureConstant() const override;
121-
CArgMinLoss minimizer() const override;
183+
CArgMinLoss minimizer(double lambda) const override;
184+
const std::string& name() const override;
185+
186+
public:
187+
static const std::string NAME;
188+
};
189+
190+
//! \brief Implements loss for binomial logistic regression.
191+
//!
192+
//! DESCRIPTION:\n
193+
//! This targets the cross entropy loss using the tree to predict class log-odds:
194+
//! <pre class="fragment">
195+
//! \f$\displaystyle l_i(p) = -(1 - a_i) \log(1 - S(p)) - a_i \log(S(p))\f$
196+
//! </pre>
197+
//! where \f$a_i\f$ denotes the actual class of the i'th example, \f$p\f$ is the
198+
//! prediction and \f$S(\cdot)\f$ denotes the logistic function.
199+
class MATHS_EXPORT CLogistic final : public CLoss {
200+
std::unique_ptr<CLoss> clone() const override;
201+
double value(double prediction, double actual) const override;
202+
double gradient(double prediction, double actual) const override;
203+
double curvature(double prediction, double actual) const override;
204+
bool isCurvatureConstant() const override;
205+
CArgMinLoss minimizer(double lambda) const override;
122206
const std::string& name() const override;
123207

124208
public:
@@ -248,6 +332,7 @@ class MATHS_EXPORT CBoostedTreeNode final {
248332
//! proposed by Reshef for this purpose. See CDataFrameCategoryEncoder for more details.
249333
class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
250334
public:
335+
using TStrVec = std::vector<std::string>;
251336
using TRowRef = core::CDataFrame::TRowRef;
252337
using TLossFunctionUPtr = std::unique_ptr<boosted_tree::CLoss>;
253338
using TDataFramePtr = core::CDataFrame*;
@@ -285,6 +370,16 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
285370
//! Get the model produced by training if it has been run.
286371
const TNodeVecVec& trainedModel() const;
287372

373+
//! The name of the object holding the best hyperaparameters in the state document.
374+
static const std::string& bestHyperparametersName();
375+
376+
//! The name of the object holding the best regularisation hyperparameters in the
377+
//! state document.
378+
static const std::string& bestRegularizationHyperparametersName();
379+
380+
//! A list of the names of the best individual hyperparameters in the state document.
381+
static TStrVec bestHyperparameterNames();
382+
288383
//! Persist by passing information to \p inserter.
289384
void acceptPersistInserter(core::CStatePersistInserter& inserter) const;
290385

include/maths/CBoostedTreeFactory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ class MATHS_EXPORT CBoostedTreeFactory final {
177177
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
178178
TOptionalSize m_BayesianOptimisationRestarts;
179179
bool m_Restored = false;
180+
std::size_t m_NumberThreads;
181+
TLossFunctionUPtr m_Loss;
180182
TBoostedTreeImplUPtr m_TreeImpl;
181183
TVector m_LogDepthPenaltyMultiplierSearchInterval;
182184
TVector m_LogTreeSizePenaltyMultiplierSearchInterval;

include/maths/CBoostedTreeImpl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ inline std::size_t predictionColumn(std::size_t numberColumns) {
4848
class MATHS_EXPORT CBoostedTreeImpl final {
4949
public:
5050
using TDoubleVec = std::vector<double>;
51+
using TStrVec = std::vector<std::string>;
5152
using TMeanAccumulator = CBasicStatistics::SSampleMean<double>::TAccumulator;
5253
using TMeanVarAccumulator = CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
5354
using TBayesinOptimizationUPtr = std::unique_ptr<maths::CBayesianOptimisation>;
@@ -101,6 +102,16 @@ class MATHS_EXPORT CBoostedTreeImpl final {
101102
//! frame with \p numberRows row and \p numberColumns columns will use.
102103
std::size_t estimateMemoryUsage(std::size_t numberRows, std::size_t numberColumns) const;
103104

105+
//! The name of the object holding the best hyperaparameters in the state document.
106+
static const std::string& bestHyperparametersName();
107+
108+
//! The name of the object holding the best regularisation hyperparameters in the
109+
//! state document.
110+
static const std::string& bestRegularizationHyperparametersName();
111+
112+
//! A list of the names of the best individual hyperparameters in the state document.
113+
static TStrVec bestHyperparameterNames();
114+
104115
//! Persist by passing information to \p inserter.
105116
void acceptPersistInserter(core::CStatePersistInserter& inserter) const;
106117

include/maths/CTools.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,8 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
678678
//! \param[in] width The step width.
679679
//! \param[in] x0 The centre of the step.
680680
//! \param[in] sign Determines whether it's a step up or down.
681-
static double logisticFunction(double x, double width, double x0 = 0.0, double sign = 1.0) {
681+
static double
682+
logisticFunction(double x, double width = 1.0, double x0 = 0.0, double sign = 1.0) {
682683
return sigmoid(std::exp(std::copysign(1.0, sign) * (x - x0) / width));
683684
}
684685

0 commit comments

Comments
 (0)