Skip to content

Commit f0ca4b6

Browse files
committed
Pass prediction_field_type to C++ analytics process
1 parent 121876f commit f0ca4b6

File tree

20 files changed

+291
-127
lines changed

20 files changed

+291
-127
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
6565
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
6666
.flatMap(Set::stream)
6767
.collect(Collectors.toUnmodifiableSet());
68+
/**
69+
* Name of the parameter passed down to C++.
70+
* This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction.
71+
*/
72+
private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
73+
6874
/**
6975
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
7076
* This way the user can see if the prediction was made with confidence they need.
@@ -152,17 +158,37 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
152158
}
153159

154160
@Override
155-
public Map<String, Object> getParams() {
161+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
156162
Map<String, Object> params = new HashMap<>();
157163
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
158164
params.putAll(boostedTreeParams.getParams());
159165
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
160166
if (predictionFieldName != null) {
161167
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
162168
}
169+
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
170+
if (predictionFieldType != null) {
171+
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
172+
}
163173
return params;
164174
}
165175

176+
private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
177+
if (dependentVariableTypes == null) {
178+
return null;
179+
}
180+
if (Types.bool().containsAll(dependentVariableTypes)) {
181+
return "bool";
182+
}
183+
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
184+
return "int";
185+
}
186+
if (Types.categorical().containsAll(dependentVariableTypes)) {
187+
return "string";
188+
}
189+
return null;
190+
}
191+
166192
@Override
167193
public boolean supportsCategoricalFields() {
168194
return true;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
1616

1717
/**
1818
* @return The analysis parameters as a map
19+
* @param extractedFields map of (name, types) for all the extracted fields
1920
*/
20-
Map<String, Object> getParams();
21+
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
2122

