Skip to content

Commit 4d2be9b

Browse files
[ML] Add num_top_feature_importance_values param to regression and classi… (#50914)
Adds a new parameter to regression and classification that enables computation of importance for the top most important features. The computation of the importance is based on SHAP (SHapley Additive exPlanations) method.
1 parent 360f954 commit 4d2be9b

File tree

19 files changed

+266
-80
lines changed

19 files changed

+266
-80
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

+29-9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) {
4646
static final ParseField ETA = new ParseField("eta");
4747
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
4848
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
49+
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
4950
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5051
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
5152
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
@@ -62,10 +63,11 @@ public static Builder builder(String dependentVariable) {
6263
(Double) a[3],
6364
(Integer) a[4],
6465
(Double) a[5],
65-
(String) a[6],
66-
(Double) a[7],
67-
(Integer) a[8],
68-
(Long) a[9]));
66+
(Integer) a[6],
67+
(String) a[7],
68+
(Double) a[8],
69+
(Integer) a[9],
70+
(Long) a[10]));
6971

7072
static {
7173
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -74,6 +76,7 @@ public static Builder builder(String dependentVariable) {
7476
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
7577
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
7678
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
79+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
7780
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7881
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
7982
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
@@ -86,20 +89,23 @@ public static Builder builder(String dependentVariable) {
8689
private final Double eta;
8790
private final Integer maximumNumberTrees;
8891
private final Double featureBagFraction;
92+
private final Integer numTopFeatureImportanceValues;
8993
private final String predictionFieldName;
9094
private final Double trainingPercent;
9195
private final Integer numTopClasses;
9296
private final Long randomizeSeed;
9397

9498
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
95-
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
99+
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
100+
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
96101
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
97102
this.dependentVariable = Objects.requireNonNull(dependentVariable);
98103
this.lambda = lambda;
99104
this.gamma = gamma;
100105
this.eta = eta;
101106
this.maximumNumberTrees = maximumNumberTrees;
102107
this.featureBagFraction = featureBagFraction;
108+
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
103109
this.predictionFieldName = predictionFieldName;
104110
this.trainingPercent = trainingPercent;
105111
this.numTopClasses = numTopClasses;
@@ -135,6 +141,10 @@ public Double getFeatureBagFraction() {
135141
return featureBagFraction;
136142
}
137143

144+
public Integer getNumTopFeatureImportanceValues() {
145+
return numTopFeatureImportanceValues;
146+
}
147+
138148
public String getPredictionFieldName() {
139149
return predictionFieldName;
140150
}
@@ -170,6 +180,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
170180
if (featureBagFraction != null) {
171181
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
172182
}
183+
if (numTopFeatureImportanceValues != null) {
184+
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
185+
}
173186
if (predictionFieldName != null) {
174187
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
175188
}
@@ -188,8 +201,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
188201

189202
@Override
190203
public int hashCode() {
191-
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
192-
trainingPercent, randomizeSeed, numTopClasses);
204+
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
205+
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
193206
}
194207

195208
@Override
@@ -203,6 +216,7 @@ public boolean equals(Object o) {
203216
&& Objects.equals(eta, that.eta)
204217
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
205218
&& Objects.equals(featureBagFraction, that.featureBagFraction)
219+
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
206220
&& Objects.equals(predictionFieldName, that.predictionFieldName)
207221
&& Objects.equals(trainingPercent, that.trainingPercent)
208222
&& Objects.equals(randomizeSeed, that.randomizeSeed)
@@ -221,6 +235,7 @@ public static class Builder {
221235
private Double eta;
222236
private Integer maximumNumberTrees;
223237
private Double featureBagFraction;
238+
private Integer numTopFeatureImportanceValues;
224239
private String predictionFieldName;
225240
private Double trainingPercent;
226241
private Integer numTopClasses;
@@ -255,6 +270,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) {
255270
return this;
256271
}
257272

273+
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
274+
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
275+
return this;
276+
}
277+
258278
public Builder setPredictionFieldName(String predictionFieldName) {
259279
this.predictionFieldName = predictionFieldName;
260280
return this;
@@ -276,8 +296,8 @@ public Builder setNumTopClasses(Integer numTopClasses) {
276296
}
277297

278298
public Classification build() {
279-
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
280-
trainingPercent, numTopClasses, randomizeSeed);
299+
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
300+
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
281301
}
282302
}
283303
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

+29-9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) {
4646
static final ParseField ETA = new ParseField("eta");
4747
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
4848
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
49+
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
4950
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5051
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
5152
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
@@ -61,9 +62,10 @@ public static Builder builder(String dependentVariable) {
6162
(Double) a[3],
6263
(Integer) a[4],
6364
(Double) a[5],
64-
(String) a[6],
65-
(Double) a[7],
66-
(Long) a[8]));
65+
(Integer) a[6],
66+
(String) a[7],
67+
(Double) a[8],
68+
(Long) a[9]));
6769

