Skip to content

[ML] Correct missing category handling #1042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/core/CDataFrame.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <vector>

Expand Down Expand Up @@ -487,7 +488,9 @@ class CORE_EXPORT CDataFrame final {
std::size_t numberColumns);

//! Get the value to use for a missing element in a data frame.
static double valueOfMissing();
static constexpr double valueOfMissing() {
return std::numeric_limits<double>::quiet_NaN();
}

private:
using TStrSizeUMap = boost::unordered_map<std::string, std::size_t>;
Expand Down
18 changes: 18 additions & 0 deletions include/maths/CDataFrameUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,15 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
COneHotCategoricalColumnValue(std::size_t column, std::size_t category)
: CColumnValue{column}, m_Category{category} {}
double operator()(const TRowRef& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
return static_cast<std::size_t>(row[this->column()]) == m_Category ? 1.0 : 0.0;
}
double operator()(const TFloatVec& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
return static_cast<std::size_t>(row[this->column()]) == m_Category ? 1.0 : 0.0;
}
std::size_t hash() const override { return m_Category; }
Expand All @@ -140,10 +146,16 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
CFrequencyCategoricalColumnValue(std::size_t column, const TDoubleVec& frequencies)
: CColumnValue{column}, m_Frequencies{&frequencies} {}
double operator()(const TRowRef& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
std::size_t category{static_cast<std::size_t>(row[this->column()])};
return (*m_Frequencies)[category];
}
double operator()(const TFloatVec& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
std::size_t category{static_cast<std::size_t>(row[this->column()])};
return (*m_Frequencies)[category];
}
Expand All @@ -166,10 +178,16 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
: CColumnValue{column}, m_RareCategories{&rareCategories}, m_TargetMeanValues{&targetMeanValues} {
}
double operator()(const TRowRef& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
std::size_t category{static_cast<std::size_t>(row[this->column()])};
return this->isRare(category) ? 0.0 : (*m_TargetMeanValues)[category];
}
double operator()(const TFloatVec& row) const override {
if (isMissing(row[this->column()])) {
return core::CDataFrame::valueOfMissing();
}
std::size_t category{static_cast<std::size_t>(row[this->column()])};
return this->isRare(category) ? 0.0 : (*m_TargetMeanValues)[category];
}
Expand Down
4 changes: 0 additions & 4 deletions lib/core/CDataFrame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,6 @@ std::size_t CDataFrame::estimateMemoryUsage(bool inMainMemory,
return inMainMemory ? numberRows * numberColumns * sizeof(float) : 0;
}

double CDataFrame::valueOfMissing() {
return std::numeric_limits<double>::quiet_NaN();
}

CDataFrame::TRowFuncVecBoolPr
CDataFrame::parallelApplyToAllRows(std::size_t numberThreads,
std::size_t beginRows,
Expand Down
8 changes: 7 additions & 1 deletion lib/maths/CDataFrameCategoryEncoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <maths/CDataFrameCategoryEncoder.h>

#include <core/CContainerPrinter.h>
#include <core/CDataFrame.h>
#include <core/CLogger.h>
#include <core/CPackedBitVector.h>
#include <core/CPersistUtils.h>
Expand Down Expand Up @@ -452,7 +453,9 @@ EEncoding CDataFrameCategoryEncoder::COneHotEncoding::type() const {
}

double CDataFrameCategoryEncoder::COneHotEncoding::encode(double value) const {
return static_cast<std::size_t>(value) == m_HotCategory;
return CDataFrameUtils::isMissing(value)
? core::CDataFrame::valueOfMissing()
: static_cast<std::size_t>(value) == m_HotCategory;
}

bool CDataFrameCategoryEncoder::COneHotEncoding::isBinary() const {
Expand Down Expand Up @@ -503,6 +506,9 @@ EEncoding CDataFrameCategoryEncoder::CMappedEncoding::type() const {
}

double CDataFrameCategoryEncoder::CMappedEncoding::encode(double value) const {
if (CDataFrameUtils::isMissing(value)) {
return core::CDataFrame::valueOfMissing();
}
std::size_t category{static_cast<std::size_t>(value)};
return category < m_Map.size() ? m_Map[category] : m_Fallback;
}
Expand Down
9 changes: 7 additions & 2 deletions lib/maths/CDataFrameUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,11 @@ CDataFrameUtils::TSizeDoublePrVecVecVec CDataFrameUtils::categoricalMicWithColum
1, 0, frame.numberRows(),
[&](TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
if (isMissing((*row)[i]) || isMissing(target(*row))) {
continue;
}
std::size_t category{static_cast<std::size_t>((*row)[i])};
if (frequencies[i][category] >= minimumFrequency &&
isMissing(target(*row)) == false) {
if (frequencies[i][category] >= minimumFrequency) {
sampler.sample(*row);
}
}
Expand Down Expand Up @@ -905,6 +907,9 @@ CDataFrameUtils::TSizeDoublePrVecVecVec CDataFrameUtils::categoricalMicWithColum

encoders.clear();
for (const auto& sample : samples) {
if (isMissing(sample[i])) {
continue;
}
std::size_t category{static_cast<std::size_t>(sample[i])};
if (frequencies[i][category] >= minimumFrequency) {
auto encoder = makeEncoder(i, i, category);
Expand Down
Loading