Skip to content

Commit e527aaa

Browse files
authored
[ML] Allow configuration of prediction column name in data_frame_analyzer (#587)
1 parent e237ef2 commit e527aaa

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

include/api/CDataFrameBoostedTreeRunner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
5454
private:
5555
// Note custom config is written directly to the factory object.
5656

57-
std::string m_DependentVariable;
57+
std::string m_DependentVariableFieldName;
58+
std::string m_PredictionFieldName;
5859
TBoostedTreeFactoryUPtr m_BoostedTreeFactory;
5960
TBoostedTreeUPtr m_BoostedTree;
6061
};

lib/api/CDataFrameBoostedTreeRunner.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace ml {
2323
namespace api {
2424
namespace {
2525
// Configuration
26-
const std::string DEPENDENT_VARIABLE{"dependent_variable"};
26+
const std::string DEPENDENT_VARIABLE_NAME{"dependent_variable"};
27+
const std::string PREDICTION_FIELD_NAME{"prediction_field_name"};
2728
const std::string LAMBDA{"lambda"};
2829
const std::string GAMMA{"gamma"};
2930
const std::string ETA{"eta"};
@@ -32,8 +33,10 @@ const std::string FEATURE_BAG_FRACTION{"feature_bag_fraction"};
3233

3334
const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
3435
CDataFrameAnalysisConfigReader theReader;
35-
theReader.addParameter(DEPENDENT_VARIABLE,
36+
theReader.addParameter(DEPENDENT_VARIABLE_NAME,
3637
CDataFrameAnalysisConfigReader::E_RequiredParameter);
38+
theReader.addParameter(PREDICTION_FIELD_NAME,
39+
CDataFrameAnalysisConfigReader::E_OptionalParameter);
3740
// TODO objective function, support train and predict.
3841
theReader.addParameter(LAMBDA, CDataFrameAnalysisConfigReader::E_OptionalParameter);
3942
theReader.addParameter(GAMMA, CDataFrameAnalysisConfigReader::E_OptionalParameter);
@@ -44,9 +47,6 @@ const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
4447
CDataFrameAnalysisConfigReader::E_OptionalParameter);
4548
return theReader;
4649
}()};
47-
48-
// Output
49-
const std::string PREDICTION{"prediction"};
5050
}
5151

