Skip to content

Commit 0a95ace

Browse files
authored
Introduce classification analysis runner. (#701)
1 parent 55b9f3e commit 0a95ace

22 files changed

+610
-109
lines changed

include/api/CDataFrameAnalysisRunner.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ class CMemoryUsageEstimationResultJsonWriter;
5959
//! early to determine how to implement a good cooperative interrupt scheme.
6060
class API_EXPORT CDataFrameAnalysisRunner {
6161
public:
62+
using TBoolVec = std::vector<bool>;
6263
using TStrVec = std::vector<std::string>;
64+
using TStrVecVec = std::vector<TStrVec>;
6365
using TRowRef = core::data_frame_detail::CRowRef;
6466
using TProgressRecorder = std::function<void(double)>;
6567

@@ -98,6 +100,9 @@ class API_EXPORT CDataFrameAnalysisRunner {
98100
//! \return The number of columns this analysis appends.
99101
virtual std::size_t numberExtraColumns() const = 0;
100102

103+
//! \return Indicator of columns for which empty value should be treated as missing.
104+
virtual TBoolVec columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const;
105+
101106
//! Write the extra columns of \p row added by the analysis to \p writer.
102107
//!
103108
//! This should create a new object of the form:
@@ -114,6 +119,7 @@ class API_EXPORT CDataFrameAnalysisRunner {
114119
//! \param[in] row The row to write the columns added by this analysis.
115120
//! \param[in,out] writer The stream to which to write the extra columns.
116121
virtual void writeOneRow(const TStrVec& featureNames,
122+
const TStrVecVec& categoricalFieldValues,
117123
TRowRef row,
118124
core::CRapidJsonConcurrentLineWriter& writer) const = 0;
119125

@@ -189,12 +195,12 @@ class API_EXPORT CDataFrameAnalysisRunnerFactory {
189195

190196
TRunnerUPtr make(const CDataFrameAnalysisSpecification& spec) const;
191197
TRunnerUPtr make(const CDataFrameAnalysisSpecification& spec,
192-
const rapidjson::Value& params) const;
198+
const rapidjson::Value& jsonParameters) const;
193199

194200
private:
195201
virtual TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const = 0;
196202
virtual TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec,
197-
const rapidjson::Value& params) const = 0;
203+
const rapidjson::Value& jsonParameters) const = 0;
198204
};
199205
}
200206
}

