Skip to content

Commit 303b7e3

Browse files
authored
Allow integer types for classification's dependent variable (#47902)
1 parent 17610e7 commit 303b7e3

File tree

12 files changed

+259
-70
lines changed

12 files changed

+259
-70
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import java.util.List;
2121
import java.util.Map;
2222
import java.util.Objects;
23+
import java.util.Set;
24+
import java.util.stream.Collectors;
25+
import java.util.stream.Stream;
2326

2427
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
2528
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -58,6 +61,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
5861
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
5962
}
6063

64+
private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES =
65+
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
66+
.flatMap(Set::stream)
67+
.collect(Collectors.toUnmodifiableSet());
68+
6169
private final String dependentVariable;
6270
private final BoostedTreeParams boostedTreeParams;
6371
private final String predictionFieldName;
@@ -147,9 +155,17 @@ public boolean supportsCategoricalFields() {
147155
return true;
148156
}
149157

158+
@Override
159+
public Set<String> getAllowedCategoricalTypes(String fieldName) {
160+
if (dependentVariable.equals(fieldName)) {
161+
return ALLOWED_DEPENDENT_VARIABLE_TYPES;
162+
}
163+
return Types.categorical();
164+
}
165+
150166
@Override
151167
public List<RequiredField> getRequiredFields() {
152-
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
168+
return Collections.singletonList(new RequiredField(dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
153169
}
154170

155171
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import java.util.List;
1212
import java.util.Map;
13+
import java.util.Set;
1314

1415
public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
1516

@@ -23,6 +24,12 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
2324
*/
2425
boolean supportsCategoricalFields();
2526

27+
/**
28+
* @param fieldName field for which the allowed categorical types should be returned
29+
* @return The types treated as categorical for the given field
30+
*/
31+
Set<String> getAllowedCategoricalTypes(String fieldName);
32+
2633
/**
2734
* @return The names and types of the fields that analyzed documents must have for the analysis to operate
2835
*/

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Locale;
2323
import java.util.Map;
2424
import java.util.Objects;
25+
import java.util.Set;
2526

2627
public class OutlierDetection implements DataFrameAnalysis {
2728

@@ -213,6 +214,11 @@ public boolean supportsCategoricalFields() {
213214
return false;
214215
}
215216

217+
@Override
218+
public Set<String> getAllowedCategoricalTypes(String fieldName) {
219+
return Collections.emptySet();
220+
}
221+
216222
@Override
217223
public List<RequiredField> getRequiredFields() {
218224
return Collections.emptyList();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121
import java.util.Map;
2222
import java.util.Objects;
23+
import java.util.Set;
2324

2425
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
2526
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -134,6 +135,11 @@ public boolean supportsCategoricalFields() {
134135
return true;
135136
}
136137

138+
@Override
139+
public Set<String> getAllowedCategoricalTypes(String fieldName) {
140+
return Types.categorical();
141+
}
142+
137143
@Override
138144
public List<RequiredField> getRequiredFields() {
139145
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
77

8-
import org.elasticsearch.index.mapper.NumberFieldMapper;
8+
import org.elasticsearch.index.mapper.BooleanFieldMapper;
9+
import org.elasticsearch.index.mapper.IpFieldMapper;
10+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
11+
import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType;
12+
import org.elasticsearch.index.mapper.TextFieldMapper;
913

1014
import java.util.Collections;
1115
import java.util.Set;
@@ -20,16 +24,19 @@ public final class Types {
2024
private Types() {}
2125

2226
private static final Set<String> CATEGORICAL_TYPES =
23-
Collections.unmodifiableSet(
24-
Stream.of("text", "keyword", "ip")
25-
.collect(Collectors.toSet()));
27+
Stream.of(TextFieldMapper.CONTENT_TYPE, KeywordFieldMapper.CONTENT_TYPE, IpFieldMapper.CONTENT_TYPE)
28+
.collect(Collectors.toUnmodifiableSet());
2629

2730
private static final Set<String> NUMERICAL_TYPES =
28-
Collections.unmodifiableSet(
29-
Stream.concat(
30-
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
31-
Stream.of("scaled_float"))
32-
.collect(Collectors.toSet()));
31+
Stream.concat(Stream.of(NumberType.values()).map(NumberType::typeName), Stream.of("scaled_float"))
32+
.collect(Collectors.toUnmodifiableSet());
33+
34+
private static final Set<String> DISCRETE_NUMERICAL_TYPES =
35+
Stream.of(NumberType.BYTE, NumberType.SHORT, NumberType.INTEGER, NumberType.LONG)
36+
.map(NumberType::typeName)
37+
.collect(Collectors.toUnmodifiableSet());
38+
39+
private static final Set<String> BOOL_TYPES = Collections.singleton(BooleanFieldMapper.CONTENT_TYPE);
3340

3441
public static Set<String> categorical() {
3542
return CATEGORICAL_TYPES;
@@ -38,4 +45,12 @@ public static Set<String> categorical() {
3845
public static Set<String> numerical() {
3946
return NUMERICAL_TYPES;
4047
}
48+
49+
public static Set<String> discreteNumerical() {
50+
return DISCRETE_NUMERICAL_TYPES;
51+
}
52+
53+
public static Set<String> bool() {
54+
return BOOL_TYPES;
55+
}
4156
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1111
import org.elasticsearch.action.bulk.BulkResponse;
1212
import org.elasticsearch.action.get.GetResponse;
13+
import org.elasticsearch.action.index.IndexAction;
1314
import org.elasticsearch.action.index.IndexRequest;
1415
import org.elasticsearch.action.search.SearchResponse;
1516
import org.elasticsearch.action.support.WriteRequest;
@@ -24,6 +25,7 @@
2425
import java.util.Collections;
2526
import java.util.List;
2627
import java.util.Map;
28+
import java.util.function.Function;
2729

2830
import static org.hamcrest.Matchers.allOf;
2931
import static org.hamcrest.Matchers.equalTo;
@@ -34,12 +36,17 @@
3436
import static org.hamcrest.Matchers.in;
3537
import static org.hamcrest.Matchers.is;
3638
import static org.hamcrest.Matchers.lessThanOrEqualTo;
39+
import static org.hamcrest.Matchers.startsWith;
3740

3841
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
3942

43+
private static final String BOOLEAN_FIELD = "boolean-field";
4044
private static final String NUMERICAL_FIELD = "numerical-field";
45+
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
4146
private static final String KEYWORD_FIELD = "keyword-field";
42-
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
47+
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
48+
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
49+
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
4350
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
4451

4552
private String jobId;
@@ -53,7 +60,7 @@ public void cleanup() {
5360

5461
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
5562
initialize("classification_single_numeric_feature_and_mixed_data_set");
56-
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
63+
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
5764

5865
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
5966
registerAnalytics(config);
@@ -91,7 +98,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
9198

9299
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
93100
initialize("classification_only_training_data_and_training_percent_is_100");
94-
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
101+
indexData(sourceIndex, 300, 0, KEYWORD_FIELD);
95102

96103
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
97104
registerAnalytics(config);
@@ -126,17 +133,19 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
126133
"Finished analysis");
127134
}
128135

129-
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
130-
initialize("classification_only_training_data_and_training_percent_is_50");
131-
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
136+
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
137+
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
138+
initialize(jobId);
139+
indexData(sourceIndex, 300, 0, dependentVariable);
132140

141+
int numTopClasses = 2;
133142
DataFrameAnalyticsConfig config =
134143
buildAnalytics(
135144
jobId,
136145
sourceIndex,
137146
destIndex,
138147
null,
139-
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
148+
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
140149
registerAnalytics(config);
141150
putAnalytics(config);
142151

@@ -151,8 +160,11 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
151160
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
152161
for (SearchHit hit : sourceData.getHits()) {
153162
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
154-
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
155-
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
163+
String predictedClassField = dependentVariable + "_prediction";
164+
assertThat(resultsObject.containsKey(predictedClassField), is(true));
165+
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
166+
assertThat(predictedClassValue, is(in(dependentVariableValues)));
167+
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);
156168

157169
assertThat(resultsObject.containsKey("is_training"), is(true));
158170
// Let's just assert there's both training and non-training results
@@ -161,11 +173,9 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
161173
} else {
162174
nonTrainingRowsCount++;
163175
}
164-
assertThat(resultsObject.containsKey("top_classes"), is(false));
165176
}
166177
assertThat(trainingRowsCount, greaterThan(0));
167178
assertThat(nonTrainingRowsCount, greaterThan(0));
168-
169179
assertProgress(jobId, 100, 100, 100, 100);
170180
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
171181
assertThatAuditMessagesMatch(jobId,
@@ -178,9 +188,32 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
178188
"Finished analysis");
179189
}
180190

191+
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
192+
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
193+
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
194+
}
195+
196+
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
197+
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
198+
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
199+
}
200+
201+
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
202+
ElasticsearchStatusException e = expectThrows(
203+
ElasticsearchStatusException.class,
204+
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
205+
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
206+
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
207+
}
208+
209+
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
210+
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
211+
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
212+
}
213+
181214
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
182215
initialize("classification_top_classes_requested");
183-
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
216+
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
184217

185218
int numTopClasses = 2;
186219
DataFrameAnalyticsConfig config =
@@ -206,7 +239,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
206239

207240
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
208241
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
209-
assertTopClasses(resultsObject, numTopClasses);
242+
assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
210243
}
211244

212245
assertProgress(jobId, 100, 100, 100, 100);
@@ -223,7 +256,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
223256

224257
public void testDependentVariableCardinalityTooHighError() {
225258
initialize("cardinality_too_high");
226-
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
259+
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
260+
// Index one more document with a class different than the two already used.
261+
client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex).source(KEYWORD_FIELD, "fox"));
227262

