diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 199158cceaa34..7dbfdd4180272 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -20,6 +20,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -58,6 +61,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); } + private static final Set ALLOWED_DEPENDENT_VARIABLE_TYPES = + Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool()) + .flatMap(Set::stream) + .collect(Collectors.toUnmodifiableSet()); + private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; @@ -147,9 +155,17 @@ public boolean supportsCategoricalFields() { return true; } + @Override + public Set getAllowedCategoricalTypes(String fieldName) { + if (dependentVariable.equals(fieldName)) { + return ALLOWED_DEPENDENT_VARIABLE_TYPES; + } + return Types.categorical(); + } + @Override public List getRequiredFields() { - return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical())); + return Collections.singletonList(new RequiredField(dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES)); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index d23097f5816b2..0ca32cde4021c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { @@ -23,6 +24,12 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { */ boolean supportsCategoricalFields(); + /** + * @param fieldName field for which the allowed categorical types should be returned + * @return The types treated as categorical for the given field + */ + Set getAllowedCategoricalTypes(String fieldName); + /** * @return The names and types of the fields that analyzed documents must have for the analysis to operate */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 055c97a511a8f..d4cefe884b53b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -22,6 +22,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; public class OutlierDetection implements DataFrameAnalysis { @@ -213,6 +214,11 @@ public boolean supportsCategoricalFields() { return false; } + @Override + public Set getAllowedCategoricalTypes(String fieldName) { + return Collections.emptySet(); + } + @Override public List getRequiredFields() { return Collections.emptyList(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 34a93713385d0..5412156a1b4cd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -134,6 +135,11 @@ public boolean supportsCategoricalFields() { return true; } + @Override + public Set getAllowedCategoricalTypes(String fieldName) { + return Types.categorical(); + } + @Override public List getRequiredFields() { return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical())); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java index fc991c86f5663..708db1a913f86 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java @@ -5,7 +5,11 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; -import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.mapper.BooleanFieldMapper; +import org.elasticsearch.index.mapper.IpFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType; +import org.elasticsearch.index.mapper.TextFieldMapper; import java.util.Collections; import java.util.Set; @@ -20,16 +24,19 @@ public final class Types { private Types() {} private static final Set CATEGORICAL_TYPES = - Collections.unmodifiableSet( - Stream.of("text", "keyword", "ip") - .collect(Collectors.toSet())); + Stream.of(TextFieldMapper.CONTENT_TYPE, KeywordFieldMapper.CONTENT_TYPE, IpFieldMapper.CONTENT_TYPE) + .collect(Collectors.toUnmodifiableSet()); private static final Set NUMERICAL_TYPES = - Collections.unmodifiableSet( - Stream.concat( - Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName), - Stream.of("scaled_float")) - .collect(Collectors.toSet())); + Stream.concat(Stream.of(NumberType.values()).map(NumberType::typeName), Stream.of("scaled_float")) + .collect(Collectors.toUnmodifiableSet()); + + private static final Set DISCRETE_NUMERICAL_TYPES = + Stream.of(NumberType.BYTE, NumberType.SHORT, NumberType.INTEGER, NumberType.LONG) + .map(NumberType::typeName) + .collect(Collectors.toUnmodifiableSet()); + + private static final Set BOOL_TYPES = Collections.singleton(BooleanFieldMapper.CONTENT_TYPE); public static Set categorical() { return CATEGORICAL_TYPES; @@ -38,4 +45,12 @@ public static Set categorical() { public static Set numerical() { return NUMERICAL_TYPES; } + + public static Set discreteNumerical() { + return DISCRETE_NUMERICAL_TYPES; + } + + public static Set bool() { + return BOOL_TYPES; + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index b8fb0a5f91afa..776a27ff7349a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; @@ -24,6 +25,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Function; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; @@ -34,12 +36,17 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.startsWith; public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { + private static final String BOOLEAN_FIELD = "boolean-field"; private static final String NUMERICAL_FIELD = "numerical-field"; + private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field"; private static final String KEYWORD_FIELD = "keyword-field"; - private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + private static final List BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true)); + private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0)); + private static final List DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20)); private static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat")); private String jobId; @@ -53,7 +60,7 @@ public void cleanup() { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("classification_single_numeric_feature_and_mixed_data_set"); - indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); + indexData(sourceIndex, 300, 50, KEYWORD_FIELD); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); registerAnalytics(config); @@ -91,7 +98,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { initialize("classification_only_training_data_and_training_percent_is_100"); - indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); + indexData(sourceIndex, 300, 0, KEYWORD_FIELD); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); registerAnalytics(config); @@ -126,17 +133,19 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti "Finished analysis"); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { - initialize("classification_only_training_data_and_training_percent_is_50"); - indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( + String jobId, String dependentVariable, List dependentVariableValues, Function parser) throws Exception { + initialize(jobId); + indexData(sourceIndex, 300, 0, dependentVariable); + int numTopClasses = 2; DataFrameAnalyticsConfig config = buildAnalytics( jobId, sourceIndex, destIndex, null, - new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0)); + new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0)); registerAnalytics(config); putAnalytics(config); @@ -151,8 +160,11 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); - assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); + String predictedClassField = dependentVariable + "_prediction"; + assertThat(resultsObject.containsKey(predictedClassField), is(true)); + T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField)); + assertThat(predictedClassValue, is(in(dependentVariableValues))); + assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser); assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results @@ -161,11 +173,9 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception } else { nonTrainingRowsCount++; } - assertThat(resultsObject.containsKey("top_classes"), is(false)); } assertThat(trainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0)); - assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThatAuditMessagesMatch(jobId, @@ -178,9 +188,32 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception "Finished analysis"); } + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception { + testWithOnlyTrainingRowsAndTrainingPercentIsFifty( + "classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); + } + + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception { + testWithOnlyTrainingRowsAndTrainingPercentIsFifty( + "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf); + } + + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { + ElasticsearchStatusException e = expectThrows( + ElasticsearchStatusException.class, + () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( + "classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf)); + assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); + } + + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception { + testWithOnlyTrainingRowsAndTrainingPercentIsFifty( + "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf); + } + public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception { initialize("classification_top_classes_requested"); - indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); + indexData(sourceIndex, 300, 50, KEYWORD_FIELD); int numTopClasses = 2; DataFrameAnalyticsConfig config = @@ -206,7 +239,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); - assertTopClasses(resultsObject, numTopClasses); + assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); } assertProgress(jobId, 100, 100, 100, 100); @@ -223,7 +256,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse public void testDependentVariableCardinalityTooHighError() { initialize("cardinality_too_high"); - indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox")); + indexData(sourceIndex, 6, 5, KEYWORD_FIELD); + // Index one more document with a class different than the two already used. + client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex).source(KEYWORD_FIELD, "fox")); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); registerAnalytics(config); @@ -240,28 +275,42 @@ private void initialize(String jobId) { this.destIndex = sourceIndex + "_results"; } - private static void indexData(String sourceIndex, - int numTrainingRows, int numNonTrainingRows, - List numericalFieldValues, List keywordFieldValues) { + private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { client().admin().indices().prepareCreate(sourceIndex) - .addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword") + .addMapping("_doc", + BOOLEAN_FIELD, "type=boolean", + NUMERICAL_FIELD, "type=double", + DISCRETE_NUMERICAL_FIELD, "type=integer", + KEYWORD_FIELD, "type=keyword") .get(); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < numTrainingRows; i++) { - Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); - String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size()); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue); + List source = List.of( + BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()), + NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()), + DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()), + KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())); + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { - Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FIELD, numericalValue); + List source = new ArrayList<>(); + if (BOOLEAN_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()))); + } + if (NUMERICAL_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()))); + } + if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) { + source.addAll( + List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()))); + } + if (KEYWORD_FIELD.equals(dependentVariable) == false) { + source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()))); + } + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } BulkResponse bulkResponse = bulkRequestBuilder.get(); @@ -289,7 +338,12 @@ private static Map getMlResultsObjectFromDestDoc(Map resultsObject, int numTopClasses) { + private static void assertTopClasses( + Map resultsObject, + int numTopClasses, + String dependentVariable, + List dependentVariableValues, + Function parser) { assertThat(resultsObject.containsKey("top_classes"), is(true)); @SuppressWarnings("unchecked") List> topClasses = (List>) resultsObject.get("top_classes"); @@ -302,9 +356,9 @@ private static void assertTopClasses(Map resultsObject, int numT classProbabilities.add((Double) topClass.get("class_probability")); } // Assert that all the predicted class names come from the set of keyword field values. - classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES)))); + classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues)))); // Assert that the first class listed in top classes is the same as the predicted class. - assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction"))); + assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction"))); // Assert that all the class probabilities lie within [0, 1] interval. classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); // Assert that the top classes are listed in the order of decreasing probabilities. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 657608d08bb9b..d62a0156c5b79 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -23,7 +23,7 @@ import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsIndex; @@ -31,7 +31,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -265,15 +264,11 @@ private SearchRequestBuilder buildDataSummarySearchRequestBuilder() { .setTrackTotalHits(true); } - public Set getCategoricalFields() { - Set categoricalFields = new HashSet<>(); - for (ExtractedField extractedField : context.extractedFields.getAllFields()) { - String fieldName = extractedField.getName(); - if (Types.categorical().containsAll(extractedField.getTypes())) { - categoricalFields.add(fieldName); - } - } - return categoricalFields; + public Set getCategoricalFields(DataFrameAnalysis analysis) { + return context.extractedFields.getAllFields().stream() + .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()).containsAll(extractedField.getTypes())) + .map(ExtractedField::getName) + .collect(Collectors.toUnmodifiableSet()); } public static class DataSummary { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index dc173f4d8ffcb..000ca23a9a8f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -264,7 +264,14 @@ private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFi List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); for (ExtractedField field : extractedFields.getAllFields()) { if (isBoolean(field.getTypes())) { - adjusted.add(new BooleanAsInteger(field)); + if (config.getAnalysis().getAllowedCategoricalTypes(field.getAlias()).contains(BooleanFieldMapper.CONTENT_TYPE)) { + // We convert boolean field to string if it is a categorical dependent variable + adjusted.add(new BooleanMapper<>(field, Boolean.TRUE.toString(), Boolean.FALSE.toString())); + } else { + // We convert boolean fields to integers with values 0, 1 as this is the preferred + // way to consume such features in the analytics process. + adjusted.add(new BooleanMapper<>(field, 1, 0)); + } } else { adjusted.add(field); } @@ -277,21 +284,24 @@ private static boolean isBoolean(Set types) { } /** - * We convert boolean fields to integers with values 0, 1 as this is the preferred - * way to consume such features in the analytics process. + * {@link BooleanMapper} makes boolean field behave as a field of different type. */ - private static class BooleanAsInteger extends ExtractedField { + private static final class BooleanMapper extends ExtractedField { - protected BooleanAsInteger(ExtractedField field) { + private final T trueValue; + private final T falseValue; + + BooleanMapper(ExtractedField field, T trueValue, T falseValue) { super(field.getAlias(), field.getName(), Collections.singleton(BooleanFieldMapper.CONTENT_TYPE), ExtractionMethod.DOC_VALUE); + this.trueValue = trueValue; + this.falseValue = falseValue; } @Override public Object[] value(SearchHit hit) { DocumentField keyValue = hit.field(name); if (keyValue != null) { - List values = keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? 1 : 0).collect(Collectors.toList()); - return values.toArray(new Object[0]); + return keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? trueValue : falseValue).toArray(); } return new Object[0]; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index f361238cf4cb7..485b9d9d60501 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -367,7 +367,7 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - Set categoricalFields = dataExtractor.getCategoricalFields(); + Set categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis()); AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(config.getId(), dataSummary.rows, dataSummary.cols, config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), categoricalFields, config.getAnalysis()); return processConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java index a4c0cb2025ef7..c41e3038725f8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java @@ -57,7 +57,7 @@ private MemoryUsageEstimationResult runJob(String jobId, DataFrameDataExtractorFactory dataExtractorFactory) { DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - Set categoricalFields = dataExtractor.getCategoricalFields(); + Set categoricalFields = dataExtractor.getCategoricalFields(config.getAnalysis()); if (dataSummary.rows == 0) { return new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index ed00512a81c5d..9bbd3f207cc11 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -24,6 +24,9 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; @@ -41,7 +44,9 @@ import java.util.Queue; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; @@ -384,6 +389,36 @@ public void testMissingValues_GivenShouldInclude() throws IOException { assertThat(dataExtractor.hasNext(), is(false)); } + public void testGetCategoricalFields() { + extractedFields = new ExtractedFields(Arrays.asList( + ExtractedField.newField("field_boolean", Collections.singleton("boolean"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_float", Collections.singleton("float"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_double", Collections.singleton("double"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_byte", Collections.singleton("byte"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_short", Collections.singleton("short"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_integer", Collections.singleton("integer"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_long", Collections.singleton("long"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_keyword", Collections.singleton("keyword"), ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_text", Collections.singleton("text"), ExtractedField.ExtractionMethod.SOURCE))); + TestExtractor dataExtractor = createExtractor(true, true); + + assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); + + assertThat(dataExtractor.getCategoricalFields(new Regression("field_double")), containsInAnyOrder("field_keyword", "field_text")); + assertThat(dataExtractor.getCategoricalFields(new Regression("field_long")), containsInAnyOrder("field_keyword", "field_text")); + assertThat(dataExtractor.getCategoricalFields(new Regression("field_boolean")), containsInAnyOrder("field_keyword", "field_text")); + + assertThat( + dataExtractor.getCategoricalFields(new Classification("field_keyword")), + containsInAnyOrder("field_keyword", "field_text")); + assertThat( + dataExtractor.getCategoricalFields(new Classification("field_long")), + containsInAnyOrder("field_keyword", "field_text", "field_long")); + assertThat( + dataExtractor.getCategoricalFields(new Classification("field_boolean")), + containsInAnyOrder("field_keyword", "field_text", "field_boolean")); + } + private TestExtractor createExtractor(boolean includeSource, boolean includeRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, includeRowsWithMissingValues); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index c31b82009f73d..c6bd263031e7d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; @@ -213,6 +214,22 @@ public void testDetect_GivenRegressionAndRequiredFieldHasInvalidType() { "expected types are [byte, double, float, half_float, integer, long, scaled_float, short]")); } + public void testDetect_GivenClassificationAndRequiredFieldHasInvalidType() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .addAggregatableField("foo", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("some_float"), RESULTS_FIELD, false, 100, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("invalid types [float] for required field [some_float]; " + + "expected types are [boolean, byte, integer, ip, keyword, long, short, text]")); + } + public void testDetect_GivenIgnoredField() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("_id", "float").build(); @@ -467,7 +484,7 @@ public void testDetect_GivenMoreFieldsThanDocValuesLimit() { contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); } - public void testDetect_GivenBooleanField() { + public void testDetect_GivenBooleanField_BooleanMappedAsInteger() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField("some_boolean", "boolean") .build(); @@ -483,19 +500,38 @@ public void testDetect_GivenBooleanField() { assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)); SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build(); - Object[] values = booleanField.value(hit); - assertThat(values.length, equalTo(1)); - assertThat(values[0], equalTo(1)); + assertThat(booleanField.value(hit), arrayContaining(1)); hit = new SearchHitBuilder(42).addField("some_boolean", false).build(); - values = booleanField.value(hit); - assertThat(values.length, equalTo(1)); - assertThat(values[0], equalTo(0)); + assertThat(booleanField.value(hit), arrayContaining(0)); hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build(); - values = booleanField.value(hit); - assertThat(values.length, equalTo(3)); - assertThat(values, arrayContaining(0, 1, 0)); + assertThat(booleanField.value(hit), arrayContaining(0, 1, 0)); + } + + public void testDetect_GivenBooleanField_BooleanMappedAsString() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("some_boolean", "boolean") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("some_boolean"), RESULTS_FIELD, false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(1)); + ExtractedField booleanField = allFields.get(0); + assertThat(booleanField.getTypes(), contains("boolean")); + assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)); + + SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build(); + assertThat(booleanField.value(hit), arrayContaining("true")); + + hit = new SearchHitBuilder(42).addField("some_boolean", false).build(); + assertThat(booleanField.value(hit), arrayContaining("false")); + + hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build(); + assertThat(booleanField.value(hit), arrayContaining("false", "true", "false")); } private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { @@ -526,6 +562,15 @@ private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVa .build(); } + private static DataFrameAnalyticsConfig buildClassificationConfig(String dependentVariable) { + return new DataFrameAnalyticsConfig.Builder() + .setId("foo") + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD)) + .setAnalysis(new Classification(dependentVariable)) + .build(); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>();