Skip to content

Commit 2a99f4e

Browse files
authored
[ML] adjusting feature importance mapping for multi-class support (#53821)
Feature importance storage format is changing to encompass multi-class. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` This change adjusts the mapping creation for analytics so that the field is mapped as a `nested` type. Native side change: elastic/ml-cpp#1071
1 parent 569dffc commit 2a99f4e

File tree

7 files changed

+88
-33
lines changed

7 files changed

+88
-33
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,11 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
286286
@SuppressWarnings("unchecked")
287287
@Override
288288
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
289+
Map<String, Object> additionalProperties = new HashMap<>();
290+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
289291
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
290292
if ((dependentVariableMapping instanceof Map) == false) {
291-
return Collections.emptyMap();
293+
return additionalProperties;
292294
}
293295
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
294296
// If the source field is an alias, fetch the concrete field that the alias points to.
@@ -299,9 +301,8 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
299301
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
300302
// Hence, we need to check the "instanceof" condition again.
301303
if ((dependentVariableMapping instanceof Map) == false) {
302-
return Collections.emptyMap();
304+
return additionalProperties;
303305
}
304-
Map<String, Object> additionalProperties = new HashMap<>();
305306
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
306307
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
307308
return additionalProperties;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*//*
6+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
7+
* or more contributor license agreements. Licensed under the Elastic License;
8+
* you may not use this file except in compliance with the Elastic License.
9+
*/
10+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
11+
12+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
13+
import org.elasticsearch.index.mapper.NumberFieldMapper;
14+
15+
import java.util.Collections;
16+
import java.util.HashMap;
17+
import java.util.Map;
18+
19+
final class MapUtils {
20+
21+
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
22+
static {
23+
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
24+
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
25+
featureImportanceMappingProperties.put("importance",
26+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
27+
Map<String, Object> featureImportanceMapping = new HashMap<>();
28+
// TODO sorted indices don't support nested types
29+
//featureImportanceMapping.put("dynamic", true);
30+
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
31+
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
32+
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
33+
}
34+
35+
static Map<String, Object> featureImportanceMapping() {
36+
return FEATURE_IMPORTANCE_MAPPING;
37+
}
38+
39+
private MapUtils() {}
40+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
17+
import org.elasticsearch.index.mapper.NumberFieldMapper;
1718
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1819

1920
import java.io.IOException;
@@ -187,9 +188,13 @@ public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
187188

188189
@Override
189190
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
191+
Map<String, Object> additionalProperties = new HashMap<>();
192+
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
190193
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
191194
// high (over 10M) values of dependent variable.
192-
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
195+
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
196+
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
197+
return additionalProperties;
193198
}
194199

195200
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

+24-19
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.Set;
2727

2828
import static org.hamcrest.Matchers.allOf;
29-
import static org.hamcrest.Matchers.anEmptyMap;
3029
import static org.hamcrest.Matchers.containsString;
3130
import static org.hamcrest.Matchers.empty;
3231
import static org.hamcrest.Matchers.equalTo;
@@ -244,39 +243,45 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
244243
}
245244

246245
public void testGetExplicitlyMappedFields() {
247-
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
248-
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
246+
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
247+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
248+
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
249+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
249250
assertThat(
250251
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
251-
is(anEmptyMap()));
252-
assertThat(
253-
new Classification("foo").getExplicitlyMappedFields(
254-
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
255-
"results"),
252+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
253+
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
254+
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
255+
"results");
256+
assertThat(explicitlyMappedFields,
256257
allOf(
257258
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
258259
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
259-
assertThat(
260-
new Classification("foo").getExplicitlyMappedFields(
261-
new HashMap<>() {{
262-
put("foo", new HashMap<>() {{
263-
put("type", "alias");
264-
put("path", "bar");
265-
}});
266-
put("bar", Collections.singletonMap("type", "long"));
267-
}},
268-
"results"),
260+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
261+
262+
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
263+
new HashMap<>() {{
264+
put("foo", new HashMap<>() {{
265+
put("type", "alias");
266+
put("path", "bar");
267+
}});
268+
put("bar", Collections.singletonMap("type", "long"));
269+
}},
270+
"results");
271+
assertThat(explicitlyMappedFields,
269272
allOf(
270273
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
271274
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
275+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
276+
272277
assertThat(
273278
new Classification("foo").getExplicitlyMappedFields(
274279
Collections.singletonMap("foo", new HashMap<>() {{
275280
put("type", "alias");
276281
put("path", "missing");
277282
}}),
278283
"results"),
279-
is(anEmptyMap()));
284+
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
280285
}
281286

282287
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ public void testFieldCardinalityLimitsIsEmpty() {
143143
}
144144

145145
public void testGetExplicitlyMappedFields() {
146-
assertThat(
147-
new Regression("foo").getExplicitlyMappedFields(null, "results"),
148-
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
146+
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
147+
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
148+
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
149149
}
150150

151151
public void testGetStateDocId() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ public void cleanup() {
7676
cleanUp();
7777
}
7878

79-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
8079
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
8180
initialize("classification_single_numeric_feature_and_mixed_data_set");
8281
String predictedClassField = KEYWORD_FIELD + "_prediction";
@@ -108,7 +107,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
108107
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
109108
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
110109
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
111-
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true));
110+
@SuppressWarnings("unchecked")
111+
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
112+
assertThat(importanceArray, hasSize(greaterThan(0)));
112113
}
113114

114115
assertProgress(jobId, 100, 100, 100, 100);

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727

28+
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
2829
import static org.hamcrest.Matchers.anyOf;
2930
import static org.hamcrest.Matchers.equalTo;
3031
import static org.hamcrest.Matchers.greaterThan;
32+
import static org.hamcrest.Matchers.hasSize;
3133
import static org.hamcrest.Matchers.is;
3234

3335
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@@ -48,7 +50,6 @@ public void cleanup() {
4850
cleanUp();
4951
}
5052

51-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
5253
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
5354
initialize("regression_single_numeric_feature_and_mixed_data_set");
5455
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
@@ -86,11 +87,13 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8687
assertThat(resultsObject.containsKey(predictedClassField), is(true));
8788
assertThat(resultsObject.containsKey("is_training"), is(true));
8889
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
90+
@SuppressWarnings("unchecked")
91+
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
92+
assertThat(importanceArray, hasSize(greaterThan(0)));
8993
assertThat(
90-
resultsObject.toString(),
91-
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
92-
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
93-
is(true));
94+
importanceArray.stream().filter(m -> NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))
95+
|| DISCRETE_NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))).findAny(),
96+
isPresent());
9497
}
9598

9699
assertProgress(jobId, 100, 100, 100, 100);

0 commit comments

Comments
 (0)