@@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
169
169
[this ](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
170
170
this ->writePredictedCategoryValue (categoryValue, writer);
171
171
});
172
- featureImportance->shap (row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
173
- const TStrVec& featureNames,
174
- const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
175
- writer.Key (FEATURE_IMPORTANCE_FIELD_NAME);
176
- writer.StartArray ();
177
- TDoubleVec baseline;
178
- baseline.reserve (numberClasses);
179
- for (std::size_t j = 0 ; j < shap[0 ].size () && j < numberClasses; ++j) {
180
- baseline.push_back (featureImportance->baseline (j));
181
- }
182
- for (auto i : indices) {
183
- if (shap[i].norm () != 0.0 ) {
184
- writer.StartObject ();
185
- writer.Key (FEATURE_NAME_FIELD_NAME);
186
- writer.String (featureNames[i]);
187
- if (shap[i].size () == 1 ) {
188
- // output feature importance for individual classes in binary case
189
- writer.Key (CLASSES_FIELD_NAME);
190
- writer.StartArray ();
191
- for (std::size_t j = 0 ; j < numberClasses; ++j) {
192
- writer.StartObject ();
193
- writer.Key (CLASS_NAME_FIELD_NAME);
194
- writePredictedCategoryValue (classValues[j], writer);
195
- writer.Key (IMPORTANCE_FIELD_NAME);
196
- if (j == 1 ) {
197
- writer.Double (shap[i](0 ));
198
- } else {
199
- writer.Double (-shap[i](0 ));
172
+ featureImportance->shap (
173
+ row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
174
+ const TStrVec& featureNames,
175
+ const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
176
+ writer.Key (FEATURE_IMPORTANCE_FIELD_NAME);
177
+ writer.StartArray ();
178
+ for (auto i : indices) {
179
+ if (shap[i].norm () != 0.0 ) {
180
+ writer.StartObject ();
181
+ writer.Key (FEATURE_NAME_FIELD_NAME);
182
+ writer.String (featureNames[i]);
183
+ if (shap[i].size () == 1 ) {
184
+ // output feature importance for individual classes in binary case
185
+ writer.Key (CLASSES_FIELD_NAME);
186
+ writer.StartArray ();
187
+ for (std::size_t j = 0 ; j < numberClasses; ++j) {
188
+ writer.StartObject ();
189
+ writer.Key (CLASS_NAME_FIELD_NAME);
190
+ writePredictedCategoryValue (classValues[j], writer);
191
+ writer.Key (IMPORTANCE_FIELD_NAME);
192
+ if (j == 1 ) {
193
+ writer.Double (shap[i](0 ));
194
+ } else {
195
+ writer.Double (-shap[i](0 ));
196
+ }
197
+ writer.EndObject ();
200
198
}
201
- writer.EndObject ();
202
- }
203
- writer.EndArray ();
204
- } else {
205
- // output feature importance for individual classes in multiclass case
206
- writer.Key (CLASSES_FIELD_NAME);
207
- writer.StartArray ();
208
- TDoubleVec featureImportanceSum (numberClasses, 0.0 );
209
- for (std::size_t j = 0 ;
210
- j < shap[i].size () && j < numberClasses; ++j) {
211
- for (auto k : indices) {
212
- featureImportanceSum[j] += shap[k](j);
199
+ writer.EndArray ();
200
+ } else {
201
+ // output feature importance for individual classes in multiclass case
202
+ writer.Key (CLASSES_FIELD_NAME);
203
+ writer.StartArray ();
204
+ for (std::size_t j = 0 ;
205
+ j < shap[i].size () && j < numberClasses; ++j) {
206
+ writer.StartObject ();
207
+ writer.Key (CLASS_NAME_FIELD_NAME);
208
+ writePredictedCategoryValue (classValues[j], writer);
209
+ writer.Key (IMPORTANCE_FIELD_NAME);
210
+ writer.Double (shap[i](j));
211
+ writer.EndObject ();
213
212
}
213
+ writer.EndArray ();
214
214
}
215
- for (std::size_t j = 0 ;
216
- j < shap[i].size () && j < numberClasses; ++j) {
217
- writer.StartObject ();
218
- writer.Key (CLASS_NAME_FIELD_NAME);
219
- writePredictedCategoryValue (classValues[j], writer);
220
- writer.Key (IMPORTANCE_FIELD_NAME);
221
- double correctedShap{
222
- shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0 )};
223
- writer.Double (correctedShap);
224
- writer.EndObject ();
225
- }
226
- writer.EndArray ();
215
+ writer.EndObject ();
227
216
}
228
- writer.EndObject ();
229
217
}
230
- }
231
- writer.EndArray ();
218
+ writer.EndArray ();
232
219
233
- for (std::size_t i = 0 ; i < shap.size (); ++i) {
234
- if (shap[i].lpNorm <1 >() != 0 ) {
235
- const_cast <CDataFrameTrainBoostedTreeClassifierRunner*>(this )
236
- ->m_InferenceModelMetadata .addToFeatureImportance (i, shap[i]);
220
+ for (std::size_t i = 0 ; i < shap.size (); ++i) {
221
+ if (shap[i].lpNorm <1 >() != 0 ) {
222
+ const_cast <CDataFrameTrainBoostedTreeClassifierRunner*>(this )
223
+ ->m_InferenceModelMetadata .addToFeatureImportance (i, shap[i]);
224
+ }
237
225
}
238
- }
239
- });
226
+ });
240
227
}
241
228
writer.EndObject ();
242
229
}
@@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
306
293
307
294
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
308
295
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata () const {
296
+ const auto & featureImportance = this ->boostedTree ().shap ();
297
+ if (featureImportance) {
298
+ m_InferenceModelMetadata.featureImportanceBaseline (featureImportance->baseline ());
299
+ }
309
300
return m_InferenceModelMetadata;
310
301
}
311
302
0 commit comments