Skip to content

Commit 3788a70

Browse files
authored
[ML][Inference] handle string values better in feature extraction (#48584)
* [ML][Inference] handle string values better in feature extraction * adding tests for InferenceHelpers
1 parent 484886a commit 3788a70

File tree

4 files changed

+78
-5
lines changed

4 files changed

+78
-5
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List
5656
return topClassEntries;
5757
}
5858

59-
6059
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
6160
assert inferenceValue == Math.rint(inferenceValue);
6261
if (classificationLabels == null) {
@@ -72,4 +71,19 @@ public static String classificationLabel(double inferenceValue, @Nullable List<S
7271
}
7372
return classificationLabels.get(label);
7473
}
74+
75+
public static Double toDouble(Object value) {
76+
if (value instanceof Number) {
77+
return ((Number)value).doubleValue();
78+
}
79+
if (value instanceof String) {
80+
try {
81+
return Double.valueOf((String)value);
82+
} catch (NumberFormatException nfe) {
83+
assert false : "value is not properly formatted double [" + value + "]";
84+
return null;
85+
}
86+
}
87+
return null;
88+
}
7589
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
126126
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
127127
}
128128

129-
List<Double> features = featureNames.stream().map(f ->
130-
fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null
131-
).collect(Collectors.toList());
129+
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
132130
return infer(features, config);
133131
}
134132

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
7+
8+
import org.elasticsearch.test.ESTestCase;
9+
10+
import java.util.HashMap;
11+
12+
import static org.hamcrest.Matchers.equalTo;
13+
import static org.hamcrest.Matchers.is;
14+
import static org.hamcrest.Matchers.nullValue;
15+
16+
17+
public class InferenceHelpersTests extends ESTestCase {
18+
19+
public void testToDoubleFromNumbers() {
20+
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
21+
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
22+
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5L)));
23+
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5)));
24+
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5f)));
25+
}
26+
27+
public void testToDoubleFromString() {
28+
assertThat(0.5, equalTo(InferenceHelpers.toDouble("0.5")));
29+
assertThat(-0.5, equalTo(InferenceHelpers.toDouble("-0.5")));
30+
assertThat(5.0, equalTo(InferenceHelpers.toDouble("5")));
31+
assertThat(-5.0, equalTo(InferenceHelpers.toDouble("-5")));
32+
33+
// if ae are turned off, then we should get a null value
34+
// otherwise, we should expect an assertion failure telling us that the string is improperly formatted
35+
try {
36+
assertThat(InferenceHelpers.toDouble(""), is(nullValue()));
37+
} catch (AssertionError ae) {
38+
assertThat(ae.getMessage(), equalTo("value is not properly formatted double []"));
39+
}
40+
try {
41+
assertThat(InferenceHelpers.toDouble("notADouble"), is(nullValue()));
42+
} catch (AssertionError ae) {
43+
assertThat(ae.getMessage(), equalTo("value is not properly formatted double [notADouble]"));
44+
}
45+
}
46+
47+
public void testToDoubleFromNull() {
48+
assertThat(InferenceHelpers.toDouble(null), is(nullValue()));
49+
}
50+
51+
public void testDoubleFromUnknownObj() {
52+
assertThat(InferenceHelpers.toDouble(new HashMap<>()), is(nullValue()));
53+
}
54+
55+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ public void testInfer() {
154154
assertThat(0.2,
155155
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
156156

157+
// This should still work if the internal values are strings
158+
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
159+
featureMap = zipObjMap(featureNames, featureVectorStrings);
160+
assertThat(0.2,
161+
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
162+
157163
// This should handle missing values and take the default_left path
158164
featureMap = new HashMap<>(2) {{
159165
put("foo", 0.3);
@@ -294,7 +300,7 @@ public void testTreeWithTargetTypeAndLabelsMismatch() {
294300
assertThat(ex.getMessage(), equalTo(msg));
295301
}
296302

297-
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
303+
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
298304
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
299305
}
300306
}

0 commit comments

Comments
 (0)