6870
static {
6971
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
7274
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
7375
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
7476
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
77+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
7578
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7679
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
7780
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
@@ -83,19 +86,22 @@ public static Builder builder(String dependentVariable) {
8386
private final Double eta;
8487
private final Integer maximumNumberTrees;
8588
private final Double featureBagFraction;
89+
private final Integer numTopFeatureImportanceValues;
8690
private final String predictionFieldName;
8791
private final Double trainingPercent;
8892
private final Long randomizeSeed;
8993

90-
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
91-
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
94+
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
95+
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
96+
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
9297
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
9398
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9499
this.lambda = lambda;
95100
this.gamma = gamma;
96101
this.eta = eta;
97102
this.maximumNumberTrees = maximumNumberTrees;
98103
this.featureBagFraction = featureBagFraction;
104+
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
99105
this.predictionFieldName = predictionFieldName;
100106
this.trainingPercent = trainingPercent;
101107
this.randomizeSeed = randomizeSeed;
@@ -130,6 +136,10 @@ public Double getFeatureBagFraction() {
130136
return featureBagFraction;
131137
}
132138

139+
public Integer getNumTopFeatureImportanceValues() {
140+
return numTopFeatureImportanceValues;
141+
}
142+
133143
public String getPredictionFieldName() {
134144
return predictionFieldName;
135145
}
@@ -161,6 +171,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
161171
if (featureBagFraction != null) {
162172
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
163173
}
174+
if (numTopFeatureImportanceValues != null) {
175+
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
176+
}
164177
if (predictionFieldName != null) {
165178
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
166179
}
@@ -176,8 +189,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
176189

177190
@Override
178191
public int hashCode() {
179-
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
180-
trainingPercent, randomizeSeed);
192+
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
193+
predictionFieldName, trainingPercent, randomizeSeed);
181194
}
182195

