Skip to content

Commit 6565f4f

Browse files
authored
[7.6][ML] Compute SHAP values for supervised learning (elastic#857) (elastic#888)
This PR introduces the computation of SHAP (SHapley Additive exPlanation) values for feature importance. Refer "Consistent Individualized Feature Attribution for Tree Ensembles" by Lundberg et al. for details to the original algorithm.
1 parent 2b81bce commit 6565f4f

24 files changed

+1360
-14
lines changed

include/api/CDataFrameTrainBoostedTreeRunner.h

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#ifndef INCLUDED_ml_api_CDataFrameTrainBoostedTreeRunner_h
88
#define INCLUDED_ml_api_CDataFrameTrainBoostedTreeRunner_h
99

10+
#include <maths/CBasicStatistics.h>
11+
1012
#include <api/CDataFrameAnalysisRunner.h>
1113
#include <api/CDataFrameAnalysisSpecification.h>
1214
#include <api/ImportExport.h>
@@ -44,6 +46,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
4446
static const std::string NUMBER_FOLDS;
4547
static const std::string NUMBER_ROUNDS_PER_HYPERPARAMETER;
4648
static const std::string BAYESIAN_OPTIMISATION_RESTARTS;
49+
static const std::string TOP_SHAP_VALUES;
4750

4851
public:
4952
~CDataFrameTrainBoostedTreeRunner() override;
@@ -57,6 +60,8 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
5760
//! The boosted tree factory.
5861
const maths::CBoostedTreeFactory& boostedTreeFactory() const;
5962

63+
std::size_t topShapValues() const;
64+
6065
protected:
6166
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
6267
using TLossFunctionUPtr = std::unique_ptr<maths::boosted_tree::CLoss>;
@@ -76,6 +81,14 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
7681
//! The boosted tree factory.
7782
maths::CBoostedTreeFactory& boostedTreeFactory();
7883

84+
//! Factory for the largest SHAP value accumulator.
85+
template<typename LESS>
86+
maths::CBasicStatistics::COrderStatisticsHeap<std::size_t, LESS>
87+
makeLargestShapAccumulator(std::size_t n, LESS less) const {
88+
return maths::CBasicStatistics::COrderStatisticsHeap<std::size_t, LESS>{
89+
n, std::size_t{}, less};
90+
};
91+
7992
private:
8093
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
8194
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;

include/maths/CBoostedTree.h

+14
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ class MATHS_EXPORT CBoostedTreeNode final {
300300
double curvature,
301301
TNodeVec& tree);
302302

303+
//! Get the feature index of the split.
304+
std::size_t splitFeature() const { return m_SplitFeature; };
305+
303306
//! Persist by passing information to \p inserter.
304307
void acceptPersistInserter(core::CStatePersistInserter& inserter) const;
305308

@@ -382,6 +385,11 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
382385
//! \warning This can only be called after train.
383386
void predict() const override;
384387

388+
//! Write SHAP values to the data frame supplied to the constructor.
389+
//!
390+
//! \warning This can only be called after train.
391+
void computeShapValues() override;
392+
385393
//! Get the feature weights the model has chosen.
386394
const TDoubleVec& featureWeights() const override;
387395

@@ -391,6 +399,12 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
391399
//! Get the column containing the model's prediction for the dependent variable.
392400
std::size_t columnHoldingPrediction(std::size_t numberColumns) const override;
393401

402+
//! Get the optional vector of column indices with SHAP values
403+
TSizeVec columnsHoldingShapValues() const override;
404+
405+
//! Get the number of largest SHAP values that will be returned for every row.
406+
std::size_t topShapValues() const override;
407+
394408
//! Get the model produced by training if it has been run.
395409
const TNodeVecVec& trainedModel() const;
396410

include/maths/CBoostedTreeFactory.h

+5
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class MATHS_EXPORT CBoostedTreeFactory final {
8787
CBoostedTreeFactory& bayesianOptimisationRestarts(std::size_t restarts);
8888
//! Set the number of training examples we need per feature we'll include.
8989
CBoostedTreeFactory& rowsPerFeature(std::size_t rowsPerFeature);
90+
91+
//! Set the number of training examples we need per feature we'll include.
92+
CBoostedTreeFactory& topShapValues(std::size_t topShapValues);
93+
9094
//! Set whether to try and balance within class accuracy. For classification
9195
//! this reweights examples so approximately the same total loss is assigned
9296
//! to every class.
@@ -205,6 +209,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
205209
TProgressCallback m_RecordProgress = noopRecordProgress;
206210
TMemoryUsageCallback m_RecordMemoryUsage = noopRecordMemoryUsage;
207211
TTrainingStateCallback m_RecordTrainingState = noopRecordTrainingState;
212+
std::size_t m_TopShapValues = 0;
208213
};
209214
}
210215
}

