Skip to content

Commit 7e0a3c4

Browse files
[ML] Extend classification to support multiple classes (#53539)
* [ML] Extend classification to support multiple classes Prepares classification analysis to support more than just two classes. It introduces a new parameter to the process config which dictates the `num_classes` to the process. It also changes the max classes limit to `30` provisionally. * We can't test cardinality is too high in the YML tests anymore * Extract max number of classes in a constant
1 parent eaa8ead commit 7e0a3c4

File tree

15 files changed

+305
-63
lines changed

15 files changed

+305
-63
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,16 @@ public class Classification implements DataFrameAnalysis {
4646

4747
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
4848

49+
private static final String NUM_CLASSES = "num_classes";
50+
4951
private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
5052
private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
5153

54+
/**
55+
* The max number of classes classification supports
56+
*/
57+
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
58+
5259
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
5360
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
5461
NAME.getPreferredName(),
@@ -218,7 +225,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
218225
}
219226

220227
@Override
221-
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
228+
public Map<String, Object> getParams(FieldInfo fieldInfo) {
222229
Map<String, Object> params = new HashMap<>();
223230
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
224231
params.putAll(boostedTreeParams.getParams());
@@ -227,10 +234,11 @@ public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
227234
if (predictionFieldName != null) {
228235
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
229236
}
230-
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
237+
String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable));
231238
if (predictionFieldType != null) {
232239
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
233240
}
241+
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
234242
return params;
235243
}
236244

@@ -272,7 +280,7 @@ public List<RequiredField> getRequiredFields() {
272280
@Override
273281
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
274282
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
275-
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
283+
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY));
276284
}
277285

278286
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

+26-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
77

8+
import org.elasticsearch.common.Nullable;
89
import org.elasticsearch.common.io.stream.NamedWriteable;
910
import org.elasticsearch.common.xcontent.ToXContentObject;
1011

@@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
1617

1718
/**
1819
* @return The analysis parameters as a map
19-
* @param extractedFields map of (name, types) for all the extracted fields
20+
* @param fieldInfo Information about the fields like types and cardinalities
2021
*/
21-
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);
22+
Map<String, Object> getParams(FieldInfo fieldInfo);
2223

2324
/**
2425
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
@@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
6465
* Returns the document id for the analysis state
6566
*/
6667
String getStateDocId(String jobId);
68+
69+
/**
70+
* Summarizes information about the fields that is necessary for analysis to generate
71+
* the parameters needed for the process configuration.
72+
*/
73+
interface FieldInfo {
74+
75+
/**
76+
* Returns the types for the given field or {@code null} if the field is unknown
77+
* @param field the field whose types to return
78+
* @return the types for the given field or {@code null} if the field is unknown
79+
*/
80+
@Nullable
81+
Set<String> getTypes(String field);
82+
83+
/**
84+
* Returns the cardinality of the given field or {@code null} if there is no cardinality for that field
85+
* @param field the field whose cardinality to get
86+
* @return the cardinality of the given field or {@code null} if there is no cardinality for that field
87+
*/
88+
@Nullable
89+
Long getCardinality(String field);
90+
}
6791
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ public int hashCode() {
192192
}
193193

