Skip to content

[7.6][ML] Validate classification dependent_variable cardinality is a… #51310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ public List<RequiredField> getRequiredFields() {
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonMap(dependentVariable, 2L);
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
List<RequiredField> getRequiredFields();

/**
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
* @return {@link List} containing cardinality constraints for the selected (analysis-specific) fields
*/
Map<String, Long> getFieldCardinalityLimits();
List<FieldCardinalityConstraint> getFieldCardinalityConstraints();

/**
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.Objects;

/**
* Allows checking a field's cardinality against given lower and upper bounds
*/
public class FieldCardinalityConstraint {

private final String field;
private final long lowerBound;
private final long upperBound;

public static FieldCardinalityConstraint between(String field, long lowerBound, long upperBound) {
return new FieldCardinalityConstraint(field, lowerBound, upperBound);
}

private FieldCardinalityConstraint(String field, long lowerBound, long upperBound) {
this.field = Objects.requireNonNull(field);
this.lowerBound = lowerBound;
this.upperBound = upperBound;
}

public String getField() {
return field;
}

public long getLowerBound() {
return lowerBound;
}

public long getUpperBound() {
return upperBound;
}

public void check(long fieldCardinality) {
if (fieldCardinality < lowerBound) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at least [{}] distinct values but there were [{}]",
field, lowerBound, fieldCardinality);
}
if (fieldCardinality > upperBound) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
field, upperBound, fieldCardinality);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ public List<RequiredField> getRequiredFields() {
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
return Collections.emptyList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ public List<RequiredField> getRequiredFields() {
}

@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
return Collections.emptyList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -169,7 +170,13 @@ public void testRequiredFieldsIsNonEmpty() {
}

public void testFieldCardinalityLimitsIsNonEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
Classification classification = createTestInstance();
List<FieldCardinalityConstraint> constraints = classification.getFieldCardinalityConstraints();

assertThat(constraints.size(), equalTo(1));
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
}

public void testGetExplicitlyMappedFields() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;

import static org.hamcrest.Matchers.equalTo;

public class FieldCardinalityConstraintTests extends ESTestCase {

public void testBetween_GivenWithinLimits() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);

constraint.check(3);
constraint.check(4);
constraint.check(5);
constraint.check(6);
}

public void testBetween_GivenLessThanLowerBound() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(2L));
assertThat(e.getMessage(), equalTo("Field [foo] must have at least [3] distinct values but there were [2]"));
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
}

public void testBetween_GivenGreaterThanUpperBound() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(7L));
assertThat(e.getMessage(), equalTo("Field [foo] must have at most [6] distinct values but there were at least [7]"));
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void testRequiredFieldsIsEmpty() {
}

public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
}

public void testGetExplicitlyMappedFields() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.Collections;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -107,7 +106,7 @@ public void testRequiredFieldsIsNonEmpty() {
}

public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
}

public void testGetExplicitlyMappedFields() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ public void testSourceQueryIsApplied() throws IOException {
String sourceIndex = "test-source-query-is-applied";

client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
.addMapping("_doc",
"numeric_1", "type=double",
"numeric_2", "type=float",
"categorical", "type=keyword",
"filtered_field", "type=keyword")
.get();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

for (int i = 0; i < 30; i++) {
IndexRequest indexRequest = new IndexRequest(sourceIndex);

// We insert one odd value out of 5 for one feature
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
indexRequest.source(
"numeric_1", 1.0,
"numeric_2", 2.0,
"categorical", i % 2 == 0 ? "class_1" : "class_2",
"filtered_field", i < 2 ? "bingo" : "rest"); // We tag bingo on the first two docs to ensure we have 2 classes
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
Expand All @@ -66,7 +72,7 @@ public void testSourceQueryIsApplied() throws IOException {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one")),
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("filtered_field", "bingo")),
null))
.setAnalysis(new Classification("categorical"))
.buildForExplain();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.get.GetIndexResponse;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
Expand Down Expand Up @@ -42,6 +45,7 @@
import java.util.function.Supplier;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeWithHeadersAsync;

public class DataFrameAnalyticsManager {

Expand Down Expand Up @@ -158,7 +162,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF

// Reindexing is complete; start analytics
ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
refreshResponse -> {
reindexResponse -> {
if (task.isStopping()) {
LOGGER.debug("[{}] Stopping before starting analytics process", config.getId());
return;
Expand All @@ -177,6 +181,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
ActionListener<CreateIndexResponse> copyIndexCreatedListener = ActionListener.wrap(
createIndexResponse -> {
ReindexRequest reindexRequest = new ReindexRequest();
reindexRequest.setRefresh(true);
reindexRequest.setSourceIndices(config.getSource().getIndex());
reindexRequest.setSourceQuery(config.getSource().getParsedQuery());
reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering());
Expand Down Expand Up @@ -224,9 +229,6 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
}

private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();

// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
dataExtractorFactory -> {
Expand All @@ -246,10 +248,23 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi
error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
);

// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
refreshResponse -> {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();

// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
},
dataExtractorFactoryListener::onFailure
);

// First we need to refresh the dest index to ensure data is searchable in case the job
// was stopped after reindexing was complete but before the index was refreshed.
executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, RefreshAction.INSTANCE,
new RefreshRequest(config.getDest().getIndex()), refreshListener);
}

public void stop(DataFrameAnalyticsTask task) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
Expand Down Expand Up @@ -284,15 +285,8 @@ private void checkRequiredFields(Set<String> fields) {
}

private void checkFieldsWithCardinalityLimit() {
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
String fieldName = entry.getKey();
long limit = entry.getValue();
long cardinality = fieldCardinalities.get(fieldName);
if (cardinality > limit) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
fieldName, limit, cardinality);
}
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
constraint.check(fieldCardinalities.get(constraint.getField()));
}
}

Expand Down
Loading