Skip to content

Commit e06ef9d

Browse files
authored
[7.6][ML] Correct logistic loss function (#1059)
Backport #1032.
1 parent 592d61b commit e06ef9d

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828

2929
//=== Regressions
3030

31+
== {es} version 7.6.2
32+
33+
=== Bug Fixes
34+
35+
* Fix a bug in the calculation of the minimum loss leaf values for classification.
36+
(See {ml-pull}1032[#1032].)
37+
3138
== {es} version 7.6.0
3239

3340
=== New Features

include/maths/CBoostedTree.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
9797
}
9898

9999
double bucketWidth() const {
100-
return m_PredictionMinMax.range() /
101-
static_cast<double>(m_BucketCategoryCounts.size());
100+
return m_PredictionMinMax.initialized()
101+
? m_PredictionMinMax.range() /
102+
static_cast<double>(m_BucketCategoryCounts.size())
103+
: 0.0;
102104
}
103105

104106
private:

lib/maths/CBoostedTree.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,17 @@ double CArgMinLogisticImpl::value() const {
154154
// case we only need one pass over the data and can compute the optimal
155155
// value from the counts of the two categories.
156156
if (this->bucketWidth() == 0.0) {
157-
objective = [this](double weight) {
157+
// This is the (unique) predicted value for the rows in leaf by the forest
158+
// so far (i.e. without the weight for the leaf we're about to add).
159+
double prediction{m_PredictionMinMax.initialized()
160+
? (m_PredictionMinMax.min() + m_PredictionMinMax.max()) / 2.0
161+
: 0.0};
162+
objective = [prediction, this](double weight) {
163+
double logOdds{prediction + weight};
158164
double c0{m_CategoryCounts(0)};
159165
double c1{m_CategoryCounts(1)};
160166
return this->lambda() * CTools::pow2(weight) -
161-
c0 * logOneMinusLogistic(weight) - c1 * logLogistic(weight);
167+
c0 * logOneMinusLogistic(logOdds) - c1 * logLogistic(logOdds);
162168
};
163169

164170
// Weight shrinkage means the optimal weight will be somewhere between
@@ -168,8 +174,8 @@ double CArgMinLogisticImpl::value() const {
168174
double empiricalProbabilityC1{c1 / (c0 + c1)};
169175
double empiricalLogOddsC1{
170176
std::log(empiricalProbabilityC1 / (1.0 - empiricalProbabilityC1))};
171-
minWeight = empiricalProbabilityC1 < 0.5 ? empiricalLogOddsC1 : 0.0;
172-
maxWeight = empiricalProbabilityC1 < 0.5 ? 0.0 : empiricalLogOddsC1;
177+
minWeight = (empiricalProbabilityC1 < 0.5 ? empiricalLogOddsC1 : 0.0) - prediction;
178+
maxWeight = (empiricalProbabilityC1 < 0.5 ? 0.0 : empiricalLogOddsC1) - prediction;
173179

174180
} else {
175181
objective = [this](double weight) {

lib/maths/CBoostedTreeImpl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,9 @@ void CBoostedTreeImpl::refreshPredictionsAndLossDerivatives(core::CDataFrame& fr
12411241
} while (nextPass());
12421242

12431243
for (std::size_t i = 0; i < tree.size(); ++i) {
1244-
tree[i].value(eta * leafValues[i].value());
1244+
if (tree[i].isLeaf()) {
1245+
tree[i].value(eta * leafValues[i].value());
1246+
}
12451247
}
12461248

12471249
LOG_TRACE(<< "tree =\n" << root(tree).print(tree));

lib/maths/unittest/CBoostedTreeTest.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,7 @@ BOOST_AUTO_TEST_CASE(testLogisticRegression) {
11231123
LOG_DEBUG(<< "log relative error = "
11241124
<< maths::CBasicStatistics::mean(logRelativeError));
11251125

1126-
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.75);
1126+
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.71);
11271127
meanLogRelativeError.add(maths::CBasicStatistics::mean(logRelativeError));
11281128
}
11291129

@@ -1215,7 +1215,7 @@ BOOST_AUTO_TEST_CASE(testImbalancedClasses) {
12151215
LOG_DEBUG(<< "recalls = " << core::CContainerPrinter::print(recalls));
12161216

12171217
BOOST_TEST_REQUIRE(std::fabs(precisions[0] - precisions[1]) < 0.1);
1218-
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.15);
1218+
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.13);
12191219
}
12201220

12211221
BOOST_AUTO_TEST_CASE(testEstimateMemoryUsedByTrain) {

0 commit comments

Comments
 (0)