228263
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
229264
registerAnalytics(config);
@@ -240,28 +275,42 @@ private void initialize(String jobId) {
240275
this.destIndex = sourceIndex + "_results";
241276
}
242277

243-
private static void indexData(String sourceIndex,
244-
int numTrainingRows, int numNonTrainingRows,
245-
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
278+
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
246279
client().admin().indices().prepareCreate(sourceIndex)
247-
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
280+
.addMapping("_doc",
281+
BOOLEAN_FIELD, "type=boolean",
282+
NUMERICAL_FIELD, "type=double",
283+
DISCRETE_NUMERICAL_FIELD, "type=integer",
284+
KEYWORD_FIELD, "type=keyword")
248285
.get();
249286

250287
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
251288
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
252289
for (int i = 0; i < numTrainingRows; i++) {
253-
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
254-
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
255-
256-
IndexRequest indexRequest = new IndexRequest(sourceIndex)
257-
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
290+
List<Object> source = List.of(
291+
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
292+
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
293+
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
294+
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
295+
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
258296
bulkRequestBuilder.add(indexRequest);
259297
}
260298
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
261-
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
262-
263-
IndexRequest indexRequest = new IndexRequest(sourceIndex)
264-
.source(NUMERICAL_FIELD, numericalValue);
299+
List<Object> source = new ArrayList<>();
300+
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
301+
source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
302+
}
303+
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
304+
source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
305+
}
306+
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
307+
source.addAll(
308+
List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
309+
}
310+
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
311+
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
312+
}
313+
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
265314
bulkRequestBuilder.add(indexRequest);
266315
}
267316
BulkResponse bulkResponse = bulkRequestBuilder.get();
@@ -289,7 +338,12 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
289338
return resultsObject;
290339
}
291340

292-
private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) {
341+
private static <T> void assertTopClasses(
342+
Map<String, Object> resultsObject,
343+
int numTopClasses,
344+
String dependentVariable,
345+
List<T> dependentVariableValues,
346+
Function<String, T> parser) {
293347
assertThat(resultsObject.containsKey("top_classes"), is(true));
294348
@SuppressWarnings("unchecked")
295349
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
@@ -302,9 +356,9 @@ private static void assertTopClasses(Map<String, Object> resultsObject, int numT
302356
classProbabilities.add((Double) topClass.get("class_probability"));
303357
}
304358
// Assert that all the predicted class names come from the set of keyword field values.
305-
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
359+
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
306360
// Assert that the first class listed in top classes is the same as the predicted class.
307-
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
361+
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
308362
// Assert that all the class probabilities lie within [0, 1] interval.
309363
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
310364
// Assert that the top classes are listed in the order of decreasing probabilities.

0 commit comments

Comments
 (0)