Skip to content

Commit 484886a

Browse files
authored
[ML][Inference] adding more options to inference processor (#48545)
* [ML][Inference] adding more options to inference processor, fixing minor bug * addressing PR comments
1 parent c6d977c commit 484886a

File tree

2 files changed

+80
-30
lines changed

2 files changed

+80
-30
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,33 @@ public class InferenceProcessor extends AbstractProcessor {
5959
public static final String TARGET_FIELD = "target_field";
6060
public static final String FIELD_MAPPINGS = "field_mappings";
6161
public static final String MODEL_INFO_FIELD = "model_info_field";
62+
public static final String INCLUDE_MODEL_METADATA = "include_model_metadata";
6263

6364
private final Client client;
6465
private final String modelId;
6566

6667
private final String targetField;
6768
private final String modelInfoField;
68-
private final Map<String, Object> modelInfo;
6969
private final InferenceConfig inferenceConfig;
7070
private final Map<String, String> fieldMapping;
71+
private final boolean includeModelMetadata;
7172

7273
public InferenceProcessor(Client client,
7374
String tag,
7475
String targetField,
7576
String modelId,
7677
InferenceConfig inferenceConfig,
7778
Map<String, String> fieldMapping,
78-
String modelInfoField) {
79+
String modelInfoField,
80+
boolean includeModelMetadata) {
7981
super(tag);
80-
this.client = client;
81-
this.targetField = targetField;
82-
this.modelInfoField = modelInfoField;
83-
this.modelId = modelId;
84-
this.inferenceConfig = inferenceConfig;
85-
this.fieldMapping = fieldMapping;
86-
this.modelInfo = new HashMap<>();
87-
this.modelInfo.put("model_id", modelId);
82+
this.client = ExceptionsHelper.requireNonNull(client, "client");
83+
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
84+
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
85+
this.includeModelMetadata = includeModelMetadata;
86+
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
87+
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
88+
this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS);
8889
}
8990

9091
public String getModelId() {
@@ -128,8 +129,8 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc
128129
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
129130
}
130131
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
131-
if (modelInfoField != null) {
132-
ingestDocument.setFieldValue(modelInfoField, modelInfo);
132+
if (includeModelMetadata) {
133+
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
133134
}
134135
}
135136

@@ -207,15 +208,25 @@ public InferenceProcessor create(Map<String, Processor.Factory> processorFactori
207208
maxIngestProcessors);
208209
}
209210

211+
boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true);
210212
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
211213
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD);
212214
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
213215
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
214-
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info");
215-
if (modelInfoField != null && tag != null) {
216+
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml");
217+
// If multiple inference processors are in the same pipeline, it is wise to tag them
218+
// The tag will keep metadata entries from stepping on each other
219+
if (tag != null) {
216220
modelInfoField += "." + tag;
217221
}
218-
return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField);
222+
return new InferenceProcessor(client,
223+
tag,
224+
targetField,
225+
modelId,
226+
inferenceConfig,
227+
fieldMapping,
228+
modelInfoField,
229+
includeModelMetadata);
219230
}
220231

221232
// Package private for testing

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ public void testMutateDocumentWithClassification() {
4343
"classification_model",
4444
new ClassificationConfig(0),
4545
Collections.emptyMap(),
46-
"_ml_model.my_processor");
46+
"ml.my_processor",
47+
true);
4748

