Skip to content

Commit aec1e08

Browse files
authored
[7.5][ML] Encapsulate encoding logic (elastic#689) (elastic#708)
This PR moves the logic of categorical encoding into a map and replaces online computations in operator[] by a look-up. Backport to elastic#689
1 parent 89b9723 commit aec1e08

9 files changed

+705
-489
lines changed

include/core/RestoreMacros.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ namespace core {
2323
continue; \
2424
}
2525

26+
#define RESTORE_NO_LOOP(tag, restore) \
27+
if (name == tag) { \
28+
if ((restore) == false) { \
29+
if (traverser.value().empty()) { \
30+
LOG_ERROR(<< "Failed to restore " #tag); \
31+
} else { \
32+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
33+
} \
34+
return false; \
35+
} \
36+
}
37+
2638
#define RESTORE_BUILT_IN(tag, target) \
2739
if (name == tag) { \
2840
if (core::CStringUtils::stringToType(traverser.value(), target) == false) { \

include/maths/CBoostedTree.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
184184
//! Populate the object from serialized data.
185185
bool acceptRestoreTraverser(core::CStateRestoreTraverser& traverser);
186186

187+
CBoostedTree& operator=(const CBoostedTree&) = delete;
188+
187189
private:
188190
using TImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
189191

include/maths/CDataFrameCategoryEncoder.h

Lines changed: 171 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <boost/unordered_set.hpp>
1919

2020
#include <cstdint>
21+
#include <memory>
2122
#include <utility>
2223
#include <vector>
2324

@@ -44,7 +45,7 @@ class MATHS_EXPORT CEncodedDataFrameRowRef final {
4445
CEncodedDataFrameRowRef(TRowRef row, const CDataFrameCategoryEncoder& encoder);
4546

4647
//! Get column \p i value.
47-
CFloatStorage operator[](std::size_t i) const;
48+
CFloatStorage operator[](std::size_t encodedColumnIndex) const;
4849

4950
//! Get the row's index.
5051
std::size_t index() const;
@@ -72,6 +73,14 @@ class MATHS_EXPORT CEncodedDataFrameRowRef final {
7273

7374
class CMakeDataFrameCategoryEncoder;
7475

76+
//! Encoding styles considered.
77+
enum EEncoding {
78+
E_OneHot = 0,
79+
E_Frequency,
80+
E_TargetMean,
81+
E_IdentityEncoding // This must stay at the end
82+
};
83+
7584
//! \brief Performs encoding of the categorical columns in a data frame.
7685
//!
7786
//! DESCRIPTION:\n
@@ -85,60 +94,125 @@ class CMakeDataFrameCategoryEncoder;
8594
//! the number of features we use in total based on the quantity of training data.
8695
class MATHS_EXPORT CDataFrameCategoryEncoder final {
8796
public:
88-
using TBoolVec = std::vector<bool>;
8997
using TDoubleVec = std::vector<double>;
90-
using TDoubleVecVec = std::vector<TDoubleVec>;
9198
using TSizeVec = std::vector<std::size_t>;
92-
using TSizeVecVec = std::vector<TSizeVec>;
9399
using TRowRef = core::CDataFrame::TRowRef;
94100

101+
//! \brief Base type of category encodings.
102+
class MATHS_EXPORT CEncoding {
103+
public:
104+
CEncoding(std::size_t inputColumnIndex, double mic);
105+
virtual ~CEncoding() = default;
106+
virtual EEncoding type() const = 0;
107+
virtual double encode(double value) const = 0;
108+
virtual std::uint64_t checksum() const = 0;
109+
virtual bool isBinary() const = 0;
110+
//! return encoding type as string
111+
virtual std::string typeString() const = 0;
112+
113+
std::size_t inputColumnIndex() const;
114+
double encode(const TRowRef& row) const;
115+
double mic() const;
116+
//! Persist by passing information to \p inserter.
117+
void acceptPersistInserter(core::CStatePersistInserter& inserter) const;
118+
//! Populate the object from serialized data.
119+
bool acceptRestoreTraverser(core::CStateRestoreTraverser& traverser);
120+
121+
protected:
122+
std::size_t m_InputColumnIndex;
123+
double m_Mic;
124+
125+
private:
126+
virtual void
127+
acceptPersistInserterForDerivedTypeState(core::CStatePersistInserter& inserter) const = 0;
128+
virtual bool
129+
acceptRestoreTraverserForDerivedTypeState(core::CStateRestoreTraverser& traverser) = 0;
130+
};
131+
132+
using TEncodingUPtr = std::unique_ptr<CEncoding>;
133+
using TEncodingUPtrVec = std::vector<TEncodingUPtr>;
134+
135+
//! \brief Returns the supplied value.
136+
class MATHS_EXPORT CIdentityEncoding : public CEncoding {
137+
public:
138+
CIdentityEncoding(std::size_t inputColumnIndex, double mic);
139+
EEncoding type() const override;
140+
double encode(double value) const override;
141+
bool isBinary() const override;
142+
std::uint64_t checksum() const override;
143+
std::string typeString() const override;
144+
145+
private:
146+
void acceptPersistInserterForDerivedTypeState(core::CStatePersistInserter& inserter) const override;
147+
bool acceptRestoreTraverserForDerivedTypeState(core::CStateRestoreTraverser& traverser) override;
148+
};
149+
150+
//! \brief One-hot encoding.
151+
class MATHS_EXPORT COneHotEncoding : public CEncoding {
152+
public:
153+
COneHotEncoding(std::size_t inputColumnIndex, double mic, std::size_t hotCategory);
154+
EEncoding type() const override;
155+
double encode(double value) const override;
156+
bool isBinary() const override;
157+
std::uint64_t checksum() const override;
158+
std::string typeString() const override;
159+
160+
private:
161+
void acceptPersistInserterForDerivedTypeState(core::CStatePersistInserter& inserter) const override;
162+
bool acceptRestoreTraverserForDerivedTypeState(core::CStateRestoreTraverser& traverser) override;
163+
164+
private:
165+
std::size_t m_HotCategory;
166+
};
167+
168+
//! \brief Looks up the encoding in a map.
169+
class MATHS_EXPORT CMappedEncoding : public CEncoding {
170+
public:
171+
CMappedEncoding(std::size_t inputColumnIndex,
172+
double mic,
173+
EEncoding encoding,
174+
const TDoubleVec& map,
175+
double fallback);
176+
EEncoding type() const override;
177+
double encode(double value) const override;
178+
bool isBinary() const override;
179+
std::uint64_t checksum() const override;
180+
std::string typeString() const override;
181+
182+
private:
183+
void acceptPersistInserterForDerivedTypeState(core::CStatePersistInserter& inserter) const override;
184+
bool acceptRestoreTraverserForDerivedTypeState(core::CStateRestoreTraverser& traverser) override;
185+
186+
private:
187+
EEncoding m_Encoding;
188+
TDoubleVec m_Map;
189+
double m_Fallback;
190+
bool m_Binary;
191+
};
192+
95193
public:
96-
CDataFrameCategoryEncoder(const CMakeDataFrameCategoryEncoder& parameters);
194+
CDataFrameCategoryEncoder(CMakeDataFrameCategoryEncoder parameters);
97195

98196
//! Initialize from serialized data.
99197
CDataFrameCategoryEncoder(core::CStateRestoreTraverser& traverser);
100198

199+
CDataFrameCategoryEncoder(const CDataFrameCategoryEncoder&) = delete;
200+
CDataFrameCategoryEncoder& operator=(const CDataFrameCategoryEncoder&) = delete;
201+
101202
//! Get a row reference which encodes the categories in \p row.
102203
CEncodedDataFrameRowRef encode(TRowRef row) const;
103204

104-
//! Check if \p feature is categorical.
105-
bool columnIsCategorical(std::size_t feature) const;
106-
107205
//! Get the MICs of the selected features.
108-
const TDoubleVec& featureMics() const;
206+
TDoubleVec encodedColumnMics() const;
109207

110208
//! Get the total number of dimensions in the feature vector.
111-
std::size_t numberFeatures() const;
112-
113-
//! Get the encoding offset in feature vector of \p index.
114-
std::size_t encoding(std::size_t index) const;
115-
116-
//! Get the data frame column of \p index into the feature vector.
117-
std::size_t column(std::size_t index) const;
118-
119-
//! Check if \p index is a binary encoded feature.
120-
bool isBinary(std::size_t index) const;
121-
122-
//! Get the number of one-hot encoded categories for \p feature.
123-
std::size_t numberOneHotEncodedCategories(std::size_t feature) const;
124-
125-
//! Check if \p category of \p feature uses one-hot encoding.
126-
bool usesOneHotEncoding(std::size_t feature, std::size_t category) const;
127-
128-
//! Check if feature with encoding \p encoding is one for \p category of \p feature.
129-
bool isHot(std::size_t encoding, std::size_t feature, std::size_t category) const;
130-
131-
//! Check if \p feature uses frequency encoding.
132-
bool usesFrequencyEncoding(std::size_t feature) const;
133-
134-
//! Check if \p category of \p feature is a rare category.
135-
bool isRareCategory(std::size_t feature, std::size_t category) const;
209+
std::size_t numberEncodedColumns() const;
136210

137-
//! Get the frequency of \p category of \p feature.
138-
double frequency(std::size_t feature, std::size_t category) const;
211+
//! Get the encoded feature at position \p encodedColumnIndex.
212+
const CEncoding& encoding(std::size_t encodedColumnIndex) const;
139213

140-
//! Get the mean value of the target variable for \p category of \p feature.
141-
double targetMeanValue(std::size_t feature, std::size_t category) const;
214+
//! Check if \p encodedColumnIndex is a binary encoded feature.
215+
bool isBinary(std::size_t encodedColumnIndex) const;
142216

143217
//! Get a checksum of the state of this object seeded with \p seed.
144218
std::uint64_t checksum(std::uint64_t seed = 0) const;
@@ -150,64 +224,21 @@ class MATHS_EXPORT CDataFrameCategoryEncoder final {
150224
bool acceptRestoreTraverser(core::CStateRestoreTraverser& traverser);
151225

152226
private:
153-
using TSizeDoublePr = std::pair<std::size_t, double>;
154-
using TSizeDoublePrVec = std::vector<TSizeDoublePr>;
155-
using TSizeDoublePrVecVec = std::vector<TSizeDoublePrVec>;
156-
using TSizeSizePr = std::pair<std::size_t, std::size_t>;
157-
using TSizeSizePrDoubleMap = std::map<TSizeSizePr, double>;
158-
using TSizeUSet = boost::unordered_set<std::size_t>;
159-
using TSizeUSetVec = std::vector<TSizeUSet>;
160-
161-
private:
162-
TSizeDoublePrVecVec mics(std::size_t numberThreads,
163-
const core::CDataFrame& frame,
164-
const CDataFrameUtils::CColumnValue& target,
165-
const core::CPackedBitVector& rowMask,
166-
const TSizeVec& metricColumnMask,
167-
const TSizeVec& categoricalColumnMask) const;
168-
void setupFrequencyEncoding(std::size_t numberThreads,
169-
const core::CDataFrame& frame,
170-
const core::CPackedBitVector& rowMask,
171-
const TSizeVec& categoricalColumnMask);
172-
void setupTargetMeanValueEncoding(std::size_t numberThreads,
173-
const core::CDataFrame& frame,
174-
const core::CPackedBitVector& rowMask,
175-
const TSizeVec& categoricalColumnMask,
176-
std::size_t targetColumn);
177-
TSizeSizePrDoubleMap selectFeatures(std::size_t numberThreads,
178-
const core::CDataFrame& frame,
179-
const core::CPackedBitVector& rowMask,
180-
TSizeVec metricColumnMask,
181-
TSizeVec categoricalColumnMask,
182-
std::size_t targetColumn);
183-
TSizeSizePrDoubleMap selectAllFeatures(const TSizeDoublePrVecVec& mics);
184-
void finishEncoding(std::size_t targetColumn, TSizeSizePrDoubleMap selectedFeatureMics);
185-
void discardNuisanceFeatures(TSizeDoublePrVecVec& mics) const;
186-
std::size_t numberAvailableFeatures(const TSizeDoublePrVecVec& mics) const;
227+
bool restoreEncodings(core::CStateRestoreTraverser& traverser);
228+
template<typename T, typename... Args>
229+
bool forwardRestoreEncodings(core::CStateRestoreTraverser& traverser, Args&&... args);
187230

188231
private:
189-
std::size_t m_MinimumRowsPerFeature;
190-
double m_MinimumFrequencyToOneHotEncode;
191-
double m_MinimumRelativeMicToSelectFeature;
192-
double m_RedundancyWeight;
193-
TBoolVec m_ColumnIsCategorical;
194-
TBoolVec m_ColumnUsesFrequencyEncoding;
195-
TSizeVecVec m_OneHotEncodedCategories;
196-
TSizeUSetVec m_RareCategories;
197-
TDoubleVecVec m_CategoryFrequencies;
198-
TDoubleVec m_MeanCategoryFrequencies;
199-
TDoubleVecVec m_CategoryTargetMeanValues;
200-
TDoubleVec m_MeanCategoryTargetMeanValues;
201-
TDoubleVec m_FeatureVectorMics;
202-
TSizeVec m_FeatureVectorColumnMap;
203-
TSizeVec m_FeatureVectorEncodingMap;
232+
TEncodingUPtrVec m_Encodings;
204233
};
205234

206235
//! \brief Implements the named parameter idiom for CDataFrameCategoryEncoder.
207236
class MATHS_EXPORT CMakeDataFrameCategoryEncoder {
208237
public:
238+
using TDoubleVec = std::vector<double>;
209239
using TSizeVec = std::vector<std::size_t>;
210240
using TOptionalDouble = boost::optional<double>;
241+
using TEncodingUPtrVec = CDataFrameCategoryEncoder::TEncodingUPtrVec;
211242

212243
public:
213244
//! The minimum number of training rows needed per feature used.
@@ -233,9 +264,6 @@ class MATHS_EXPORT CMakeDataFrameCategoryEncoder {
233264
const core::CDataFrame& frame,
234265
std::size_t targetColumn);
235266

236-
CMakeDataFrameCategoryEncoder(const CMakeDataFrameCategoryEncoder&) = delete;
237-
CMakeDataFrameCategoryEncoder& operator=(const CMakeDataFrameCategoryEncoder&) = delete;
238-
239267
//! Set the minimum number of training rows needed per feature used.
240268
CMakeDataFrameCategoryEncoder& minimumRowsPerFeature(std::size_t minimumRowsPerFeature);
241269

@@ -260,18 +288,69 @@ class MATHS_EXPORT CMakeDataFrameCategoryEncoder {
260288
//! Set a mask of the columns to include.
261289
CMakeDataFrameCategoryEncoder& columnMask(TSizeVec columnMask);
262290

291+
//! Make the encoding.
292+
TEncodingUPtrVec makeEncodings();
293+
294+
//! \name Test Methods
295+
//@{
296+
//! Get the encoding offset in feature vector of \p index.
297+
std::size_t encoding(std::size_t index) const;
298+
299+
//! Check if \p category of \p inputColumnIndex uses one-hot encoding.
300+
bool usesOneHotEncoding(std::size_t inputColumnIndex, std::size_t category) const;
301+
302+
//! Check if \p category of \p inputColumnIndex is a rare category.
303+
bool isRareCategory(std::size_t inputColumnIndex, std::size_t category) const;
304+
//@}
305+
263306
private:
307+
using TBoolVec = std::vector<bool>;
308+
using TDoubleVecVec = std::vector<TDoubleVec>;
309+
using TSizeVecVec = std::vector<TSizeVec>;
310+
using TSizeDoublePr = std::pair<std::size_t, double>;
311+
using TSizeDoublePrVec = std::vector<TSizeDoublePr>;
312+
using TSizeDoublePrVecVec = std::vector<TSizeDoublePrVec>;
313+
using TSizeSizePr = std::pair<std::size_t, std::size_t>;
314+
using TSizeSizePrDoubleMap = std::map<TSizeSizePr, double>;
315+
using TSizeUSet = boost::unordered_set<std::size_t>;
316+
using TSizeUSetVec = std::vector<TSizeUSet>;
317+
318+
private:
319+
TSizeDoublePrVecVec mics(const CDataFrameUtils::CColumnValue& target,
320+
const TSizeVec& metricColumnMask,
321+
const TSizeVec& categoricalColumnMask) const;
322+
void setupFrequencyEncoding(const TSizeVec& categoricalColumnMask);
323+
void setupTargetMeanValueEncoding(const TSizeVec& categoricalColumnMask);
324+
TSizeSizePrDoubleMap selectFeatures(TSizeVec metricColumnMask,
325+
const TSizeVec& categoricalColumnMask);
326+
TSizeSizePrDoubleMap selectAllFeatures(const TSizeDoublePrVecVec& mics);
327+
void finishEncoding(TSizeSizePrDoubleMap selectedFeatureMics);
328+
void discardNuisanceFeatures(TSizeDoublePrVecVec& mics) const;
329+
std::size_t numberAvailableFeatures(const TSizeDoublePrVecVec& mics) const;
330+
331+
private:
332+
// Begin parameters
264333
std::size_t m_MinimumRowsPerFeature = MINIMUM_ROWS_PER_FEATURE;
265334
double m_MinimumFrequencyToOneHotEncode = MINIMUM_FREQUENCY_TO_ONE_HOT_ENCODE;
266335
double m_MinimumRelativeMicToSelectFeature = MINIMUM_RELATIVE_MIC_TO_SELECT_FEATURE;
267336
double m_RedundancyWeight = REDUNDANCY_WEIGHT;
268337
std::size_t m_NumberThreads;
269-
const core::CDataFrame& m_Frame;
338+
const core::CDataFrame* m_Frame;
270339
core::CPackedBitVector m_RowMask;
271340
TSizeVec m_ColumnMask;
272341
std::size_t m_TargetColumn;
342+
// End parameters
273343

274-
friend CDataFrameCategoryEncoder;
344+
TBoolVec m_InputColumnUsesFrequencyEncoding;
345+
TSizeVecVec m_OneHotEncodedCategories;
346+
TSizeUSetVec m_RareCategories;
347+
TDoubleVecVec m_CategoryFrequencies;
348+
TDoubleVec m_MeanCategoryFrequencies;
349+
TDoubleVecVec m_CategoryTargetMeanValues;
350+
TDoubleVec m_MeanCategoryTargetMeanValues;
351+
TDoubleVec m_EncodedColumnMics;
352+
TSizeVec m_EncodedColumnInputColumnMap;
353+
TSizeVec m_EncodedColumnEncodingMap;
275354
};
276355
}
277356
}

lib/maths/CBoostedTreeFactory.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ void CBoostedTreeFactory::selectFeaturesAndEncodeCategories(const core::CDataFra
203203

204204
void CBoostedTreeFactory::determineFeatureDataTypes(const core::CDataFrame& frame) const {
205205

206-
TSizeVec columnMask(m_TreeImpl->m_Encoder->numberFeatures());
206+
TSizeVec columnMask(m_TreeImpl->m_Encoder->numberEncodedColumns());
207207
std::iota(columnMask.begin(), columnMask.end(), 0);
208208
columnMask.erase(std::remove_if(columnMask.begin(), columnMask.end(),
209209
[this](std::size_t index) {
@@ -220,7 +220,7 @@ bool CBoostedTreeFactory::initializeFeatureSampleDistribution() const {
220220

221221
// Compute feature sample probabilities.
222222

223-
TDoubleVec mics(m_TreeImpl->m_Encoder->featureMics());
223+
TDoubleVec mics(m_TreeImpl->m_Encoder->encodedColumnMics());
224224
LOG_TRACE(<< "candidate regressors MICe = " << core::CContainerPrinter::print(mics));
225225

226226
if (mics.size() > 0) {

0 commit comments

Comments
 (0)