Skip to content

Commit 854ed5d

Browse files
authored
[ML] Fix handling of numerical precision loss in logistic loss gradient and curvature (#1041)
1 parent 5102588 commit 854ed5d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lib/maths/CBoostedTreeLoss.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,9 @@ void CBinomialLogistic::gradient(const TMemoryMappedFloatVector& prediction,
308308
double weight) const {
309309
if (prediction(0) > -LOG_EPSILON && actual == 1.0) {
310310
writer(0, -weight * std::exp(-prediction(0)));
311+
} else {
312+
writer(0, weight * (CTools::logisticFunction(prediction(0)) - actual));
311313
}
312-
writer(0, weight * (CTools::logisticFunction(prediction(0)) - actual));
313314
}
314315

315316
void CBinomialLogistic::curvature(const TMemoryMappedFloatVector& prediction,
@@ -318,9 +319,10 @@ void CBinomialLogistic::curvature(const TMemoryMappedFloatVector& prediction,
318319
double weight) const {
319320
if (prediction(0) > -LOG_EPSILON) {
320321
writer(0, weight * std::exp(-prediction(0)));
322+
} else {
323+
double probability{CTools::logisticFunction(prediction(0))};
324+
writer(0, weight * probability * (1.0 - probability));
321325
}
322-
double probability{CTools::logisticFunction(prediction(0))};
323-
writer(0, weight * probability * (1.0 - probability));
324326
}
325327

326328
bool CBinomialLogistic::isCurvatureConstant() const {

0 commit comments

Comments
 (0)