include/maths/CBoostedTreeImpl.h

+21-4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
6060
using TTrainingStateCallback = CBoostedTree::TTrainingStateCallback;
6161
using TOptionalDouble = boost::optional<double>;
6262
using TRegularization = CBoostedTreeRegularization<double>;
63+
using TSizeVec = std::vector<std::size_t>;
6364

6465
public:
6566
static const double MINIMUM_RELATIVE_GAIN_PER_SPLIT;
@@ -83,6 +84,11 @@ class MATHS_EXPORT CBoostedTreeImpl final {
8384
//! \note Must be called only if a trained model is available.
8485
void predict(core::CDataFrame& frame, const TProgressCallback& /*recordProgress*/) const;
8586

87+
//! Compute SHAP values using the best trained model to \p frame.
88+
//!
89+
//! \note Must be called only if a trained model is available.
90+
void computeShapValues(core::CDataFrame& frame, const TProgressCallback&);
91+
8692
//! Get the feature sample probabilities.
8793
const TDoubleVec& featureWeights() const;
8894

@@ -132,12 +138,20 @@ class MATHS_EXPORT CBoostedTreeImpl final {
132138
//! \return The best hyperparameters for validation error found so far.
133139
const CBoostedTreeHyperparameters& bestHyperparameters() const;
134140

141+
//! Get the indices of the columns containing SHAP values.
142+
TSizeVec columnsHoldingShapValues() const;
143+
144+
//! Get the number of largest SHAP values that will be returned for every row.
145+
std::size_t topShapValues() const;
146+
147+
//! Get the number of input columns.
148+
std::size_t numberInputColumns() const;
149+
135150
private:
136151
using TSizeDoublePr = std::pair<std::size_t, double>;
137152
using TDoubleDoublePr = std::pair<double, double>;
138153
using TOptionalSize = boost::optional<std::size_t>;
139154
using TImmutableRadixSetVec = std::vector<core::CImmutableRadixSet<double>>;
140-
using TSizeVec = std::vector<std::size_t>;
141155
using TVector = CDenseVector<double>;
142156
using TRowItr = core::CDataFrame::TRowItr;
143157
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
@@ -383,6 +397,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
383397
// The maximum number of rows encoded by a single byte in the packed bit
384398
// vector assuming best compression.
385399
static const std::size_t PACKED_BIT_VECTOR_MAXIMUM_ROWS_PER_BYTE;
400+
static const double INF;
386401

387402
private:
388403
CBoostedTreeImpl();
@@ -492,9 +507,6 @@ class MATHS_EXPORT CBoostedTreeImpl final {
492507
//! Record the training state using the \p recordTrainState callback function
493508
void recordState(const TTrainingStateCallback& recordTrainState) const;
494509

495-
private:
496-
static const double INF;
497-
498510
private:
499511
mutable CPRNG::CXorOShiro128Plus m_Rng;
500512
std::size_t m_NumberThreads;
@@ -529,7 +541,12 @@ class MATHS_EXPORT CBoostedTreeImpl final {
529541
std::size_t m_NumberRounds = 1;
530542
std::size_t m_CurrentRound = 0;
531543
core::CLoopProgress m_TrainingProgress;
544+
std::size_t m_TopShapValues = 0;
545+
std::size_t m_FirstShapColumnIndex = 0;
546+
std::size_t m_LastShapColumnIndex = 0;
547+
std::size_t m_NumberInputColumns = 0;
532548

549+
private:
533550
friend class CBoostedTreeFactory;
534551
};
535552

include/maths/CDataFrameCategoryEncoder.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ class MATHS_EXPORT CDataFrameCategoryEncoder final {
211211
};
212212

213213
public:
214-
CDataFrameCategoryEncoder(CMakeDataFrameCategoryEncoder parameters);
214+
CDataFrameCategoryEncoder(CMakeDataFrameCategoryEncoder& builder);
215+
CDataFrameCategoryEncoder(CMakeDataFrameCategoryEncoder&& builder);
215216

216217
//! Initialize from serialized data.
217218
CDataFrameCategoryEncoder(core::CStateRestoreTraverser& traverser);
@@ -288,6 +289,8 @@ class MATHS_EXPORT CMakeDataFrameCategoryEncoder {
288289
const core::CDataFrame& frame,
289290
std::size_t targetColumn);
290291

292+
virtual ~CMakeDataFrameCategoryEncoder() = default;
293+
291294
//! Set the minimum number of training rows needed per feature used.
292295
CMakeDataFrameCategoryEncoder& minimumRowsPerFeature(std::size_t minimumRowsPerFeature);
293296

@@ -313,7 +316,7 @@ class MATHS_EXPORT CMakeDataFrameCategoryEncoder {
313316
CMakeDataFrameCategoryEncoder& columnMask(TSizeVec columnMask);
314317

315318
//! Make the encoding.
316-
TEncodingUPtrVec makeEncodings();
319+
virtual TEncodingUPtrVec makeEncodings();
317320

318321
//! \name Test Methods
319322
//@{

include/maths/CDataFrameRegressionModel.h

+17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include <maths/ImportExport.h>
1313

14+
#include <boost/optional.hpp>
15+
1416
#include <functional>
1517
#include <utility>
1618
#include <vector>
@@ -27,6 +29,7 @@ namespace maths {
2729
class MATHS_EXPORT CDataFrameRegressionModel {
2830
public:
2931
using TDoubleVec = std::vector<double>;
32+
using TSizeVec = std::vector<std::size_t>;
3033
using TProgressCallback = std::function<void(double)>;
3134
using TMemoryUsageCallback = std::function<void(std::uint64_t)>;
3235
using TPersistFunc = std::function<void(core::CStatePersistInserter&)>;
@@ -44,6 +47,11 @@ class MATHS_EXPORT CDataFrameRegressionModel {
4447
//! \warning This can only be called after train.
4548
virtual void predict() const = 0;
4649

50+
//! Write SHAP values to the data frame supplied to the contructor.
51+
//!
52+
//! \warning This can only be called after train.
53+
virtual void computeShapValues() = 0;
54+
4755
//! Get the feature weights the model has chosen.
4856
virtual const TDoubleVec& featureWeights() const = 0;
4957

@@ -53,6 +61,15 @@ class MATHS_EXPORT CDataFrameRegressionModel {
5361
//! Get the column containing the model's prediction for the dependent variable.
5462
virtual std::size_t columnHoldingPrediction(std::size_t numberColumns) const = 0;
5563

64+
//! Get the number of largest SHAP values that will be returned for every row.
65+
virtual std::size_t topShapValues() const = 0;
66+
67+
//! Get the optional vector of column indices with SHAP values
68+
virtual TSizeVec columnsHoldingShapValues() const = 0;
69+
70+
public:
71+
static const std::string SHAP_PREFIX;
72+
5673
protected:
5774
CDataFrameRegressionModel(core::CDataFrame& frame,
5875
TProgressCallback recordProgress,
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
#ifndef INCLUDED_ml_maths_CTreeShapFeatureImportance_h
8+
#define INCLUDED_ml_maths_CTreeShapFeatureImportance_h
9+
10+
#include <maths/CBoostedTree.h>
11+
#include <maths/ImportExport.h>
12+
13+
#include <vector>
14+
15+
namespace ml {
16+
namespace maths {
17+
18+
//! \brief Computes SHAP (SHapley Additive exPlanation) values for feature importance estimation for gradient boosting
19+
//! trees.
20+
//!
21+
//! DESCRIPTION:\n
22+
//! SHAP values is a unique consistent and locally accurate attribution value. This mean that the sum of the SHAP
23+
//! feature importance values approximates the model prediction up to a constant bias. This implementation follows the
24+
//! algorithm "Consistent Individualized Feature Attribution for Tree Ensembles" by Lundberg, Erion, and Lee.
25+
//! The algorithm has the complexity O(TLD^2) where T is the number of trees, L is the maximum number of leaves in the
26+
//! tree, and D is the maximum depth of a tree in the ensemble.
27+
class MATHS_EXPORT CTreeShapFeatureImportance {
28+
public:
29+
using TTree = std::vector<CBoostedTreeNode>;
30+
using TTreeVec = std::vector<TTree>;
31+
using TIntVec = std::vector<int>;
32+
using TDoubleVec = std::vector<double>;
33+
using TDoubleVecVec = std::vector<TDoubleVec>;
34+
35+
public:
36+
explicit CTreeShapFeatureImportance(TTreeVec trees, std::size_t threads = 1);
37+
38+
//! Compute SHAP values for the data in \p frame using the specified \p encoder.
39+
//! The results are written directly back into the \p frame, the index of the first result column is controller
40+
//! by \p offset.
41+
void shap(core::CDataFrame& frame, const CDataFrameCategoryEncoder& encoder, std::size_t offset);
42+
43+
//! Compute number of training samples from \p frame that pass every node in the \p tree.
44+
static TDoubleVec samplesPerNode(const TTree& tree,
45+
const core::CDataFrame& frame,
46+
const CDataFrameCategoryEncoder& encoder,
47+
std::size_t numThreads);
48+
49+
//! Recursively computes inner node values as weighted average of the children (leaf) values
50+
//! \returns The maximum depth the the tree.
51+
static std::size_t updateNodeValues(TTree& tree,
52+
std::size_t nodeIndex,
53+
const TDoubleVec& samplesPerNode,
54+
std::size_t depth);
55+
56+
//! Get the reference to the trees.
57+
TTreeVec& trees() { return m_Trees; };
58+
59+
private:
60+
using TSizeVec = std::vector<std::size_t>;
61+
62+
//! Manages variables for the current path through the tree as the main algorithm proceeds.
63+
struct SPath {
64+
explicit SPath(std::size_t length)
65+
: s_FractionOnes(length), s_FractionZeros(length),
66+
s_FeatureIndex(length, -1), s_Scale(length), s_NextIndex(0),
67+
s_MaxLength(length) {}
68+
69+
void extend(int featureIndex, double fractionZero, double fractionOne) {
70+
if (s_NextIndex < s_MaxLength) {
71+
s_FeatureIndex[s_NextIndex] = featureIndex;
72+
s_FractionZeros[s_NextIndex] = fractionZero;
73+
s_FractionOnes[s_NextIndex] = fractionOne;
74+
if (s_NextIndex == 0) {
75+
s_Scale[s_NextIndex] = 1.0;
76+
} else {
77+
s_Scale[s_NextIndex] = 0.0;
78+
}
79+
++s_NextIndex;
80+
}
81+
}
82+
83+
void reduce(std::size_t pathIndex) {
84+
for (std::size_t i = pathIndex; i < this->depth(); ++i) {
85+
s_FeatureIndex[i] = s_FeatureIndex[i + 1];
86+
s_FractionZeros[i] = s_FractionZeros[i + 1];
87+
s_FractionOnes[i] = s_FractionOnes[i + 1];
88+
}
89+
--s_NextIndex;
90+
}
91+
92+
//! Indicator whether or not the feature \p pathIndex is decicive for the path.
93+
double fractionOnes(std::size_t pathIndex) const {
94+
return s_FractionOnes[pathIndex];
95+
}
96+
97+
//! Fraction of all training data that reached the \pathIndex in the path.
98+
double fractionZeros(std::size_t pathIndex) const {
99+
return s_FractionZeros[pathIndex];
100+
}
101+
102+
int featureIndex(std::size_t pathIndex) const {
103+
return s_FeatureIndex[pathIndex];
104+
}
105+
106+
//! Scaling coefficients (factorials), see. Equation (2) in the paper by Lundberg et al.
107+
double scale(std::size_t pathIndex) const { return s_Scale[pathIndex]; }
108+
109+
//! Current depth in the tree
110+
std::size_t depth() const { return s_NextIndex - 1; };
111+
112+
TDoubleVec s_FractionOnes;
113+
TDoubleVec s_FractionZeros;
114+
TIntVec s_FeatureIndex;
115+
TDoubleVec s_Scale;
116+
std::size_t s_NextIndex;
117+
std::size_t s_MaxLength;
118+
};
119+
120+
private:
121+
//! Recursively traverses all pathes in the \p tree and updated SHAP values once it hits a leaf.
122+
//! Ref. Algorithm 2 in the paper by Lundberg et al.
123+
void shapRecursive(const TTree& tree,
124+
const TDoubleVec& samplesPerNode,
125+
const CDataFrameCategoryEncoder& encoder,
126+
const CEncodedDataFrameRowRef& encodedRow,
127+
SPath splitPath,
128+
std::size_t nodeIndex,
129+
double parentFractionZero,
130+
double parentFractionOne,
131+
int parentFeatureIndex,
132+
std::size_t offset,
133+
core::CDataFrame::TRowItr& row) const;
134+
//! Extend the \p path object, update the variables and factorial scaling coefficients.
135+
static void extendPath(SPath& path, double fractionZero, double fractionOne, int featureIndex);
136+
//! Sum the scaling coefficients for the \p path without the feature defined in \p pathIndex.
137+
static double sumUnwoundPath(const SPath& path, std::size_t pathIndex);
138+
//! Updated the scaling coefficients in the \p path if the feature defined in \p pathIndex was seen again.
139+
static void unwindPath(SPath& path, std::size_t pathIndex);
140+
141+
private:
142+
TTreeVec m_Trees;
143+
std::size_t m_NumberThreads;
144+
TDoubleVecVec m_SamplesPerNode;
145+
};
146+
}
147+
}
148+
149+
#endif // INCLUDED_ml_maths_CTreeShapFeatureImportance_h

include/test/CDataFrameAnalysisSpecificationFactory.h

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class TEST_EXPORT CDataFrameAnalysisSpecificationFactory {
5757
double eta = -1.0,
5858
std::size_t maximumNumberTrees = 0,
5959
double featureBagFraction = -1.0,
60+
size_t topShapValues = 0,
6061
TPersisterSupplier* persisterSupplier = nullptr,
6162
TRestoreSearcherSupplier* restoreSearcherSupplier = nullptr);
6263
};

0 commit comments

Comments
 (0)