Skip to content

Commit 1813075

Browse files
committed
[ML] Correct missing category handling (elastic#1042)
1 parent 0ae8ab9 commit 1813075

File tree

6 files changed

+243
-23
lines changed

6 files changed

+243
-23
lines changed

include/core/CDataFrame.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cstdint>
2121
#include <functional>
2222
#include <iterator>
23+
#include <limits>
2324
#include <memory>
2425
#include <vector>
2526

@@ -490,7 +491,9 @@ class CORE_EXPORT CDataFrame final {
490491
std::size_t numberColumns);
491492

492493
//! Get the value to use for a missing element in a data frame.
493-
static double valueOfMissing();
494+
static constexpr double valueOfMissing() {
495+
return std::numeric_limits<double>::quiet_NaN();
496+
}
494497

495498
private:
496499
using TStrSizeUMap = boost::unordered_map<std::string, std::size_t>;

include/maths/CDataFrameUtils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,15 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
122122
COneHotCategoricalColumnValue(std::size_t column, std::size_t category)
123123
: CColumnValue{column}, m_Category{category} {}
124124
double operator()(const TRowRef& row) const override {
125+
if (isMissing(row[this->column()])) {
126+
return core::CDataFrame::valueOfMissing();
127+
}
125128
return static_cast<std::size_t>(row[this->column()]) == m_Category ? 1.0 : 0.0;
126129
}
127130
double operator()(const TFloatVec& row) const override {
131+
if (isMissing(row[this->column()])) {
132+
return core::CDataFrame::valueOfMissing();
133+
}
128134
return static_cast<std::size_t>(row[this->column()]) == m_Category ? 1.0 : 0.0;
129135
}
130136
std::size_t hash() const override { return m_Category; }
@@ -140,10 +146,16 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
140146
CFrequencyCategoricalColumnValue(std::size_t column, const TDoubleVec& frequencies)
141147
: CColumnValue{column}, m_Frequencies{&frequencies} {}
142148
double operator()(const TRowRef& row) const override {
149+
if (isMissing(row[this->column()])) {
150+
return core::CDataFrame::valueOfMissing();
151+
}
143152
std::size_t category{static_cast<std::size_t>(row[this->column()])};
144153
return (*m_Frequencies)[category];
145154
}
146155
double operator()(const TFloatVec& row) const override {
156+
if (isMissing(row[this->column()])) {
157+
return core::CDataFrame::valueOfMissing();
158+
}
147159
std::size_t category{static_cast<std::size_t>(row[this->column()])};
148160
return (*m_Frequencies)[category];
149161
}
@@ -166,10 +178,16 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
166178
: CColumnValue{column}, m_RareCategories{&rareCategories}, m_TargetMeanValues{&targetMeanValues} {
167179
}
168180
double operator()(const TRowRef& row) const override {
181+
if (isMissing(row[this->column()])) {
182+
return core::CDataFrame::valueOfMissing();
183+
}
169184
std::size_t category{static_cast<std::size_t>(row[this->column()])};
170185
return this->isRare(category) ? 0.0 : (*m_TargetMeanValues)[category];
171186
}
172187
double operator()(const TFloatVec& row) const override {
188+
if (isMissing(row[this->column()])) {
189+
return core::CDataFrame::valueOfMissing();
190+
}
173191
std::size_t category{static_cast<std::size_t>(row[this->column()])};
174192
return this->isRare(category) ? 0.0 : (*m_TargetMeanValues)[category];
175193
}

lib/core/CDataFrame.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,6 @@ std::size_t CDataFrame::estimateMemoryUsage(bool inMainMemory,
413413
return inMainMemory ? numberRows * numberColumns * sizeof(float) : 0;
414414
}
415415

416-
double CDataFrame::valueOfMissing() {
417-
return std::numeric_limits<double>::quiet_NaN();
418-
}
419-
420416
CDataFrame::TRowFuncVecBoolPr
421417
CDataFrame::parallelApplyToAllRows(std::size_t numberThreads,
422418
std::size_t beginRows,

lib/maths/CDataFrameCategoryEncoder.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <maths/CDataFrameCategoryEncoder.h>
88

99
#include <core/CContainerPrinter.h>
10+
#include <core/CDataFrame.h>
1011
#include <core/CLogger.h>
1112
#include <core/CPackedBitVector.h>
1213
#include <core/CPersistUtils.h>
@@ -452,7 +453,9 @@ EEncoding CDataFrameCategoryEncoder::COneHotEncoding::type() const {
452453
}
453454

454455
double CDataFrameCategoryEncoder::COneHotEncoding::encode(double value) const {
455-
return static_cast<std::size_t>(value) == m_HotCategory;
456+
return CDataFrameUtils::isMissing(value)
457+
? core::CDataFrame::valueOfMissing()
458+
: static_cast<std::size_t>(value) == m_HotCategory;
456459
}
457460

458461
bool CDataFrameCategoryEncoder::COneHotEncoding::isBinary() const {
@@ -503,6 +506,9 @@ EEncoding CDataFrameCategoryEncoder::CMappedEncoding::type() const {
503506
}
504507

505508
double CDataFrameCategoryEncoder::CMappedEncoding::encode(double value) const {
509+
if (CDataFrameUtils::isMissing(value)) {
510+
return core::CDataFrame::valueOfMissing();
511+
}
506512
std::size_t category{static_cast<std::size_t>(value)};
507513
return category < m_Map.size() ? m_Map[category] : m_Fallback;
508514
}

lib/maths/CDataFrameUtils.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,9 +824,11 @@ CDataFrameUtils::TSizeDoublePrVecVecVec CDataFrameUtils::categoricalMicWithColum
824824
1, 0, frame.numberRows(),
825825
[&](TRowItr beginRows, TRowItr endRows) {
826826
for (auto row = beginRows; row != endRows; ++row) {
827+
if (isMissing((*row)[i]) || isMissing(target(*row))) {
828+
continue;
829+
}
827830
std::size_t category{static_cast<std::size_t>((*row)[i])};
828-
if (frequencies[i][category] >= minimumFrequency &&
829-
isMissing(target(*row)) == false) {
831+
if (frequencies[i][category] >= minimumFrequency) {
830832
sampler.sample(*row);
831833
}
832834
}
@@ -905,6 +907,9 @@ CDataFrameUtils::TSizeDoublePrVecVecVec CDataFrameUtils::categoricalMicWithColum
905907

906908
encoders.clear();
907909
for (const auto& sample : samples) {
910+
if (isMissing(sample[i])) {
911+
continue;
912+
}
908913
std::size_t category{static_cast<std::size_t>(sample[i])};
909914
if (frequencies[i][category] >= minimumFrequency) {
910915
auto encoder = makeEncoder(i, i, category);

0 commit comments

Comments
 (0)