194194
@Override
195-
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
195+
public Map<String, Object> getParams(FieldInfo fieldInfo) {
196196
Map<String, Object> params = new HashMap<>();
197197
if (nNeighbors != null) {
198198
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
155155
}
156156

157157
@Override
158-
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
158+
public Map<String, Object> getParams(FieldInfo fieldInfo) {
159159
Map<String, Object> params = new HashMap<>();
160160
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
161161
params.putAll(boostedTreeParams.getParams());

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

+38-9
Original file line numberDiff line numberDiff line change
@@ -187,38 +187,46 @@ public void testGetTrainingPercent() {
187187
}
188188

189189
public void testGetParams() {
190-
Map<String, Set<String>> extractedFields =
190+
DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo(
191191
Map.of(
192192
"foo", Set.of(BooleanFieldMapper.CONTENT_TYPE),
193193
"bar", Set.of(NumberFieldMapper.NumberType.LONG.typeName()),
194-
"baz", Set.of(KeywordFieldMapper.CONTENT_TYPE));
194+
"baz", Set.of(KeywordFieldMapper.CONTENT_TYPE)),
195+
Map.of(
196+
"foo", 10L,
197+
"bar", 20L,
198+
"baz", 30L)
199+
);
195200
assertThat(
196-
new Classification("foo").getParams(extractedFields),
201+
new Classification("foo").getParams(fieldInfo),
197202
equalTo(
198203
Map.of(
199204
"dependent_variable", "foo",
200205
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
201206
"num_top_classes", 2,
202207
"prediction_field_name", "foo_prediction",
203-
"prediction_field_type", "bool")));
208+
"prediction_field_type", "bool",
209+
"num_classes", 10L)));
204210
assertThat(
205-
new Classification("bar").getParams(extractedFields),
211+
new Classification("bar").getParams(fieldInfo),
206212
equalTo(
207213
Map.of(
208214
"dependent_variable", "bar",
209215
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
210216
"num_top_classes", 2,
211217
"prediction_field_name", "bar_prediction",
212-
"prediction_field_type", "int")));
218+
"prediction_field_type", "int",
219+
"num_classes", 20L)));
213220
assertThat(
214-
new Classification("baz").getParams(extractedFields),
221+
new Classification("baz").getParams(fieldInfo),
215222
equalTo(
216223
Map.of(
217224
"dependent_variable", "baz",
218225
"class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
219226
"num_top_classes", 2,
220227
"prediction_field_name", "baz_prediction",
221-
"prediction_field_type", "string")));
228+
"prediction_field_type", "string",
229+
"num_classes", 30L)));
222230
}
223231

224232
public void testRequiredFieldsIsNonEmpty() {
@@ -232,7 +240,7 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
232240
assertThat(constraints.size(), equalTo(1));
233241
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
234242
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
235-
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
243+
assertThat(constraints.get(0).getUpperBound(), equalTo(30L));
236244
}
237245

238246
public void testGetExplicitlyMappedFields() {
@@ -331,4 +339,25 @@ public void testExtractJobIdFromStateDoc() {
331339
protected Classification mutateInstanceForVersion(Classification instance, Version version) {
332340
return mutateForVersion(instance, version);
333341
}
342+
343+
private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo {
344+
345+
private final Map<String, Set<String>> fieldTypes;
346+
private final Map<String, Long> fieldCardinalities;
347+
348+
private TestFieldInfo(Map<String, Set<String>> fieldTypes, Map<String, Long> fieldCardinalities) {
349+
this.fieldTypes = fieldTypes;
350+
this.fieldCardinalities = fieldCardinalities;
351+
}
352+
353+
@Override
354+
public Set<String> getTypes(String field) {
355+
return fieldTypes.get(field);
356+
}
357+
358+
@Override
359+
public Long getCardinality(String field) {
360+
return fieldCardinalities.get(field);
361+
}
362+
}
334363
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

+2
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,11 @@ public void testStopAndRestart() throws Exception {
322322
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
323323
}
324324

325+
@AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes")
325326
public void testDependentVariableCardinalityTooHighError() throws Exception {
326327
initialize("cardinality_too_high");
327328
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
329+
328330
// Index one more document with a class different than the two already used.
329331
client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex)
330332
.source(KEYWORD_FIELD, "fox")

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import java.util.ArrayList;
1616
import java.util.Arrays;
17+
import java.util.Collections;
1718
import java.util.List;
1819
import java.util.Objects;
1920
import java.util.Set;
@@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields {
2728
private final ExtractedField timeField;
2829

2930
public TimeBasedExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
30-
super(allFields);
31+
super(allFields, Collections.emptyMap());
3132
if (!allFields.contains(timeField)) {
3233
throw new IllegalArgumentException("timeField should also be contained in allFields");
3334
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ public class ExtractedFieldsDetector {
5858
private final DataFrameAnalyticsConfig config;
5959
private final int docValueFieldsLimit;
6060
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
61-
private final Map<String, Long> fieldCardinalities;
61+
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
6262

6363
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit,
64-
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> fieldCardinalities) {
64+
FieldCapabilitiesResponse fieldCapabilitiesResponse, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
6565
this.index = Objects.requireNonNull(index);
6666
this.config = Objects.requireNonNull(config);
6767
this.docValueFieldsLimit = docValueFieldsLimit;
6868
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
69-
this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities);
69+
this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints);
7070
}
7171

7272
public Tuple<ExtractedFields, List<FieldSelection>> detect() {
@@ -286,12 +286,13 @@ private void checkRequiredFields(Set<String> fields) {
286286

287287
private void checkFieldsWithCardinalityLimit() {
288288
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
289-
constraint.check(fieldCardinalities.get(constraint.getField()));
289+
constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField()));
290290
}
291291
}
292292

293293
private ExtractedFields detectExtractedFields(Set<String> fields, Set<FieldSelection> fieldSelection) {
294-
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse);
294+
ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse,
295+
cardinalitiesForFieldsWithConstraints);
295296
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
296297
extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection);
297298
if (preferSource) {
@@ -321,7 +322,7 @@ private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields,
321322
chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection));
322323
}
323324
}
324-
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
325+
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints);
325326
}
326327

