Skip to content

Commit 1cf9de8

Browse files
authored
[ML] Correct logistic loss function (#1032)
1 parent 56550ee commit 1cf9de8

File tree

5 files changed

+28
-10
lines changed

5 files changed

+28
-10
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ the build from version 2.20 to 2.34. (See {ml-pull}1013[#1013].)
6565
* Account for the data frame's memory when estimating the peak memory used by classification
6666
and regression model training. (See {ml-pull}996[#996].)
6767

68+
== {es} version 7.6.2
69+
70+
=== Bug Fixes
71+
72+
* Fix a bug in the calculation of the minimum loss leaf values for classification.
73+
(See {ml-pull}1032[#1032].)
74+
6875
== {es} version 7.6.0
6976

7077
=== New Features

include/maths/CBoostedTreeLoss.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
9393
}
9494

9595
double bucketWidth() const {
96-
return m_PredictionMinMax.range() /
97-
static_cast<double>(m_BucketCategoryCounts.size());
96+
return m_PredictionMinMax.initialized()
97+
? m_PredictionMinMax.range() /
98+
static_cast<double>(m_BucketCategoryCounts.size())
99+
: 0.0;
98100
}
99101

100102
private:

lib/maths/CBoostedTreeImpl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,9 @@ void CBoostedTreeImpl::refreshPredictionsAndLossDerivatives(core::CDataFrame& fr
948948
} while (nextPass());
949949

950950
for (std::size_t i = 0; i < tree.size(); ++i) {
951-
tree[i].value(eta * leafValues[i].value());
951+
if (tree[i].isLeaf()) {
952+
tree[i].value(eta * leafValues[i].value());
953+
}
952954
}
953955

954956
LOG_TRACE(<< "tree =\n" << root(tree).print(tree));

lib/maths/CBoostedTreeLoss.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,17 @@ CArgMinLogisticImpl::TDoubleVector CArgMinLogisticImpl::value() const {
144144
// case we only need one pass over the data and can compute the optimal
145145
// value from the counts of the two categories.
146146
if (this->bucketWidth() == 0.0) {
147-
objective = [this](double weight) {
147+
// This is the (unique) predicted value for the rows in leaf by the forest
148+
// so far (i.e. without the weight for the leaf we're about to add).
149+
double prediction{m_PredictionMinMax.initialized()
150+
? (m_PredictionMinMax.min() + m_PredictionMinMax.max()) / 2.0
151+
: 0.0};
152+
objective = [prediction, this](double weight) {
153+
double logOdds{prediction + weight};
148154
double c0{m_CategoryCounts(0)};
149155
double c1{m_CategoryCounts(1)};
150156
return this->lambda() * CTools::pow2(weight) -
151-
c0 * logOneMinusLogistic(weight) - c1 * logLogistic(weight);
157+
c0 * logOneMinusLogistic(logOdds) - c1 * logLogistic(logOdds);
152158
};
153159

154160
// Weight shrinkage means the optimal weight will be somewhere between
@@ -158,8 +164,8 @@ CArgMinLogisticImpl::TDoubleVector CArgMinLogisticImpl::value() const {
158164
double empiricalProbabilityC1{c1 / (c0 + c1)};
159165
double empiricalLogOddsC1{
160166
std::log(empiricalProbabilityC1 / (1.0 - empiricalProbabilityC1))};
161-
minWeight = empiricalProbabilityC1 < 0.5 ? empiricalLogOddsC1 : 0.0;
162-
maxWeight = empiricalProbabilityC1 < 0.5 ? 0.0 : empiricalLogOddsC1;
167+
minWeight = (empiricalProbabilityC1 < 0.5 ? empiricalLogOddsC1 : 0.0) - prediction;
168+
maxWeight = (empiricalProbabilityC1 < 0.5 ? 0.0 : empiricalLogOddsC1) - prediction;
163169

164170
} else {
165171
objective = [this](double weight) {
@@ -200,6 +206,7 @@ CArgMinLogisticImpl::TDoubleVector CArgMinLogisticImpl::value() const {
200206
return result;
201207
}
202208
}
209+
203210
namespace boosted_tree {
204211

205212
CArgMinLoss::CArgMinLoss(const CArgMinLoss& other)
@@ -338,4 +345,4 @@ const std::string& CBinomialLogistic::name() const {
338345
const std::string CBinomialLogistic::NAME{"binomial_logistic"};
339346
}
340347
}
341-
}
348+
}

lib/maths/unittest/CBoostedTreeTest.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ BOOST_AUTO_TEST_CASE(testLogisticRegression) {
12141214
LOG_DEBUG(<< "log relative error = "
12151215
<< maths::CBasicStatistics::mean(logRelativeError));
12161216

1217-
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.7);
1217+
BOOST_TEST_REQUIRE(maths::CBasicStatistics::mean(logRelativeError) < 0.71);
12181218
meanLogRelativeError.add(maths::CBasicStatistics::mean(logRelativeError));
12191219
}
12201220

@@ -1307,7 +1307,7 @@ BOOST_AUTO_TEST_CASE(testImbalancedClasses) {
13071307
LOG_DEBUG(<< "recalls = " << core::CContainerPrinter::print(recalls));
13081308

13091309
BOOST_TEST_REQUIRE(std::fabs(precisions[0] - precisions[1]) < 0.1);
1310-
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.15);
1310+
BOOST_TEST_REQUIRE(std::fabs(recalls[0] - recalls[1]) < 0.13);
13111311
}
13121312

13131313
BOOST_AUTO_TEST_CASE(testEstimateMemoryUsedByTrain) {

0 commit comments

Comments
 (0)