Skip to content

Commit 0ac03ac

Browse files
authored
[ML] Add parsers for inference configuration classes (#51300)
1 parent 4590d41 commit 0ac03ac

File tree

8 files changed

+84
-25
lines changed

8 files changed

+84
-25
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,10 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
526526
RegressionInferenceResults.NAME,
527527
RegressionInferenceResults::new),
528528
// ML - Inference Configuration
529-
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
530-
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
529+
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(),
530+
ClassificationConfig::new),
531+
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME.getPreferredName(),
532+
RegressionConfig::new),
531533

532534
// monitoring
533535
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
@@ -591,7 +593,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
591593
new NamedWriteableRegistry.Entry(LifecycleAction.class, SetPriorityAction.NAME, SetPriorityAction::new),
592594
new NamedWriteableRegistry.Entry(LifecycleAction.class, UnfollowAction.NAME, UnfollowAction::new),
593595
new NamedWriteableRegistry.Entry(LifecycleAction.class, WaitForSnapshotAction.NAME, WaitForSnapshotAction::new),
594-
// Data Frame
596+
// Transforms
595597
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.TRANSFORM, TransformFeatureSetUsage::new),
596598
new NamedWriteableRegistry.Entry(PersistentTaskParams.class, TransformField.TASK_NAME, TransformTaskParams::new),
597599
new NamedWriteableRegistry.Entry(Task.Status.class, TransformField.TASK_NAME, TransformState::new),
@@ -647,7 +649,7 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
647649
RollupJobStatus::fromXContent),
648650
new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(RollupJobStatus.NAME),
649651
RollupJobStatus::fromXContent),
650-
// Data Frame
652+
// Transforms
651653
new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(TransformField.TASK_NAME),
652654
TransformTaskParams::fromXContent),
653655
new NamedXContentRegistry.Entry(Task.Status.class, new ParseField(TransformField.TASK_NAME),

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

+10-2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
9999
LogisticRegression.NAME,
100100
LogisticRegression::fromXContentStrict));
101101

102+
// Inference Configs
103+
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME,
104+
ClassificationConfig::fromXContent));
105+
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME,
106+
RegressionConfig::fromXContent));
107+
102108
return namedXContent;
103109
}
104110

@@ -142,8 +148,10 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
142148
RegressionInferenceResults::new));
143149

144150
// Inference Configs
145-
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
146-
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));
151+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
152+
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
153+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
154+
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
147155

148156
return namedWriteables;
149157
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java

+22-3
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,25 @@
99
import org.elasticsearch.common.ParseField;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1213
import org.elasticsearch.common.xcontent.XContentBuilder;
14+
import org.elasticsearch.common.xcontent.XContentParser;
1315
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1416

1517
import java.io.IOException;
1618
import java.util.HashMap;
1719
import java.util.Map;
1820
import java.util.Objects;
1921

22+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
23+
2024
public class ClassificationConfig implements InferenceConfig {
2125

22-
public static final String NAME = "classification";
26+
public static final ParseField NAME = new ParseField("classification");
2327

2428
public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
2529
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
30+
2631
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
2732
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
2833
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
@@ -45,6 +50,20 @@ public static ClassificationConfig fromMap(Map<String, Object> map) {
4550
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
4651
}
4752

53+
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
54+
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
55+
(Integer) args[0], (String) args[1], (String) args[2]));
56+
57+
static {
58+
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
59+
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
60+
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
61+
}
62+
63+
public static ClassificationConfig fromXContent(XContentParser parser) {
64+
return PARSER.apply(parser, null);
65+
}
66+
4867
public ClassificationConfig(Integer numTopClasses) {
4968
this(numTopClasses, null, null);
5069
}
@@ -109,12 +128,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
109128

110129
@Override
111130
public String getWriteableName() {
112-
return NAME;
131+
return NAME.getPreferredName();
113132
}
114133

115134
@Override
116135
public String getName() {
117-
return NAME;
136+
return NAME.getPreferredName();
118137
}
119138

120139
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99
import org.elasticsearch.common.ParseField;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1213
import org.elasticsearch.common.xcontent.XContentBuilder;
14+
import org.elasticsearch.common.xcontent.XContentParser;
1315
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1416

1517
import java.io.IOException;
1618
import java.util.HashMap;
1719
import java.util.Map;
1820
import java.util.Objects;
1921

