Skip to content

Allow integer types for classification's dependent variable #47902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
Expand Down Expand Up @@ -58,6 +61,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES =
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
.flatMap(Set::stream)
.collect(Collectors.toUnmodifiableSet());

private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
Expand Down Expand Up @@ -147,9 +155,17 @@ public boolean supportsCategoricalFields() {
return true;
}

@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
if (dependentVariable.equals(fieldName)) {
return ALLOWED_DEPENDENT_VARIABLE_TYPES;
}
return Types.categorical();
}

@Override
public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
return Collections.singletonList(new RequiredField(dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {

Expand All @@ -23,6 +24,12 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/
boolean supportsCategoricalFields();

/**
* @param fieldName field for which the allowed categorical types should be returned
* @return The types treated as categorical for the given field
*/
Set<String> getAllowedCategoricalTypes(String fieldName);

/**
* @return The names and types of the fields that analyzed documents must have for the analysis to operate
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class OutlierDetection implements DataFrameAnalysis {

Expand Down Expand Up @@ -213,6 +214,11 @@ public boolean supportsCategoricalFields() {
return false;
}

@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
return Collections.emptySet();
}

@Override
public List<RequiredField> getRequiredFields() {
return Collections.emptyList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
Expand Down Expand Up @@ -134,6 +135,11 @@ public boolean supportsCategoricalFields() {
return true;
}

@Override
public Set<String> getAllowedCategoricalTypes(String fieldName) {
return Types.categorical();
}

@Override
public List<RequiredField> getRequiredFields() {
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.IpFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType;
import org.elasticsearch.index.mapper.TextFieldMapper;

import java.util.Collections;
import java.util.Set;
Expand All @@ -20,16 +24,19 @@ public final class Types {
private Types() {}

private static final Set<String> CATEGORICAL_TYPES =
Collections.unmodifiableSet(
Stream.of("text", "keyword", "ip")
.collect(Collectors.toSet()));
Stream.of(TextFieldMapper.CONTENT_TYPE, KeywordFieldMapper.CONTENT_TYPE, IpFieldMapper.CONTENT_TYPE)
.collect(Collectors.toUnmodifiableSet());

private static final Set<String> NUMERICAL_TYPES =
Collections.unmodifiableSet(
Stream.concat(
Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName),
Stream.of("scaled_float"))
.collect(Collectors.toSet()));
Stream.concat(Stream.of(NumberType.values()).map(NumberType::typeName), Stream.of("scaled_float"))
.collect(Collectors.toUnmodifiableSet());

private static final Set<String> DISCRETE_NUMERICAL_TYPES =
Stream.of(NumberType.BYTE, NumberType.SHORT, NumberType.INTEGER, NumberType.LONG)
.map(NumberType::typeName)
.collect(Collectors.toUnmodifiableSet());

private static final Set<String> BOOL_TYPES = Collections.singleton(BooleanFieldMapper.CONTENT_TYPE);

public static Set<String> categorical() {
return CATEGORICAL_TYPES;
Expand All @@ -38,4 +45,12 @@ public static Set<String> categorical() {
public static Set<String> numerical() {
return NUMERICAL_TYPES;
}

public static Set<String> discreteNumerical() {
return DISCRETE_NUMERICAL_TYPES;
}

public static Set<String> bool() {
return BOOL_TYPES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
Expand All @@ -24,6 +25,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -34,12 +36,17 @@
import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.startsWith;

public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {

private static final String BOOLEAN_FIELD = "boolean-field";
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));

private String jobId;
Expand All @@ -53,7 +60,7 @@ public void cleanup() {

public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set");
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
Expand Down Expand Up @@ -91,7 +98,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_100");
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
indexData(sourceIndex, 300, 0, KEYWORD_FIELD);

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
Expand Down Expand Up @@ -126,17 +133,19 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
"Finished analysis");
}

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_50");
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
initialize(jobId);
indexData(sourceIndex, 300, 0, dependentVariable);

int numTopClasses = 2;
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0));
registerAnalytics(config);
putAnalytics(config);

Expand All @@ -151,8 +160,11 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
String predictedClassField = dependentVariable + "_prediction";
assertThat(resultsObject.containsKey(predictedClassField), is(true));
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
assertThat(predictedClassValue, is(in(dependentVariableValues)));
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues, parser);

assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
Expand All @@ -161,11 +173,9 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
} else {
nonTrainingRowsCount++;
}
assertThat(resultsObject.containsKey("top_classes"), is(false));
}
assertThat(trainingRowsCount, greaterThan(0));
assertThat(nonTrainingRowsCount, greaterThan(0));

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(jobId,
Expand All @@ -178,9 +188,32 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
"Finished analysis");
}

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
}

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, Integer::valueOf);
}

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception {
ElasticsearchStatusException e = expectThrows(
ElasticsearchStatusException.class,
() -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, Double::valueOf));
assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];"));
}

public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception {
testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
}

public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
initialize("classification_top_classes_requested");
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);

int numTopClasses = 2;
DataFrameAnalyticsConfig config =
Expand All @@ -206,7 +239,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse

assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertTopClasses(resultsObject, numTopClasses);
assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
}

assertProgress(jobId, 100, 100, 100, 100);
Expand All @@ -223,7 +256,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse

public void testDependentVariableCardinalityTooHighError() {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used.
client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex).source(KEYWORD_FIELD, "fox"));

DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
Expand All @@ -240,28 +275,42 @@ private void initialize(String jobId) {
this.destIndex = sourceIndex + "_results";
}

private static void indexData(String sourceIndex,
int numTrainingRows, int numNonTrainingRows,
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
.addMapping("_doc",
BOOLEAN_FIELD, "type=boolean",
NUMERICAL_FIELD, "type=double",
DISCRETE_NUMERICAL_FIELD, "type=integer",
KEYWORD_FIELD, "type=keyword")
.get();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
List<Object> source = List.of(
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());

IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FIELD, numericalValue);
List<Object> source = new ArrayList<>();
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
}
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
}
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(
List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size())));
}
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
Expand Down Expand Up @@ -289,7 +338,12 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
return resultsObject;
}

private static void assertTopClasses(Map<String, Object> resultsObject, int numTopClasses) {
private static <T> void assertTopClasses(
Map<String, Object> resultsObject,
int numTopClasses,
String dependentVariable,
List<T> dependentVariableValues,
Function<String, T> parser) {
assertThat(resultsObject.containsKey("top_classes"), is(true));
@SuppressWarnings("unchecked")
List<Map<String, Object>> topClasses = (List<Map<String, Object>>) resultsObject.get("top_classes");
Expand All @@ -302,9 +356,9 @@ private static void assertTopClasses(Map<String, Object> resultsObject, int numT
classProbabilities.add((Double) topClass.get("class_probability"));
}
// Assert that all the predicted class names come from the set of keyword field values.
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
// Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities.
Expand Down
Loading