183196
@Override
@@ -191,6 +204,7 @@ public boolean equals(Object o) {
191204
&& Objects.equals(eta, that.eta)
192205
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
193206
&& Objects.equals(featureBagFraction, that.featureBagFraction)
207+
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
194208
&& Objects.equals(predictionFieldName, that.predictionFieldName)
195209
&& Objects.equals(trainingPercent, that.trainingPercent)
196210
&& Objects.equals(randomizeSeed, that.randomizeSeed);
@@ -208,6 +222,7 @@ public static class Builder {
208222
private Double eta;
209223
private Integer maximumNumberTrees;
210224
private Double featureBagFraction;
225+
private Integer numTopFeatureImportanceValues;
211226
private String predictionFieldName;
212227
private Double trainingPercent;
213228
private Long randomizeSeed;
@@ -241,6 +256,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) {
241256
return this;
242257
}
243258

259+
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
260+
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
261+
return this;
262+
}
263+
244264
public Builder setPredictionFieldName(String predictionFieldName) {
245265
this.predictionFieldName = predictionFieldName;
246266
return this;
@@ -257,8 +277,8 @@ public Builder setRandomizeSeed(Long randomizeSeed) {
257277
}
258278

259279
public Regression build() {
260-
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
261-
trainingPercent, randomizeSeed);
280+
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
281+
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed);
262282
}
263283
}
264284
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+12
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,12 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
12941294
.setPredictionFieldName("my_dependent_variable_prediction")
12951295
.setTrainingPercent(80.0)
12961296
.setRandomizeSeed(42L)
1297+
.setLambda(1.0)
1298+
.setGamma(1.0)
1299+
.setEta(1.0)
1300+
.setMaximumNumberTrees(10)
1301+
.setFeatureBagFraction(0.5)
1302+
.setNumTopFeatureImportanceValues(3)
12971303
.build())
12981304
.setDescription("this is a regression")
12991305
.build();
@@ -1331,6 +1337,12 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13311337
.setTrainingPercent(80.0)
13321338
.setRandomizeSeed(42L)
13331339
.setNumTopClasses(1)
1340+
.setLambda(1.0)
1341+
.setGamma(1.0)
1342+
.setEta(1.0)
1343+
.setMaximumNumberTrees(10)
1344+
.setFeatureBagFraction(0.5)
1345+
.setNumTopFeatureImportanceValues(3)
13341346
.build())
13351347
.setDescription("this is a classification")
13361348
.build();

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+10-8
Original file line numberDiff line numberDiff line change
@@ -2975,10 +2975,11 @@ public void testPutDataFrameAnalytics() throws Exception {
29752975
.setEta(5.5) // <4>
29762976
.setMaximumNumberTrees(50) // <5>
29772977
.setFeatureBagFraction(0.4) // <6>
2978-
.setPredictionFieldName("my_prediction_field_name") // <7>
2979-
.setTrainingPercent(50.0) // <8>
2980-
.setRandomizeSeed(1234L) // <9>
2981-
.setNumTopClasses(1) // <10>
2978+
.setNumTopFeatureImportanceValues(3) // <7>
2979+
.setPredictionFieldName("my_prediction_field_name") // <8>
2980+
.setTrainingPercent(50.0) // <9>
2981+
.setRandomizeSeed(1234L) // <10>
2982+
.setNumTopClasses(1) // <11>
29822983
.build();
29832984
// end::put-data-frame-analytics-classification
29842985

@@ -2989,9 +2990,10 @@ public void testPutDataFrameAnalytics() throws Exception {
29892990
.setEta(5.5) // <4>
29902991
.setMaximumNumberTrees(50) // <5>
29912992
.setFeatureBagFraction(0.4) // <6>
2992-
.setPredictionFieldName("my_prediction_field_name") // <7>
2993-
.setTrainingPercent(50.0) // <8>
2994-
.setRandomizeSeed(1234L) // <9>
2993+
.setNumTopFeatureImportanceValues(3) // <7>
2994+
.setPredictionFieldName("my_prediction_field_name") // <8>
2995+
.setTrainingPercent(50.0) // <9>
2996+
.setRandomizeSeed(1234L) // <10>
29952997
.build();
29962998
// end::put-data-frame-analytics-regression
29972999

@@ -3670,7 +3672,7 @@ public void testPutTrainedModel() throws Exception {
36703672
}
36713673
{
36723674
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
3673-
3675+
36743676
// tag::put-trained-model-execute-listener
36753677
ActionListener<PutTrainedModelResponse> listener = new ActionListener<>() {
36763678
@Override

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static Classification randomClassification() {
3232
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
3333
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
3434
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
35+
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
3536
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3637
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
3738
.setRandomizeSeed(randomBoolean() ? null : randomLong())

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static Regression randomRegression() {
3232
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
3333
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
3434
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
35+
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
3536
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3637
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
3738
.build();

0 commit comments

Comments
 (0)