22+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
23+
2024
public class RegressionConfig implements InferenceConfig {
2125

22-
public static final String NAME = "regression";
26+
public static final ParseField NAME = new ParseField("regression");
2327
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
2428
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
2529
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
@@ -35,6 +39,17 @@ public static RegressionConfig fromMap(Map<String, Object> map) {
3539
return new RegressionConfig(resultsField);
3640
}
3741

42+
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
43+
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
44+
45+
static {
46+
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
47+
}
48+
49+
public static RegressionConfig fromXContent(XContentParser parser) {
50+
return PARSER.apply(parser, null);
51+
}
52+
3853
private final String resultsField;
3954

4055
public RegressionConfig(String resultsField) {
@@ -51,7 +66,7 @@ public String getResultsField() {
5166

5267
@Override
5368
public String getWriteableName() {
54-
return NAME;
69+
return NAME.getPreferredName();
5570
}
5671

5772
@Override
@@ -61,7 +76,7 @@ public void writeTo(StreamOutput out) throws IOException {
6176

6277
@Override
6378
public String getName() {
64-
return NAME;
79+
return NAME.getPreferredName();
6580
}
6681

6782
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
import org.elasticsearch.ElasticsearchException;
99
import org.elasticsearch.common.io.stream.Writeable;
10-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
10+
import org.elasticsearch.common.xcontent.XContentParser;
11+
import org.elasticsearch.test.AbstractSerializingTestCase;
1112

13+
import java.io.IOException;
1214
import java.util.Collections;
1315
import java.util.HashMap;
1416
import java.util.Map;
1517

1618
import static org.hamcrest.Matchers.equalTo;
1719

18-
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
20+
public class ClassificationConfigTests extends AbstractSerializingTestCase<ClassificationConfig> {
1921

2022
public static ClassificationConfig randomClassificationConfig() {
2123
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
@@ -52,4 +54,8 @@ protected Writeable.Reader<ClassificationConfig> instanceReader() {
5254
return ClassificationConfig::new;
5355
}
5456

57+
@Override
58+
protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException {
59+
return ClassificationConfig.fromXContent(parser);
60+
}
5561
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
import org.elasticsearch.ElasticsearchException;
99
import org.elasticsearch.common.io.stream.Writeable;
10-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
10+
import org.elasticsearch.common.xcontent.XContentParser;
11+
import org.elasticsearch.test.AbstractSerializingTestCase;
1112

13+
import java.io.IOException;
1214
import java.util.Collections;
1315
import java.util.HashMap;
1416
import java.util.Map;
1517

1618
import static org.hamcrest.Matchers.equalTo;
1719

18-
public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {
20+
public class RegressionConfigTests extends AbstractSerializingTestCase<RegressionConfig> {
1921

2022
public static RegressionConfig randomRegressionConfig() {
2123
return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10));
@@ -45,4 +47,8 @@ protected Writeable.Reader<RegressionConfig> instanceReader() {
4547
return RegressionConfig::new;
4648
}
4749

50+
@Override
51+
protected RegressionConfig doParseInstance(XContentParser parser) throws IOException {
52+
return RegressionConfig.fromXContent(parser);
53+
}
4854
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,20 @@ InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
275275
@SuppressWarnings("unchecked")
276276
Map<String, Object> valueMap = (Map<String, Object>)value;
277277

278-
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
278+
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
279279
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
280280
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
281281
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
282282
return config;
283-
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
283+
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
284284
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
285285
RegressionConfig config = RegressionConfig.fromMap(valueMap);
286286
checkFieldUniqueness(config.getResultsField());
287287
return config;
288288
} else {
289289
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
290290
inferenceConfig.keySet(),
291-
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
291+
Arrays.asList(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
292292
}
293293
}
294294

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java

+9-6
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
178178
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
179179
put(InferenceProcessor.MODEL_ID, "my_model");
180180
put(InferenceProcessor.TARGET_FIELD, "result");
181-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
181+
put(InferenceProcessor.INFERENCE_CONFIG,
182+
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
182183
}};
183184

184185
try {
@@ -195,7 +196,7 @@ public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
195196
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
196197
put(InferenceProcessor.MODEL_ID, "my_model");
197198
put(InferenceProcessor.TARGET_FIELD, "result");
198-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
199+
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
199200
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
200201
}};
201202

@@ -220,7 +221,8 @@ public void testCreateProcessor() {
220221
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
221222
put(InferenceProcessor.MODEL_ID, "my_model");
222223
put(InferenceProcessor.TARGET_FIELD, "result");
223-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
224+
put(InferenceProcessor.INFERENCE_CONFIG,
225+
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
224226
}};
225227

226228
try {
@@ -233,7 +235,7 @@ public void testCreateProcessor() {
233235
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
234236
put(InferenceProcessor.MODEL_ID, "my_model");
235237
put(InferenceProcessor.TARGET_FIELD, "result");
236-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
238+
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
237239
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
238240
}};
239241

@@ -254,7 +256,7 @@ public void testCreateProcessorWithDuplicateFields() {
254256
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
255257
put(InferenceProcessor.MODEL_ID, "my_model");
256258
put(InferenceProcessor.TARGET_FIELD, "ml");
257-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
259+
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(),
258260
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
259261
}};
260262

@@ -302,7 +304,8 @@ private static PipelineConfiguration newConfigurationWithInferenceProcessor(Stri
302304
Collections.singletonMap(InferenceProcessor.TYPE,
303305
new HashMap<String, Object>() {{
304306
put(InferenceProcessor.MODEL_ID, modelId);
305-
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
307+
put(InferenceProcessor.INFERENCE_CONFIG,
308+
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
306309
put(InferenceProcessor.TARGET_FIELD, "new_field");
307310
put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest"));
308311
}}))))) {

0 commit comments

Comments
 (0)