4849
Map<String, Object> source = new HashMap<>();
4950
Map<String, Object> ingestMetadata = new HashMap<>();
@@ -54,7 +55,7 @@ public void testMutateDocumentWithClassification() {
5455
inferenceProcessor.mutateDocument(response, document);
5556

5657
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
57-
assertThat(document.getFieldValue("_ml_model", Map.class),
58+
assertThat(document.getFieldValue("ml", Map.class),
5859
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
5960
}
6061

@@ -67,7 +68,8 @@ public void testMutateDocumentClassificationTopNClasses() {
6768
"classification_model",
6869
new ClassificationConfig(2),
6970
Collections.emptyMap(),
70-
"_ml_model.my_processor");
71+
"ml.my_processor",
72+
true);
7173

7274
Map<String, Object> source = new HashMap<>();
7375
Map<String, Object> ingestMetadata = new HashMap<>();
@@ -83,7 +85,7 @@ public void testMutateDocumentClassificationTopNClasses() {
8385

8486
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
8587
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
86-
assertThat(document.getFieldValue("_ml_model", Map.class),
88+
assertThat(document.getFieldValue("ml", Map.class),
8789
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
8890
}
8991

@@ -95,7 +97,8 @@ public void testMutateDocumentRegression() {
9597
"regression_model",
9698
new RegressionConfig(),
9799
Collections.emptyMap(),
98-
"_ml_model.my_processor");
100+
"ml.my_processor",
101+
true);
99102

100103
Map<String, Object> source = new HashMap<>();
101104
Map<String, Object> ingestMetadata = new HashMap<>();
@@ -106,7 +109,7 @@ public void testMutateDocumentRegression() {
106109
inferenceProcessor.mutateDocument(response, document);
107110

108111
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
109-
assertThat(document.getFieldValue("_ml_model", Map.class),
112+
assertThat(document.getFieldValue("ml", Map.class),
110113
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model"))));
111114
}
112115

@@ -118,7 +121,8 @@ public void testMutateDocumentNoModelMetaData() {
118121
"regression_model",
119122
new RegressionConfig(),
120123
Collections.emptyMap(),
121-
null);
124+
"ml.my_processor",
125+
false);
122126

123127
Map<String, Object> source = new HashMap<>();
124128
Map<String, Object> ingestMetadata = new HashMap<>();
@@ -129,20 +133,54 @@ public void testMutateDocumentNoModelMetaData() {
129133
inferenceProcessor.mutateDocument(response, document);
130134

131135
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
132-
assertThat(document.hasField("_ml_model"), is(false));
136+
assertThat(document.hasField("ml"), is(false));
137+
}
138+
139+
public void testMutateDocumentModelMetaDataExistingField() {
140+
String targetField = "regression_value";
141+
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
142+
"my_processor",
143+
targetField,
144+
"regression_model",
145+
new RegressionConfig(),
146+
Collections.emptyMap(),
147+
"ml.my_processor",
148+
true);
149+
150+
//cannot use singleton map as attempting to mutate later
151+
Map<String, Object> ml = new HashMap<>(){{
152+
put("regression_prediction", 0.55);
153+
}};
154+
Map<String, Object> source = new HashMap<>(){{
155+
put("ml", ml);
156+
}};
157+
Map<String, Object> ingestMetadata = new HashMap<>();
158+
IngestDocument document = new IngestDocument(source, ingestMetadata);
159+
160+
InferModelAction.Response response = new InferModelAction.Response(
161+
Collections.singletonList(new RegressionInferenceResults(0.7)));
162+
inferenceProcessor.mutateDocument(response, document);
163+
164+
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
165+
assertThat(document.getFieldValue("ml", Map.class),
166+
equalTo(new HashMap<>(){{
167+
put("my_processor", Collections.singletonMap("model_id", "regression_model"));
168+
put("regression_prediction", 0.55);
169+
}}));
133170
}
134171

135172
public void testGenerateRequestWithEmptyMapping() {
136173
String modelId = "model";
137174
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
138175

139176
InferenceProcessor processor = new InferenceProcessor(client,
140-
"my_processor",
141-
"my_field",
142-
modelId,
143-
new ClassificationConfig(topNClasses),
144-
Collections.emptyMap(),
145-
null);
177+
"my_processor",
178+
"my_field",
179+
modelId,
180+
new ClassificationConfig(topNClasses),
181+
Collections.emptyMap(),
182+
"ml.my_processor",
183+
false);
146184

147185
Map<String, Object> source = new HashMap<>(){{
148186
put("value1", 1);
@@ -171,7 +209,8 @@ public void testGenerateWithMapping() {
171209
modelId,
172210
new ClassificationConfig(topNClasses),
173211
fieldMapping,
174-
null);
212+
"ml.my_processor",
213+
false);
175214

176215
Map<String, Object> source = new HashMap<>(3){{
177216
put("value1", 1);

0 commit comments

Comments
 (0)