include/api/CDataFrameAnalysisSpecification.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace api {
4343
//! performance statistics.
4444
class API_EXPORT CDataFrameAnalysisSpecification {
4545
public:
46+
using TBoolVec = std::vector<bool>;
4647
using TStrVec = std::vector<std::string>;
4748
using TDataFrameUPtr = std::unique_ptr<core::CDataFrame>;
4849
using TTemporaryDirectoryPtr = std::shared_ptr<core::CTemporaryDirectory>;
@@ -173,6 +174,9 @@ class API_EXPORT CDataFrameAnalysisSpecification {
173174
//! 2. disk is used (only one partition needs to be loaded to main memory)
174175
void estimateMemoryUsage(CMemoryUsageEstimationResultJsonWriter& writer) const;
175176

177+
//! \return Indicator of columns for which empty value should be treated as missing.
178+
TBoolVec columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const;
179+
176180
//! \return shared pointer to the persistence stream.
177181
TDataAdderUPtr persister() const;
178182

include/api/CDataFrameAnalyzer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CDataFrameAnalysisSpecification;
3434
//!
3535
class API_EXPORT CDataFrameAnalyzer {
3636
public:
37+
using TBoolVec = std::vector<bool>;
3738
using TStrVec = std::vector<std::string>;
3839
using TJsonOutputStreamWrapperUPtr = std::unique_ptr<core::CJsonOutputStreamWrapper>;
3940
using TJsonOutputStreamWrapperUPtrSupplier =
@@ -104,6 +105,7 @@ class API_EXPORT CDataFrameAnalyzer {
104105
TDataFrameAnalysisSpecificationUPtr m_AnalysisSpecification;
105106
TStrVec m_CategoricalFieldNames;
106107
TStrSizeUMapVec m_CategoricalFieldValues;
108+
TBoolVec m_EmptyAsMissing;
107109
TDataFrameUPtr m_DataFrame;
108110
TStrVec m_FieldNames;
109111
TTemporaryDirectoryPtr m_DataFrameDirectory;

include/api/CDataFrameBoostedTreeRunner.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <core/CDataSearcher.h>
1111

12+
#include <api/CDataFrameAnalysisConfigReader.h>
1213
#include <api/CDataFrameAnalysisRunner.h>
1314
#include <api/CDataFrameAnalysisSpecification.h>
1415
#include <api/ImportExport.h>
@@ -25,11 +26,11 @@ class CBoostedTreeFactory;
2526
namespace api {
2627

2728
//! \brief Runs boosted tree regression on a core::CDataFrame.
28-
class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRunner {
29+
class API_EXPORT CDataFrameBoostedTreeRunner : public CDataFrameAnalysisRunner {
2930
public:
3031
//! This is not intended to be called directly: use CDataFrameBoostedTreeRunnerFactory.
3132
CDataFrameBoostedTreeRunner(const CDataFrameAnalysisSpecification& spec,
32-
const rapidjson::Value& jsonParameters);
33+
const CDataFrameAnalysisConfigReader::CParameters& parameters);
3334

3435
//! This is not intended to be called directly: use CDataFrameBoostedTreeRunnerFactory.
3536
CDataFrameBoostedTreeRunner(const CDataFrameAnalysisSpecification& spec);
@@ -39,13 +40,20 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
3940
//! \return The number of columns this adds to the data frame.
4041
std::size_t numberExtraColumns() const override;
4142

42-
//! Write the prediction for \p row to \p writer.
43-
void writeOneRow(const TStrVec& featureNames,
44-
TRowRef row,
45-
core::CRapidJsonConcurrentLineWriter& writer) const override;
43+
protected:
44+
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
45+
46+
protected:
47+
//! Parameter reader handling parameters that are shared by subclasses.
48+
static CDataFrameAnalysisConfigReader getParameterReader();
49+
//! Name of dependent variable field.
50+
const std::string& dependentVariableFieldName() const;
51+
//! Name of prediction field.
52+
const std::string& predictionFieldName() const;
53+
//! Underlying boosted tree.
54+
const maths::CBoostedTree& boostedTree() const;
4655

4756
private:
48-
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
4957
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
5058
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
5159
using TMemoryEstimator = std::function<void(std::int64_t)>;
@@ -71,20 +79,6 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
7179
TBoostedTreeUPtr m_BoostedTree;
7280
std::atomic<std::int64_t> m_Memory;
7381
};
74-
75-
//! \brief Makes a core::CDataFrame boosted tree regression runner.
76-
class API_EXPORT CDataFrameBoostedTreeRunnerFactory final : public CDataFrameAnalysisRunnerFactory {
77-
public:
78-
const std::string& name() const override;
79-
80-
private:
81-
static const std::string NAME;
82-
83-
private:
84-
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override;
85-
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec,
86-
const rapidjson::Value& params) const override;
87-
};
8882
}
8983
}
9084

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
#ifndef INCLUDED_ml_api_CDataFrameClassificationRunner_h
8+
#define INCLUDED_ml_api_CDataFrameClassificationRunner_h
9+
10+
#include <core/CDataSearcher.h>
11+
12+
#include <api/CDataFrameAnalysisConfigReader.h>
13+
#include <api/CDataFrameAnalysisRunner.h>
14+
#include <api/CDataFrameAnalysisSpecification.h>
15+
#include <api/CDataFrameBoostedTreeRunner.h>
16+
#include <api/ImportExport.h>
17+
18+
#include <rapidjson/fwd.h>
19+
20+
#include <atomic>
21+
22+
namespace ml {
23+
namespace api {
24+
25+
//! \brief Runs boosted tree classification on a core::CDataFrame.
26+
class API_EXPORT CDataFrameClassificationRunner final : public CDataFrameBoostedTreeRunner {
27+
public:
28+
static const CDataFrameAnalysisConfigReader getParameterReader();
29+
30+
//! This is not intended to be called directly: use CDataFrameClassificationRunnerFactory.
31+
CDataFrameClassificationRunner(const CDataFrameAnalysisSpecification& spec,
32+
const CDataFrameAnalysisConfigReader::CParameters& parameters);
33+
34+
//! This is not intended to be called directly: use CDataFrameClassificationRunnerFactory.
35+
CDataFrameClassificationRunner(const CDataFrameAnalysisSpecification& spec);
36+
37+
//! \return Indicator of columns for which empty value should be treated as missing.
38+
TBoolVec columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const override;
39+
40+
//! Write the prediction for \p row to \p writer.
41+
void writeOneRow(const TStrVec& featureNames,
42+
const TStrVecVec& categoricalFieldValues,
43+
TRowRef row,
44+
core::CRapidJsonConcurrentLineWriter& writer) const override;
45+
46+
private:
47+
std::size_t m_NumTopClasses;
48+
};
49+
50+
//! \brief Makes a core::CDataFrame boosted tree classification runner.
51+
class API_EXPORT CDataFrameClassificationRunnerFactory final
52+
: public CDataFrameAnalysisRunnerFactory {
53+
public:
54+
const std::string& name() const override;
55+
56+
private:
57+
static const std::string NAME;
58+
59+
private:
60+
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override;
61+
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec,
62+
const rapidjson::Value& jsonParameters) const override;
63+
};
64+
}
65+
}
66+
67+
#endif // INCLUDED_ml_api_CDataFrameClassificationRunner_h

include/api/CDataFrameOutliersRunner.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#ifndef INCLUDED_ml_api_CDataFrameOutliersRunner_h
88
#define INCLUDED_ml_api_CDataFrameOutliersRunner_h
99

10+
#include <api/CDataFrameAnalysisConfigReader.h>
1011
#include <api/CDataFrameAnalysisRunner.h>
1112

1213
#include <api/ImportExport.h>
@@ -21,7 +22,7 @@ class API_EXPORT CDataFrameOutliersRunner final : public CDataFrameAnalysisRunne
2122
public:
2223
//! This is not intended to be called directly: use CDataFrameOutliersRunnerFactory.
2324
CDataFrameOutliersRunner(const CDataFrameAnalysisSpecification& spec,
24-
const rapidjson::Value& jsonParameters);
25+
const CDataFrameAnalysisConfigReader::CParameters& parameters);
2526

2627
//! This is not intended to be called directly: use CDataFrameOutliersRunnerFactory.
2728
CDataFrameOutliersRunner(const CDataFrameAnalysisSpecification& spec);
@@ -31,6 +32,7 @@ class API_EXPORT CDataFrameOutliersRunner final : public CDataFrameAnalysisRunne
3132

3233
//! Write the extra columns of \p row added by outlier analysis to \p writer.
3334
void writeOneRow(const TStrVec& featureNames,
35+
const TStrVecVec& categoricalFieldValues,
3436
TRowRef row,
3537
core::CRapidJsonConcurrentLineWriter& writer) const override;
3638

@@ -79,7 +81,7 @@ class API_EXPORT CDataFrameOutliersRunnerFactory final : public CDataFrameAnalys
7981
private:
8082
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override;
8183
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec,
82-
const rapidjson::Value& params) const override;
84+
const rapidjson::Value& jsonParameters) const override;
8385
};
8486
}
8587
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
#ifndef INCLUDED_ml_api_CDataFrameRegressionRunner_h
8+
#define INCLUDED_ml_api_CDataFrameRegressionRunner_h
9+
10+
#include <core/CDataSearcher.h>
11+
12+
#include <api/CDataFrameAnalysisConfigReader.h>
13+
#include <api/CDataFrameAnalysisSpecification.h>
14+
#include <api/CDataFrameBoostedTreeRunner.h>
15+
#include <api/ImportExport.h>
16+
17+
#include <rapidjson/fwd.h>
18+
19+
#include <atomic>
20+
21+
namespace ml {
22+
namespace api {
23+
24+
//! \brief Runs boosted tree regression on a core::CDataFrame.
25+
class API_EXPORT CDataFrameRegressionRunner final : public CDataFrameBoostedTreeRunner {
26+
public:
27+
static const CDataFrameAnalysisConfigReader getParameterReader();
28+
29+
//! This is not intended to be called directly: use CDataFrameRegressionRunnerFactory.
30+
CDataFrameRegressionRunner(const CDataFrameAnalysisSpecification& spec,
31+
const CDataFrameAnalysisConfigReader::CParameters& parameters);
32+
33+
//! This is not intended to be called directly: use CDataFrameRegressionRunnerFactory.
34+
CDataFrameRegressionRunner(const CDataFrameAnalysisSpecification& spec);
35+
36+
//! Write the prediction for \p row to \p writer.
37+
void writeOneRow(const TStrVec& featureNames,
38+
const TStrVecVec& categoricalFieldValues,
39+
TRowRef row,
40+
core::CRapidJsonConcurrentLineWriter& writer) const override;
41+
};
42+
43+
//! \brief Makes a core::CDataFrame boosted tree regression runner.
44+
class API_EXPORT CDataFrameRegressionRunnerFactory final : public CDataFrameAnalysisRunnerFactory {
45+
public:
46+
const std::string& name() const override;
47+
48+
private:
49+
static const std::string NAME;
50+
51+
private:
52+
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override;
53+
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec,
54+
const rapidjson::Value& jsonParameters) const override;
55+
};
56+
}
57+
}
58+
59+
#endif // INCLUDED_ml_api_CDataFrameRegressionRunner_h