2223
/**
2324
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ public int hashCode() {
192192
}
193193

194194
@Override
195-
public Map<String, Object> getParams() {
195+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
196196
Map<String, Object> params = new HashMap<>();
197197
if (nNeighbors != null) {
198198
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
124124
}
125125

126126
@Override
127-
public Map<String, Object> getParams() {
127+
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
128128
Map<String, Object> params = new HashMap<>();
129129
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
130130
params.putAll(boostedTreeParams.getParams());

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

+37
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88
import org.elasticsearch.ElasticsearchStatusException;
99
import org.elasticsearch.common.io.stream.Writeable;
1010
import org.elasticsearch.common.xcontent.XContentParser;
11+
import org.elasticsearch.index.mapper.BooleanFieldMapper;
12+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
13+
import org.elasticsearch.index.mapper.NumberFieldMapper;
1114
import org.elasticsearch.test.AbstractSerializingTestCase;
1215

1316
import java.io.IOException;
17+
import java.util.Map;
18+
import java.util.Set;
1419

1520
import static org.hamcrest.Matchers.equalTo;
1621
import static org.hamcrest.Matchers.is;
@@ -115,6 +120,38 @@ public void testGetTrainingPercent() {
115120
assertThat(classification.getTrainingPercent(), equalTo(100.0));
116121
}
117122

123+
public void testGetParams() {
124+
Map<String, Set<String>> extractedFields =
125+
Map.of(
126+
"foo", Set.of(BooleanFieldMapper.CONTENT_TYPE),
127+
"bar", Set.of(NumberFieldMapper.NumberType.LONG.typeName()),
128+
"baz", Set.of(KeywordFieldMapper.CONTENT_TYPE));
129+
assertThat(
130+
new Classification("foo").getParams(extractedFields),
131+
equalTo(
132+
Map.of(
133+
"dependent_variable", "foo",
134+
"num_top_classes", 2,
135+
"prediction_field_name", "foo_prediction",
136+
"prediction_field_type", "bool")));
137+
assertThat(
138+
new Classification("bar").getParams(extractedFields),
139+
equalTo(
140+
Map.of(
141+
"dependent_variable", "bar",
142+
"num_top_classes", 2,
143+
"prediction_field_name", "bar_prediction",
144+
"prediction_field_type", "int")));
145+
assertThat(
146+
new Classification("baz").getParams(extractedFields),
147+
equalTo(
148+
Map.of(
149+
"dependent_variable", "baz",
150+
"num_top_classes", 2,
151+
"prediction_field_name", "baz_prediction",
152+
"prediction_field_type", "string")));
153+
}
154+
118155
public void testFieldCardinalityLimitsIsNonNull() {
119156
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
120157
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ protected Writeable.Reader<OutlierDetection> instanceReader() {
5151

5252
public void testGetParams_GivenDefaults() {
5353
OutlierDetection outlierDetection = new OutlierDetection.Builder().build();
54-
Map<String, Object> params = outlierDetection.getParams();
54+
Map<String, Object> params = outlierDetection.getParams(null);
5555
assertThat(params.size(), equalTo(3));
5656
assertThat(params.containsKey("compute_feature_influence"), is(true));
5757
assertThat(params.get("compute_feature_influence"), is(true));
@@ -71,7 +71,7 @@ public void testGetParams_GivenExplicitValues() {
7171
.setStandardizationEnabled(false)
7272
.build();
7373

74-
Map<String, Object> params = outlierDetection.getParams();
74+
Map<String, Object> params = outlierDetection.getParams(null);
7575

7676
assertThat(params.size(), equalTo(6));
7777
assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42));

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.test.AbstractSerializingTestCase;
1212

1313
import java.io.IOException;
14+
import java.util.Map;
1415

1516
import static org.hamcrest.Matchers.equalTo;
1617
import static org.hamcrest.Matchers.is;
@@ -83,6 +84,12 @@ public void testGetTrainingPercent() {
8384
assertThat(regression.getTrainingPercent(), equalTo(100.0));
8485
}
8586

87+
public void testGetParams() {
88+
assertThat(
89+
new Regression("foo").getParams(null),
90+
equalTo(Map.of("dependent_variable", "foo", "prediction_field_name", "foo_prediction")));
91+
}
92+
8693
public void testFieldCardinalityLimitsIsNonNull() {
8794
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
8895
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

+70-16
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
2727

2828
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
2929

30-
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
31-
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
30+
private static final String ANIMAL_NAME_FIELD = "animal_name";
31+
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
32+
private static final String NO_LEGS_FIELD = "no_legs";
33+
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
34+
private static final String IS_PREDATOR_FIELD = "predator";
35+
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
3236

3337
@Before
3438
public void setup() {
@@ -40,9 +44,9 @@ public void cleanup() {
4044
cleanUp();
4145
}
4246

43-
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
47+
public void testEvaluate_DefaultMetrics() {
4448
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
45-
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
49+
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
4650

4751
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
4852
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -51,9 +55,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
5155
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
5256
}
5357

54-
public void testEvaluate_MulticlassClassification_Accuracy() {
58+
public void testEvaluate_Accuracy_KeywordField() {
5559
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
56-
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new Accuracy())));
60+
evaluateDataFrame(
61+
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy())));
5762

5863
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
5964
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -72,11 +77,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
7277
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
7378
}
7479

75-
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
80+
public void testEvaluate_Accuracy_IntegerField() {
81+
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
82+
evaluateDataFrame(
83+
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, List.of(new Accuracy())));
84+
85+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
86+
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
87+
88+
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
89+
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
90+
assertThat(
91+
accuracyResult.getActualClasses(),
92+
equalTo(List.of(
93+
new Accuracy.ActualClass("1", 15, 1.0 / 15),
94+
new Accuracy.ActualClass("2", 15, 2.0 / 15),
95+
new Accuracy.ActualClass("3", 15, 3.0 / 15),
96+
new Accuracy.ActualClass("4", 15, 4.0 / 15),
97+
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
98+
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
99+
}
100+
101+
public void testEvaluate_Accuracy_BooleanField() {
102+
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
103+
evaluateDataFrame(
104+
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, List.of(new Accuracy())));
105+
106+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
107+
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
108+
109+
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
110+
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
111+
assertThat(
112+
accuracyResult.getActualClasses(),
113+
equalTo(List.of(
114+
new Accuracy.ActualClass("true", 45, 27.0 / 45),
115+
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
116+
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
117+
}
118+
119+
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
76120
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
77121
evaluateDataFrame(
78122
ANIMALS_DATA_INDEX,
79-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix())));
123+
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix())));
80124

81125
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
82126
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -135,11 +179,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
135179
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
136180
}
137181

138-
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
182+
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
139183
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
140184
evaluateDataFrame(
141185
ANIMALS_DATA_INDEX,
142-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3))));
186+
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3))));
143187

144188
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
145189
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@@ -166,20 +210,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP
166210

167211
private static void indexAnimalsData(String indexName) {
168212
client().admin().indices().prepareCreate(indexName)
169-
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
213+
.addMapping("_doc",
214+
ANIMAL_NAME_FIELD, "type=keyword",
215+
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
216+
NO_LEGS_FIELD, "type=integer",
217+
NO_LEGS_PREDICTION_FIELD, "type=integer",
218+
IS_PREDATOR_FIELD, "type=boolean",
219+
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
170220
.get();
171221

172-
List<String> classNames = List.of("dog", "cat", "mouse", "ant", "fox");
222+
List<String> animalNames = List.of("dog", "cat", "mouse", "ant", "fox");
173223
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
174224
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
175-
for (int i = 0; i < classNames.size(); i++) {
176-
for (int j = 0; j < classNames.size(); j++) {
225+
for (int i = 0; i < animalNames.size(); i++) {
226+
for (int j = 0; j < animalNames.size(); j++) {
177227
for (int k = 0; k < j + 1; k++) {
178228
bulkRequestBuilder.add(
179229
new IndexRequest(indexName)
180230
.source(
181-
ACTUAL_CLASS_FIELD, classNames.get(i),
182-
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
231+
ANIMAL_NAME_FIELD, animalNames.get(i),
232+
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
233+
NO_LEGS_FIELD, i + 1,
234+
NO_LEGS_PREDICTION_FIELD, j + 1,
235+
IS_PREDATOR_FIELD, i % 2 == 0,
236+
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
183237
}
184238
}
185239
}

0 commit comments

Comments
 (0)