327328
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields, ExtractedField parent,
@@ -372,7 +373,7 @@ private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFiel
372373
for (ExtractedField field : extractedFields.getAllFields()) {
373374
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
374375
}
375-
return new ExtractedFields(adjusted);
376+
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
376377
}
377378

378379
private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
@@ -389,7 +390,7 @@ private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFi
389390
adjusted.add(field);
390391
}
391392
}
392-
return new ExtractedFields(adjusted);
393+
return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints);
393394
}
394395

395396
private void addIncludedFields(ExtractedFields extractedFields, Set<FieldSelection> fieldSelection) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java

+24-6
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
import java.io.IOException;
1616
import java.util.Objects;
17+
import java.util.Optional;
1718
import java.util.Set;
1819

19-
import static java.util.stream.Collectors.toMap;
20-
2120
public class AnalyticsProcessConfig implements ToXContentObject {
2221

2322
private static final String JOB_ID = "job_id";
@@ -93,12 +92,31 @@ private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields ext
9392
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
9493
builder.startObject();
9594
builder.field("name", analysis.getWriteableName());
96-
builder.field(
97-
"parameters",
98-
analysis.getParams(
99-
extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes))));
95+
builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields)));
10096
builder.endObject();
10197
return builder;
10298
}
10399
}
100+
101+
private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo {
102+
103+
private final ExtractedFields extractedFields;
104+
105+
AnalysisFieldInfo(ExtractedFields extractedFields) {
106+
this.extractedFields = Objects.requireNonNull(extractedFields);
107+
}
108+
109+
@Override
110+
public Set<String> getTypes(String field) {
111+
Optional<ExtractedField> extractedField = extractedFields.getAllFields().stream()
112+
.filter(f -> f.getName().equals(field))
113+
.findAny();
114+
return extractedField.isPresent() ? extractedField.get().getTypes() : null;
115+
}
116+
117+
@Override
118+
public Long getCardinality(String field) {
119+
return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field);
120+
}
121+
}
104122
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java

+11-3
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ public class ExtractedFields {
2828
private final List<ExtractedField> allFields;
2929
private final List<ExtractedField> docValueFields;
3030
private final String[] sourceFields;
31+
private final Map<String, Long> cardinalitiesForFieldsWithConstraints;
3132

32-
public ExtractedFields(List<ExtractedField> allFields) {
33+
public ExtractedFields(List<ExtractedField> allFields, Map<String, Long> cardinalitiesForFieldsWithConstraints) {
3334
this.allFields = Collections.unmodifiableList(allFields);
3435
this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields);
3536
this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField)
3637
.toArray(String[]::new);
38+
this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints);
3739
}
3840

3941
public List<ExtractedField> getAllFields() {
@@ -48,14 +50,20 @@ public List<ExtractedField> getDocValueFields() {
4850
return docValueFields;
4951
}
5052

53+
public Map<String, Long> getCardinalitiesForFieldsWithConstraints() {
54+
return cardinalitiesForFieldsWithConstraints;
55+
}
56+
5157
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
5258
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
5359
}
5460

5561
public static ExtractedFields build(Collection<String> allFields, Set<String> scriptFields,
56-
FieldCapabilitiesResponse fieldsCapabilities) {
62+
FieldCapabilitiesResponse fieldsCapabilities,
63+
Map<String, Long> cardinalitiesForFieldsWithConstraints) {
5764
ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities);
58-
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()));
65+
return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()),
66+
cardinalitiesForFieldsWithConstraints);
5967
}
6068

6169
public static TimeField newTimeField(String name, ExtractedField.Method method) {

0 commit comments

Comments
 (0)