lib/api/CDataFrameAnalysisRunner.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
namespace ml {
2525
namespace api {
2626
namespace {
27+
using TBoolVec = std::vector<bool>;
28+
2729
std::size_t maximumNumberPartitions(const CDataFrameAnalysisSpecification& spec) {
2830
// We limit the maximum number of partitions to rows^(1/2) because very
2931
// large numbers of partitions are going to be slow and it is better to tell
@@ -43,6 +45,10 @@ CDataFrameAnalysisRunner::~CDataFrameAnalysisRunner() {
4345
this->waitToFinish();
4446
}
4547

48+
TBoolVec CDataFrameAnalysisRunner::columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const {
49+
return TBoolVec(fieldNames.size(), false);
50+
}
51+
4652
void CDataFrameAnalysisRunner::estimateMemoryUsage(CMemoryUsageEstimationResultJsonWriter& writer) const {
4753
std::size_t numberRows{m_Spec.numberRows()};
4854
std::size_t numberColumns{m_Spec.numberColumns() + this->numberExtraColumns()};
@@ -223,8 +229,8 @@ CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spe
223229

224230
CDataFrameAnalysisRunnerFactory::TRunnerUPtr
225231
CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec,
226-
const rapidjson::Value& params) const {
227-
auto result = this->makeImpl(spec, params);
232+
const rapidjson::Value& jsonParameters) const {
233+
auto result = this->makeImpl(spec, jsonParameters);
228234
result->computeAndSaveExecutionStrategy();
229235
return result;
230236
}

lib/api/CDataFrameAnalysisSpecification.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
#include <core/CLogger.h>
1111

1212
#include <api/CDataFrameAnalysisConfigReader.h>
13-
#include <api/CDataFrameBoostedTreeRunner.h>
13+
#include <api/CDataFrameClassificationRunner.h>
1414
#include <api/CDataFrameOutliersRunner.h>
15+
#include <api/CDataFrameRegressionRunner.h>
1516
#include <api/CMemoryUsageEstimationResultJsonWriter.h>
1617

1718
#include <rapidjson/document.h>
@@ -41,11 +42,13 @@ const std::string CDataFrameAnalysisSpecification::NAME("name");
4142
const std::string CDataFrameAnalysisSpecification::PARAMETERS("parameters");
4243

4344
namespace {
45+
using TBoolVec = std::vector<bool>;
4446
using TRunnerFactoryUPtrVec = ml::api::CDataFrameAnalysisSpecification::TRunnerFactoryUPtrVec;
4547

4648
TRunnerFactoryUPtrVec analysisFactories() {
4749
TRunnerFactoryUPtrVec factories;
48-
factories.push_back(std::make_unique<ml::api::CDataFrameBoostedTreeRunnerFactory>());
50+
factories.push_back(std::make_unique<ml::api::CDataFrameRegressionRunnerFactory>());
51+
factories.push_back(std::make_unique<ml::api::CDataFrameClassificationRunnerFactory>());
4952
factories.push_back(std::make_unique<ml::api::CDataFrameOutliersRunnerFactory>());
5053
// Add new analysis types here.
5154
return factories;
@@ -206,6 +209,14 @@ void CDataFrameAnalysisSpecification::estimateMemoryUsage(CMemoryUsageEstimation
206209
m_Runner->estimateMemoryUsage(writer);
207210
}
208211

212+
TBoolVec CDataFrameAnalysisSpecification::columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const {
213+
if (m_Runner == nullptr) {
214+
HANDLE_FATAL(<< "Internal error: no runner available. Please report this problem.");
215+
return TBoolVec(fieldNames.size(), false);
216+
}
217+
return m_Runner->columnsForWhichEmptyIsMissing(fieldNames);
218+
}
219+
209220
void CDataFrameAnalysisSpecification::initializeRunner(const rapidjson::Value& jsonAnalysis) {
210221
// We pass of the interpretation of the parameters object to the appropriate
211222
// analysis runner.

0 commit comments

Comments
 (0)