Skip to content

Commit 12b5014

Browse files
authored
[ML] Reduce variability in regression and classification results across our target platforms (elastic#1127)
1 parent 5cc1d26 commit 12b5014

16 files changed

+252
-147
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
5050
* Adds new `num_matches` and `preferred_to_categories` fields to category output.
5151
(See {ml-pull}1062[#1062].)
52+
* Reduce variability of classification and regression results across our target operating systems.
53+
(See {ml-pull}1127[#1127].)
5254
* Switched data frame analytics model memory estimates from kilobytes to megabytes.
5355
(See {ml-pull}1126[#1126], issue: {issue}54506[#54506].)
5456

include/core/CIEEE754.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
#include <core/ImportExport.h>
1111

12-
#include <stdint.h>
13-
#include <string.h>
12+
#include <cstdint>
13+
#include <cstring>
1414

1515
namespace ml {
1616
namespace core {
@@ -40,14 +40,14 @@ class CORE_EXPORT CIEEE754 {
4040
//! as an integer.
4141
//! \note The actual "exponent" is "exponent - 1022" in two's complement.
4242
struct SDoubleRep {
43-
#ifdef __sparc // Add any other big endian architectures
44-
uint64_t s_Sign : 1; // sign bit
45-
uint64_t s_Exponent : 11; // exponent
46-
uint64_t s_Mantissa : 52; // mantissa
43+
#ifdef __sparc // Add any other big endian architectures
44+
std::uint64_t s_Sign : 1; // sign bit
45+
std::uint64_t s_Exponent : 11; // exponent
46+
std::uint64_t s_Mantissa : 52; // mantissa
4747
#else
48-
uint64_t s_Mantissa : 52; // mantissa
49-
uint64_t s_Exponent : 11; // exponent
50-
uint64_t s_Sign : 1; // sign bit
48+
std::uint64_t s_Mantissa : 52; // mantissa
49+
std::uint64_t s_Exponent : 11; // exponent
50+
std::uint64_t s_Sign : 1; // sign bit
5151
#endif
5252
};
5353

@@ -57,15 +57,19 @@ class CORE_EXPORT CIEEE754 {
5757
//!
5858
//! \note This is closely related to std::frexp for double but returns
5959
//! the mantissa interpreted as an integer.
60-
static void decompose(double value, uint64_t& mantissa, int& exponent) {
60+
static void decompose(double value, std::uint64_t& mantissa, int& exponent) {
6161
SDoubleRep parsed;
6262
static_assert(sizeof(double) == sizeof(SDoubleRep),
6363
"SDoubleRep definition unsuitable for memcpy to double");
6464
// Use memcpy() rather than union to adhere to strict aliasing rules
65-
::memcpy(&parsed, &value, sizeof(double));
65+
std::memcpy(&parsed, &value, sizeof(double));
6666
exponent = static_cast<int>(parsed.s_Exponent) - 1022;
6767
mantissa = parsed.s_Mantissa;
6868
}
69+
70+
//! Drop \p bits trailing bits from the mantissa of \p value in a stable way
71+
//! for different operating systems.
72+
static double dropbits(double value, int bits);
6973
};
7074
}
7175
}

include/maths/CBoostedTreeLeafNodeStatistics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class MATHS_EXPORT CBoostedTreeLeafNodeStatistics final {
459459

460460
bool operator<(const SSplitStatistics& rhs) const {
461461
return COrderings::lexicographical_compare(
462-
s_Gain, s_Curvature, s_Feature, rhs.s_Gain, rhs.s_Curvature, rhs.s_Feature);
462+
s_Gain, s_Curvature, s_Feature, s_SplitAt, // <- lhs
463+
rhs.s_Gain, rhs.s_Curvature, rhs.s_Feature, rhs.s_SplitAt);
463464
}
464465

465466
std::string print() const {

include/maths/CLbfgs.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define INCLUDED_ml_maths_CLbfgs_h
99

1010
#include <maths/CLinearAlgebraShims.h>
11+
#include <maths/CTools.h>
1112

1213
#include <boost/circular_buffer.hpp>
1314

@@ -302,8 +303,9 @@ class CLbfgs {
302303
return m_BacktrackingMinDecrease * s * las::inner(m_Gx, m_P) / las::norm(m_P);
303304
}
304305

305-
constexpr double minimumStepSize() const {
306-
return std::pow(m_StepScale, static_cast<double>(MAXIMUM_BACK_TRACKING_ITERATIONS));
306+
double minimumStepSize() const {
307+
return CTools::stable(std::pow(
308+
m_StepScale, static_cast<double>(MAXIMUM_BACK_TRACKING_ITERATIONS)));
307309
}
308310

309311
private:

include/maths/CTools.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <array>
2525
#include <cmath>
26+
#include <cstdint>
2627
#include <cstring>
2728
#include <iosfwd>
2829
#include <limits>
@@ -456,7 +457,7 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
456457
// (interpreted as an integer) to the corresponding
457458
// double value and fastLog uses the same approach
458459
// to extract the mantissa.
459-
uint64_t dx = 0x10000000000000ull / BINS;
460+
std::uint64_t dx = 0x10000000000000ull / BINS;
460461
core::CIEEE754::SDoubleRep x;
461462
x.s_Sign = 0;
462463
x.s_Mantissa = (dx / 2) & core::CIEEE754::IEEE754_MANTISSA_MASK;
@@ -469,12 +470,12 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
469470
// Use memcpy() rather than union to adhere to strict
470471
// aliasing rules
471472
std::memcpy(&value, &x, sizeof(double));
472-
m_Table[i] = std::log2(value);
473+
m_Table[i] = stable(std::log2(value));
473474
}
474475
}
475476

476477
//! Lookup log2 for a given mantissa.
477-
const double& operator[](uint64_t mantissa) const {
478+
const double& operator[](std::uint64_t mantissa) const {
478479
return m_Table[mantissa >> FAST_LOG_SHIFT];
479480
}
480481

@@ -494,7 +495,7 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
494495
//! \note This is taken from the approach given in
495496
//! http://www.icsi.berkeley.edu/pubs/techreports/TR-07-002.pdf
496497
static double fastLog(double x) {
497-
uint64_t mantissa;
498+
std::uint64_t mantissa;
498499
int log2;
499500
core::CIEEE754::decompose(x, mantissa, log2);
500501
return 0.693147180559945 * (FAST_LOG_TABLE[mantissa] + log2);
@@ -669,6 +670,15 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
669670
//! Compute \f$x^2\f$.
670671
static double pow2(double x) { return x * x; }
671672

673+
//! Compute a value from \p x which will be stable across platforms.
674+
static double stable(double x) { return core::CIEEE754::dropbits(x, 1); }
675+
676+
//! A version of std::log which is stable across platforms.
677+
static double stableLog(double x) { return stable(std::log(x)); }
678+
679+
//! A version of std::log which is stable across platforms.
680+
static double stableExp(double x) { return stable(std::exp(x)); }
681+
672682
//! Sigmoid function of \p p.
673683
static double sigmoid(double p) { return 1.0 / (1.0 + 1.0 / p); }
674684

@@ -682,7 +692,7 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
682692
//! \param[in] sign Determines whether it's a step up or down.
683693
static double
684694
logisticFunction(double x, double width = 1.0, double x0 = 0.0, double sign = 1.0) {
685-
return sigmoid(std::exp(std::copysign(1.0, sign) * (x - x0) / width));
695+
return sigmoid(stableExp(std::copysign(1.0, sign) * (x - x0) / width));
686696
}
687697

688698
//! Compute the softmax for the multinomial logit values \p logit.
@@ -696,7 +706,7 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
696706
double Z{0.0};
697707
double zmax{*std::max_element(z.begin(), z.end())};
698708
for (auto& zi : z) {
699-
zi = std::exp(zi - zmax);
709+
zi = stableExp(zi - zmax);
700710
Z += zi;
701711
}
702712
for (auto& zi : z) {

include/maths/CToolsDetail.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void CTools::inplaceLogSoftmax(CDenseVector<SCALAR>& z) {
316316
double zmax{z.maxCoeff()};
317317
z.array() -= zmax;
318318
double Z{z.array().exp().sum()};
319-
z.array() -= std::log(Z);
319+
z.array() -= stableLog(Z);
320320
}
321321
}
322322
}

lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
485485
double c1{readShapValue(result, "c1")};
486486
double prediction{
487487
result["row_results"]["results"]["ml"]["target_prediction"].GetDouble()};
488-
// c1 explains 95% of the prediction value, i.e. the difference from the prediction is less than 2%.
489-
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
488+
// c1 explains 94% of the prediction value, i.e. the difference from the prediction is less than 2%.
489+
BOOST_REQUIRE_CLOSE(c1, prediction, 6.0);
490490
for (const auto& feature : {"c2", "c3", "c4"}) {
491491
double c = readShapValue(result, feature);
492492
BOOST_REQUIRE_SMALL(c, 2.0);

lib/core/CIEEE754.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
#include <core/CIEEE754.h>
88

99
#include <cmath>
10+
#include <cstring>
1011

1112
namespace ml {
1213
namespace core {
1314

1415
double CIEEE754::round(double value, EPrecision precision) {
15-
// This first decomposes the value into the mantissa
16-
// and exponent to avoid the problem with overflow if
17-
// the values are close to max double.
16+
// First decomposes the value into the mantissa and exponent to avoid the
17+
// problem with overflow if the values are close to max double.
1818

1919
int exponent;
2020
double mantissa = std::frexp(value, &exponent);
@@ -39,5 +39,15 @@ double CIEEE754::round(double value, EPrecision precision) {
3939

4040
return std::ldexp(mantissa, exponent);
4141
}
42+
43+
double CIEEE754::dropbits(double value, int bits) {
44+
SDoubleRep parsed;
45+
static_assert(sizeof(double) == sizeof(SDoubleRep),
46+
"SDoubleRep definition unsuitable for memcpy to double");
47+
std::memcpy(&parsed, &value, sizeof(double));
48+
parsed.s_Mantissa &= ((IEEE754_MANTISSA_MASK << bits) & IEEE754_MANTISSA_MASK);
49+
std::memcpy(&value, &parsed, sizeof(double));
50+
return value;
51+
}
4252
}
4353
}

0 commit comments

Comments
 (0)