Skip to content

Commit 749123b

Browse files
[7.6][ML] Validate classification dependent_variable cardinality is at lea… (#51232) (#51310)
Data frame analytics classification currently only supports 2 classes for the dependent variable. We were checking that the field's cardinality is not higher than 2 but we should also check it is not less than that as otherwise the process fails. Backport of #51232
1 parent 7a1ed2b commit 749123b

File tree

16 files changed

+197
-70
lines changed

16 files changed

+197
-70
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ public List<RequiredField> getRequiredFields() {
245245
}
246246

247247
@Override
248-
public Map<String, Long> getFieldCardinalityLimits() {
248+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
249249
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
250-
return Collections.singletonMap(dependentVariable, 2L);
250+
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
251251
}
252252

253253
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
3737
List<RequiredField> getRequiredFields();
3838

3939
/**
40-
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
40+
* @return {@link List} containing cardinality constraints for the selected (analysis-specific) fields
4141
*/
42-
Map<String, Long> getFieldCardinalityLimits();
42+
List<FieldCardinalityConstraint> getFieldCardinalityConstraints();
4343

4444
/**
4545
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
7+
8+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
9+
10+
import java.util.Objects;
11+
12+
/**
13+
* Allows checking a field's cardinality against given lower and upper bounds
14+
*/
15+
public class FieldCardinalityConstraint {
16+
17+
private final String field;
18+
private final long lowerBound;
19+
private final long upperBound;
20+
21+
public static FieldCardinalityConstraint between(String field, long lowerBound, long upperBound) {
22+
return new FieldCardinalityConstraint(field, lowerBound, upperBound);
23+
}
24+
25+
private FieldCardinalityConstraint(String field, long lowerBound, long upperBound) {
26+
this.field = Objects.requireNonNull(field);
27+
this.lowerBound = lowerBound;
28+
this.upperBound = upperBound;
29+
}
30+
31+
public String getField() {
32+
return field;
33+
}
34+
35+
public long getLowerBound() {
36+
return lowerBound;
37+
}
38+
39+
public long getUpperBound() {
40+
return upperBound;
41+
}
42+
43+
public void check(long fieldCardinality) {
44+
if (fieldCardinality < lowerBound) {
45+
throw ExceptionsHelper.badRequestException(
46+
"Field [{}] must have at least [{}] distinct values but there were [{}]",
47+
field, lowerBound, fieldCardinality);
48+
}
49+
if (fieldCardinality > upperBound) {
50+
throw ExceptionsHelper.badRequestException(
51+
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
52+
field, upperBound, fieldCardinality);
53+
}
54+
}
55+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ public List<RequiredField> getRequiredFields() {
225225
}
226226

227227
@Override
228-
public Map<String, Long> getFieldCardinalityLimits() {
229-
return Collections.emptyMap();
228+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
229+
return Collections.emptyList();
230230
}
231231

232232
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ public List<RequiredField> getRequiredFields() {
182182
}
183183

184184
@Override
185-
public Map<String, Long> getFieldCardinalityLimits() {
186-
return Collections.emptyMap();
185+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
186+
return Collections.emptyList();
187187
}
188188

189189
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.IOException;
2323
import java.util.Collections;
2424
import java.util.HashMap;
25+
import java.util.List;
2526
import java.util.Map;
2627
import java.util.Set;
2728

@@ -169,7 +170,13 @@ public void testRequiredFieldsIsNonEmpty() {
169170
}
170171

171172
public void testFieldCardinalityLimitsIsNonEmpty() {
172-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
173+
Classification classification = createTestInstance();
174+
List<FieldCardinalityConstraint> constraints = classification.getFieldCardinalityConstraints();
175+
176+
assertThat(constraints.size(), equalTo(1));
177+
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
178+
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
179+
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
173180
}
174181

175182
public void testGetExplicitlyMappedFields() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
7+
8+
import org.elasticsearch.ElasticsearchStatusException;
9+
import org.elasticsearch.rest.RestStatus;
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import static org.hamcrest.Matchers.equalTo;
13+
14+
public class FieldCardinalityConstraintTests extends ESTestCase {
15+
16+
public void testBetween_GivenWithinLimits() {
17+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
18+
19+
constraint.check(3);
20+
constraint.check(4);
21+
constraint.check(5);
22+
constraint.check(6);
23+
}
24+
25+
public void testBetween_GivenLessThanLowerBound() {
26+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
27+
28+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(2L));
29+
assertThat(e.getMessage(), equalTo("Field [foo] must have at least [3] distinct values but there were [2]"));
30+
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
31+
}
32+
33+
public void testBetween_GivenGreaterThanUpperBound() {
34+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
35+
36+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(7L));
37+
assertThat(e.getMessage(), equalTo("Field [foo] must have at most [6] distinct values but there were at least [7]"));
38+
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
39+
}
40+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public void testRequiredFieldsIsEmpty() {
8989
}
9090

9191
public void testFieldCardinalityLimitsIsEmpty() {
92-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
92+
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
9393
}
9494

9595
public void testGetExplicitlyMappedFields() {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.Collections;
2020

2121
import static org.hamcrest.Matchers.allOf;
22-
import static org.hamcrest.Matchers.anEmptyMap;
2322
import static org.hamcrest.Matchers.containsString;
2423
import static org.hamcrest.Matchers.empty;
2524
import static org.hamcrest.Matchers.equalTo;
@@ -107,7 +106,7 @@ public void testRequiredFieldsIsNonEmpty() {
107106
}
108107

109108
public void testFieldCardinalityLimitsIsEmpty() {
110-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
109+
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
111110
}
112111

113112
public void testGetExplicitlyMappedFields() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,23 @@ public void testSourceQueryIsApplied() throws IOException {
4343
String sourceIndex = "test-source-query-is-applied";
4444

4545
client().admin().indices().prepareCreate(sourceIndex)
46-
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
46+
.addMapping("_doc",
47+
"numeric_1", "type=double",
48+
"numeric_2", "type=float",
49+
"categorical", "type=keyword",
50+
"filtered_field", "type=keyword")
4751
.get();
4852

4953
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
5054
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
5155

5256
for (int i = 0; i < 30; i++) {
5357
IndexRequest indexRequest = new IndexRequest(sourceIndex);
54-
55-
// We insert one odd value out of 5 for one feature
56-
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
58+
indexRequest.source(
59+
"numeric_1", 1.0,
60+
"numeric_2", 2.0,
61+
"categorical", i % 2 == 0 ? "class_1" : "class_2",
62+
"filtered_field", i < 2 ? "bingo" : "rest"); // We tag bingo on the first two docs to ensure we have 2 classes
5763
bulkRequestBuilder.add(indexRequest);
5864
}
5965
BulkResponse bulkResponse = bulkRequestBuilder.get();
@@ -66,7 +72,7 @@ public void testSourceQueryIsApplied() throws IOException {
6672
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
6773
.setId(id)
6874
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
69-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one")),
75+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("filtered_field", "bingo")),
7076
null))
7177
.setAnalysis(new Classification("categorical"))
7278
.buildForExplain();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
1616
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
1717
import org.elasticsearch.action.admin.indices.get.GetIndexResponse;
18+
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
19+
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
20+
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
1821
import org.elasticsearch.action.support.ContextPreservingActionListener;
1922
import org.elasticsearch.client.node.NodeClient;
2023
import org.elasticsearch.cluster.ClusterState;
@@ -42,6 +45,7 @@
4245
import java.util.function.Supplier;
4346

4447
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
48+
import static org.elasticsearch.xpack.core.ClientHelper.executeWithHeadersAsync;
4549

4650
public class DataFrameAnalyticsManager {
4751

@@ -158,7 +162,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
158162

159163
// Reindexing is complete; start analytics
160164
ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
161-
refreshResponse -> {
165+
reindexResponse -> {
162166
if (task.isStopping()) {
163167
LOGGER.debug("[{}] Stopping before starting analytics process", config.getId());
164168
return;
@@ -177,6 +181,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
177181
ActionListener<CreateIndexResponse> copyIndexCreatedListener = ActionListener.wrap(
178182
createIndexResponse -> {
179183
ReindexRequest reindexRequest = new ReindexRequest();
184+
reindexRequest.setRefresh(true);
180185
reindexRequest.setSourceIndices(config.getSource().getIndex());
181186
reindexRequest.setSourceQuery(config.getSource().getParsedQuery());
182187
reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering());
@@ -224,9 +229,6 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
224229
}
225230

226231
private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
227-
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
228-
task.setReindexingFinished();
229-
230232
// Update state to ANALYZING and start process
231233
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
232234
dataExtractorFactory -> {
@@ -246,10 +248,23 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi
246248
error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
247249
);
248250

249-
// TODO This could fail with errors. In that case we get stuck with the copied index.
250-
// We could delete the index in case of failure or we could try building the factory before reindexing
251-
// to catch the error early on.
252-
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
251+
ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
252+
refreshResponse -> {
253+
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
254+
task.setReindexingFinished();
255+
256+
// TODO This could fail with errors. In that case we get stuck with the copied index.
257+
// We could delete the index in case of failure or we could try building the factory before reindexing
258+
// to catch the error early on.
259+
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
260+
},
261+
dataExtractorFactoryListener::onFailure
262+
);
263+
264+
// First we need to refresh the dest index to ensure data is searchable in case the job
265+
// was stopped after reindexing was complete but before the index was refreshed.
266+
executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, RefreshAction.INSTANCE,
267+
new RefreshRequest(config.getDest().getIndex()), refreshListener);
253268
}
254269

255270
public void stop(DataFrameAnalyticsTask task) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
2020
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2121
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
22+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
2223
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
2324
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
2425
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
@@ -284,15 +285,8 @@ private void checkRequiredFields(Set<String> fields) {
284285
}
285286

286287
private void checkFieldsWithCardinalityLimit() {
287-
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
288-
String fieldName = entry.getKey();
289-
long limit = entry.getValue();
290-
long cardinality = fieldCardinalities.get(fieldName);
291-
if (cardinality > limit) {
292-
throw ExceptionsHelper.badRequestException(
293-
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
294-
fieldName, limit, cardinality);
295-
}
288+
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
289+
constraint.check(fieldCardinalities.get(constraint.getField()));
296290
}
297291
}
298292

0 commit comments

Comments
 (0)