5252
CDataFrameBoostedTreeRunner::CDataFrameBoostedTreeRunner(const CDataFrameAnalysisSpecification& spec,
@@ -55,7 +55,10 @@ CDataFrameBoostedTreeRunner::CDataFrameBoostedTreeRunner(const CDataFrameAnalysi
5555

5656
auto parameters = PARAMETER_READER.read(jsonParameters);
5757

58-
m_DependentVariable = parameters[DEPENDENT_VARIABLE].as<std::string>();
58+
m_DependentVariableFieldName = parameters[DEPENDENT_VARIABLE_NAME].as<std::string>();
59+
60+
m_PredictionFieldName = parameters[PREDICTION_FIELD_NAME].fallback(
61+
m_DependentVariableFieldName + "_prediction");
5962

6063
std::size_t maximumNumberTrees{
6164
parameters[MAXIMUM_NUMBER_TREES].fallback(std::size_t{0})};
@@ -117,21 +120,20 @@ void CDataFrameBoostedTreeRunner::writeOneRow(const TStrVec&,
117120
HANDLE_FATAL(<< "Internal error: boosted tree object missing. Please report this error.");
118121
} else {
119122
writer.StartObject();
120-
writer.Key(PREDICTION);
123+
writer.Key(m_PredictionFieldName);
121124
writer.Double(row[m_BoostedTree->columnHoldingPrediction(row.numberColumns())]);
122125
writer.EndObject();
123126
}
124127
}
125128

126129
void CDataFrameBoostedTreeRunner::runImpl(const TStrVec& featureNames,
127130
core::CDataFrame& frame) {
128-
auto dependentVariableColumn =
129-
std::find(featureNames.begin(), featureNames.end(), m_DependentVariable);
131+
auto dependentVariableColumn = std::find(
132+
featureNames.begin(), featureNames.end(), m_DependentVariableFieldName);
130133
if (dependentVariableColumn == featureNames.end()) {
131134
HANDLE_FATAL(<< "Input error: supplied variable to predict '"
132-
<< m_DependentVariable << "' is missing from training data "
133-
<< core::CContainerPrinter::print(featureNames)
134-
<< ". Please report this problem.");
135+
<< m_DependentVariableFieldName << "' is missing from training"
136+
<< " data " << core::CContainerPrinter::print(featureNames));
135137
} else {
136138
m_BoostedTree = m_BoostedTreeFactory->buildFor(
137139
frame, dependentVariableColumn - featureNames.begin());

lib/api/unittest/CDataFrameAnalyzerTest.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ regressionSpec(std::string dependentVariable,
117117
rows, 5, memoryLimit, 1, categoricalFieldNames, true,
118118
test::CTestTmpDir::tmpDir(), "ml", "regression", parameters)};
119119

120-
LOG_DEBUG(<< "spec =\n" << spec);
120+
LOG_TRACE(<< "spec =\n" << spec);
121121

122122
return std::make_unique<api::CDataFrameAnalysisSpecification>(spec);
123123
}
@@ -294,7 +294,7 @@ void CDataFrameAnalyzerTest::testWithoutControlMessages() {
294294
analyzer.run();
295295

296296
rapidjson::Document results;
297-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
297+
rapidjson::ParseResult ok(results.Parse(output.str()));
298298
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
299299

300300
auto expectedScore = expectedScores.begin();
@@ -340,7 +340,7 @@ void CDataFrameAnalyzerTest::testRunOutlierDetection() {
340340
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
341341

342342
rapidjson::Document results;
343-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
343+
rapidjson::ParseResult ok(results.Parse(output.str()));
344344
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
345345

346346
auto expectedScore = expectedScores.begin();
@@ -394,7 +394,7 @@ void CDataFrameAnalyzerTest::testRunOutlierDetectionPartitioned() {
394394
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
395395

396396
rapidjson::Document results;
397-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
397+
rapidjson::ParseResult ok(results.Parse(output.str()));
398398
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
399399

400400
auto expectedScore = expectedScores.begin();
@@ -441,7 +441,7 @@ void CDataFrameAnalyzerTest::testRunOutlierFeatureInfluences() {
441441
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
442442

443443
rapidjson::Document results;
444-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
444+
rapidjson::ParseResult ok(results.Parse(output.str()));
445445
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
446446

447447
auto expectedFeatureInfluence = expectedFeatureInfluences.begin();
@@ -492,7 +492,7 @@ void CDataFrameAnalyzerTest::testRunOutlierDetectionWithParams() {
492492
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
493493

494494
rapidjson::Document results;
495-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
495+
rapidjson::ParseResult ok(results.Parse(output.str()));
496496
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
497497

498498
auto expectedScore = expectedScores.begin();
@@ -530,7 +530,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTraining() {
530530
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
531531

532532
rapidjson::Document results;
533-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
533+
rapidjson::ParseResult ok(results.Parse(output.str()));
534534
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
535535

536536
auto expectedPrediction = expectedPredictions.begin();
@@ -540,7 +540,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTraining() {
540540
CPPUNIT_ASSERT(expectedPrediction != expectedPredictions.end());
541541
CPPUNIT_ASSERT_DOUBLES_EQUAL(
542542
*expectedPrediction,
543-
result["row_results"]["results"]["ml"]["prediction"].GetDouble(),
543+
result["row_results"]["results"]["ml"]["c5_prediction"].GetDouble(),
544544
1e-4 * std::fabs(*expectedPrediction));
545545
++expectedPrediction;
546546
CPPUNIT_ASSERT(result.HasMember("progress_percent") == false);
@@ -584,7 +584,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTrainingWithParams() {
584584
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});
585585

586586
rapidjson::Document results;
587-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
587+
rapidjson::ParseResult ok(results.Parse(output.str()));
588588
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
589589

590590
auto expectedPrediction = expectedPredictions.begin();
@@ -594,7 +594,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTrainingWithParams() {
594594
CPPUNIT_ASSERT(expectedPrediction != expectedPredictions.end());
595595
CPPUNIT_ASSERT_DOUBLES_EQUAL(
596596
*expectedPrediction,
597-
result["row_results"]["results"]["ml"]["prediction"].GetDouble(),
597+
result["row_results"]["results"]["ml"]["c5_prediction"].GetDouble(),
598598
1e-4 * std::fabs(*expectedPrediction));
599599
++expectedPrediction;
600600
CPPUNIT_ASSERT(result.HasMember("progress_percent") == false);
@@ -760,7 +760,7 @@ void CDataFrameAnalyzerTest::testRoundTripDocHashes() {
760760
{"", "", "", "", "", "", "$"});
761761

762762
rapidjson::Document results;
763-
rapidjson::ParseResult ok(results.Parse(output.str().c_str()));
763+
rapidjson::ParseResult ok(results.Parse(output.str()));
764764
CPPUNIT_ASSERT(static_cast<bool>(ok) == true);
765765

766766
int expectedHash{0};

0 commit comments

Comments
 (0)