Skip to content

Commit 42d2ee6

Browse files
committed
Merge branch 'master' into controller_responses
2 parents d60d9d8 + 499338d commit 42d2ee6

15 files changed

+1505
-1592
lines changed

include/api/CInferenceModelMetadata.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace api {
2121
//! (such as totol feature importance) into JSON format.
2222
class API_EXPORT CInferenceModelMetadata {
2323
public:
24+
static const std::string JSON_BASELINE_TAG;
25+
static const std::string JSON_FEATURE_IMPORTANCE_BASELINE_TAG;
2426
static const std::string JSON_CLASS_NAME_TAG;
2527
static const std::string JSON_CLASSES_TAG;
2628
static const std::string JSON_FEATURE_NAME_TAG;
@@ -48,19 +50,26 @@ class API_EXPORT CInferenceModelMetadata {
4850
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
4951
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
5052
void addToFeatureImportance(std::size_t i, const TVector& values);
53+
//! Set the feature importance baseline (the individual feature importances are additive corrections
54+
//! to the baseline value).
55+
void featureImportanceBaseline(TVector&& baseline);
5156

5257
private:
53-
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<TVector>::TAccumulator;
58+
using TMeanAccumulator =
59+
std::vector<maths::CBasicStatistics::SSampleMean<double>::TAccumulator>;
5460
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
55-
using TSizeMeanVarAccumulatorUMap = std::unordered_map<std::size_t, TMeanVarAccumulator>;
61+
using TSizeMeanAccumulatorUMap = std::unordered_map<std::size_t, TMeanAccumulator>;
5662
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
63+
using TOptionalVector = boost::optional<TVector>;
5764

5865
private:
5966
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
67+
void writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const;
6068

6169
private:
62-
TSizeMeanVarAccumulatorUMap m_TotalShapValuesMeanVar;
70+
TSizeMeanAccumulatorUMap m_TotalShapValuesMean;
6371
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
72+
TOptionalVector m_ShapBaseline;
6473
TStrVec m_ColumnNames;
6574
TStrVec m_ClassValues;
6675
TPredictionFieldTypeResolverWriter m_PredictionFieldTypeResolverWriter =

include/maths/CTreeShapFeatureImportance.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
7777
const TStrVec& columnNames() const;
7878

7979
//! Get the baseline.
80-
double baseline(std::size_t classIdx = 0) const;
80+
TVector baseline() const;
8181

8282
private:
8383
//! Collects the elements of the path through decision tree that are updated together

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

+52-61
Original file line numberDiff line numberDiff line change
@@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
169169
[this](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
170170
this->writePredictedCategoryValue(categoryValue, writer);
171171
});
172-
featureImportance->shap(row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
173-
const TStrVec& featureNames,
174-
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
175-
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
176-
writer.StartArray();
177-
TDoubleVec baseline;
178-
baseline.reserve(numberClasses);
179-
for (std::size_t j = 0; j < shap[0].size() && j < numberClasses; ++j) {
180-
baseline.push_back(featureImportance->baseline(j));
181-
}
182-
for (auto i : indices) {
183-
if (shap[i].norm() != 0.0) {
184-
writer.StartObject();
185-
writer.Key(FEATURE_NAME_FIELD_NAME);
186-
writer.String(featureNames[i]);
187-
if (shap[i].size() == 1) {
188-
// output feature importance for individual classes in binary case
189-
writer.Key(CLASSES_FIELD_NAME);
190-
writer.StartArray();
191-
for (std::size_t j = 0; j < numberClasses; ++j) {
192-
writer.StartObject();
193-
writer.Key(CLASS_NAME_FIELD_NAME);
194-
writePredictedCategoryValue(classValues[j], writer);
195-
writer.Key(IMPORTANCE_FIELD_NAME);
196-
if (j == 1) {
197-
writer.Double(shap[i](0));
198-
} else {
199-
writer.Double(-shap[i](0));
172+
featureImportance->shap(
173+
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
174+
const TStrVec& featureNames,
175+
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
176+
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
177+
writer.StartArray();
178+
for (auto i : indices) {
179+
if (shap[i].norm() != 0.0) {
180+
writer.StartObject();
181+
writer.Key(FEATURE_NAME_FIELD_NAME);
182+
writer.String(featureNames[i]);
183+
if (shap[i].size() == 1) {
184+
// output feature importance for individual classes in binary case
185+
writer.Key(CLASSES_FIELD_NAME);
186+
writer.StartArray();
187+
for (std::size_t j = 0; j < numberClasses; ++j) {
188+
writer.StartObject();
189+
writer.Key(CLASS_NAME_FIELD_NAME);
190+
writePredictedCategoryValue(classValues[j], writer);
191+
writer.Key(IMPORTANCE_FIELD_NAME);
192+
if (j == 1) {
193+
writer.Double(shap[i](0));
194+
} else {
195+
writer.Double(-shap[i](0));
196+
}
197+
writer.EndObject();
200198
}
201-
writer.EndObject();
202-
}
203-
writer.EndArray();
204-
} else {
205-
// output feature importance for individual classes in multiclass case
206-
writer.Key(CLASSES_FIELD_NAME);
207-
writer.StartArray();
208-
TDoubleVec featureImportanceSum(numberClasses, 0.0);
209-
for (std::size_t j = 0;
210-
j < shap[i].size() && j < numberClasses; ++j) {
211-
for (auto k : indices) {
212-
featureImportanceSum[j] += shap[k](j);
199+
writer.EndArray();
200+
} else {
201+
// output feature importance for individual classes in multiclass case
202+
writer.Key(CLASSES_FIELD_NAME);
203+
writer.StartArray();
204+
for (std::size_t j = 0;
205+
j < shap[i].size() && j < numberClasses; ++j) {
206+
writer.StartObject();
207+
writer.Key(CLASS_NAME_FIELD_NAME);
208+
writePredictedCategoryValue(classValues[j], writer);
209+
writer.Key(IMPORTANCE_FIELD_NAME);
210+
writer.Double(shap[i](j));
211+
writer.EndObject();
213212
}
213+
writer.EndArray();
214214
}
215-
for (std::size_t j = 0;
216-
j < shap[i].size() && j < numberClasses; ++j) {
217-
writer.StartObject();
218-
writer.Key(CLASS_NAME_FIELD_NAME);
219-
writePredictedCategoryValue(classValues[j], writer);
220-
writer.Key(IMPORTANCE_FIELD_NAME);
221-
double correctedShap{
222-
shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0)};
223-
writer.Double(correctedShap);
224-
writer.EndObject();
225-
}
226-
writer.EndArray();
215+
writer.EndObject();
227216
}
228-
writer.EndObject();
229217
}
230-
}
231-
writer.EndArray();
218+
writer.EndArray();
232219

233-
for (std::size_t i = 0; i < shap.size(); ++i) {
234-
if (shap[i].lpNorm<1>() != 0) {
235-
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
236-
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
220+
for (std::size_t i = 0; i < shap.size(); ++i) {
221+
if (shap[i].lpNorm<1>() != 0) {
222+
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
223+
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
224+
}
237225
}
238-
}
239-
});
226+
});
240227
}
241228
writer.EndObject();
242229
}
@@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
306293

307294
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
308295
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
296+
const auto& featureImportance = this->boostedTree().shap();
297+
if (featureImportance) {
298+
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
299+
}
309300
return m_InferenceModelMetadata;
310301
}
311302

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
155155

156156
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
157157
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
158-
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
158+
const auto& featureImportance = this->boostedTree().shap();
159+
if (featureImportance) {
160+
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
161+
}
162+
return m_InferenceModelMetadata;
159163
}
160164

161165
// clang-format off

lib/api/CInferenceModelMetadata.cc

+61-4
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
*/
66
#include <api/CInferenceModelMetadata.h>
77

8+
#include <cmath>
9+
810
namespace ml {
911
namespace api {
1012

1113
void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const {
1214
this->writeTotalFeatureImportance(writer);
15+
this->writeFeatureImportanceBaseline(writer);
1316
}
1417

1518
void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const {
1619
writer.Key(JSON_TOTAL_FEATURE_IMPORTANCE_TAG);
1720
writer.StartArray();
18-
for (const auto& item : m_TotalShapValuesMeanVar) {
21+
for (const auto& item : m_TotalShapValuesMean) {
1922
writer.StartObject();
2023
writer.Key(JSON_FEATURE_NAME_TAG);
2124
writer.String(m_ColumnNames[item.first]);
@@ -86,6 +89,53 @@ void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writ
8689
writer.EndArray();
8790
}
8891

92+
void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const {
93+
if (m_ShapBaseline) {
94+
writer.Key(JSON_FEATURE_IMPORTANCE_BASELINE_TAG);
95+
writer.StartObject();
96+
if (m_ShapBaseline->size() == 1 && m_ClassValues.empty()) {
97+
// Regression
98+
writer.Key(JSON_BASELINE_TAG);
99+
writer.Double(m_ShapBaseline.get()(0));
100+
} else if (m_ShapBaseline->size() == 1 && m_ClassValues.empty() == false) {
101+
// Binary classification
102+
writer.Key(JSON_CLASSES_TAG);
103+
writer.StartArray();
104+
for (std::size_t j = 0; j < m_ClassValues.size(); ++j) {
105+
writer.StartObject();
106+
writer.Key(JSON_CLASS_NAME_TAG);
107+
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
108+
writer.Key(JSON_BASELINE_TAG);
109+
if (j == 1) {
110+
writer.Double(m_ShapBaseline.get()(0));
111+
} else {
112+
writer.Double(-m_ShapBaseline.get()(0));
113+
}
114+
writer.EndObject();
115+
}
116+
117+
writer.EndArray();
118+
119+
} else {
120+
// Multiclass classification
121+
writer.Key(JSON_CLASSES_TAG);
122+
writer.StartArray();
123+
for (std::size_t j = 0; j < static_cast<std::size_t>(m_ShapBaseline->size()) &&
124+
j < m_ClassValues.size();
125+
++j) {
126+
writer.StartObject();
127+
writer.Key(JSON_CLASS_NAME_TAG);
128+
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
129+
writer.Key(JSON_BASELINE_TAG);
130+
writer.Double(m_ShapBaseline.get()(j));
131+
writer.EndObject();
132+
}
133+
writer.EndArray();
134+
}
135+
writer.EndObject();
136+
}
137+
}
138+
89139
const std::string& CInferenceModelMetadata::typeString() const {
90140
return JSON_MODEL_METADATA_TAG;
91141
}
@@ -104,19 +154,26 @@ void CInferenceModelMetadata::predictionFieldTypeResolverWriter(
104154
}
105155

106156
void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVector& values) {
107-
m_TotalShapValuesMeanVar
108-
.emplace(std::make_pair(i, TVector::Zero(values.size())))
109-
.first->second.add(values.cwiseAbs());
157+
auto& meanVector = m_TotalShapValuesMean
158+
.emplace(std::make_pair(i, TMeanAccumulator(values.size())))
159+
.first->second;
110160
auto& minMaxVector =
111161
m_TotalShapValuesMinMax
112162
.emplace(std::make_pair(i, TMinMaxAccumulator(values.size())))
113163
.first->second;
114164
for (std::size_t j = 0; j < minMaxVector.size(); ++j) {
165+
meanVector[j].add(std::fabs(values[j]));
115166
minMaxVector[j].add(values[j]);
116167
}
117168
}
118169

170+
void CInferenceModelMetadata::featureImportanceBaseline(TVector&& baseline) {
171+
m_ShapBaseline = baseline;
172+
}
173+
119174
// clang-format off
175+
const std::string CInferenceModelMetadata::JSON_BASELINE_TAG{"baseline"};
176+
const std::string CInferenceModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG{"feature_importance_baseline"};
120177
const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"};
121178
const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"};
122179
const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"};

0 commit comments

Comments
 (0)