Skip to content

Commit 4275a71

Browse files
authored
[ML] adjusting inference processor to support foreach usage (#60915) (#61022)
`foreach` processors store information within the `_ingest` metadata object. This commit adds the contents of the `_ingest` metadata (if it is not empty). And will append new inference results if the result field already exists. This allows a `foreach` to execute and multiple inference results being written to the same result field. closes #60867
1 parent c81dc2b commit 4275a71

File tree

16 files changed

+143
-67
lines changed

16 files changed

+143
-67
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
import org.elasticsearch.common.io.stream.StreamInput;
1010
import org.elasticsearch.common.io.stream.StreamOutput;
1111
import org.elasticsearch.common.xcontent.XContentBuilder;
12-
import org.elasticsearch.ingest.IngestDocument;
1312
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
1413
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
1514
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
16-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1715

1816
import java.io.IOException;
1917
import java.util.Collections;
@@ -160,13 +158,6 @@ public Object predictedValue() {
160158
return predictionFieldType.transformPredictedValue(value(), valueAsString());
161159
}
162160

163-
@Override
164-
public void writeResult(IngestDocument document, String parentResultField) {
165-
ExceptionsHelper.requireNonNull(document, "document");
166-
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
167-
document.setFieldValue(parentResultField, asMap());
168-
}
169-
170161
public Double getPredictionProbability() {
171162
return predictionProbability;
172163
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,25 @@
88
import org.elasticsearch.common.io.stream.NamedWriteable;
99
import org.elasticsearch.common.xcontent.ToXContentFragment;
1010
import org.elasticsearch.ingest.IngestDocument;
11+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1112

1213
import java.util.Map;
1314

1415
public interface InferenceResults extends NamedWriteable, ToXContentFragment {
16+
String MODEL_ID_RESULTS_FIELD = "model_id";
1517

16-
void writeResult(IngestDocument document, String parentResultField);
18+
static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) {
19+
ExceptionsHelper.requireNonNull(results, "results");
20+
ExceptionsHelper.requireNonNull(ingestDocument, "ingestDocument");
21+
ExceptionsHelper.requireNonNull(resultField, "resultField");
22+
Map<String, Object> resultMap = results.asMap();
23+
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
24+
if (ingestDocument.hasField(resultField)) {
25+
ingestDocument.appendFieldValue(resultField, resultMap);
26+
} else {
27+
ingestDocument.setFieldValue(resultField, resultMap);
28+
}
29+
}
1730

1831
Map<String, Object> asMap();
1932

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import org.elasticsearch.common.io.stream.StreamOutput;
99
import org.elasticsearch.common.xcontent.XContentBuilder;
10-
import org.elasticsearch.ingest.IngestDocument;
1110

1211
import java.io.IOException;
1312
import java.util.Arrays;
@@ -53,15 +52,11 @@ public int hashCode() {
5352
return Objects.hash(Arrays.hashCode(value), featureImportance);
5453
}
5554

56-
@Override
57-
public void writeResult(IngestDocument document, String parentResultField) {
58-
throw new UnsupportedOperationException("[raw] does not support writing inference results");
59-
}
60-
6155
@Override
6256
public Map<String, Object> asMap() {
6357
throw new UnsupportedOperationException("[raw] does not support map conversion");
6458
}
59+
6560
@Override
6661
public Object predictedValue() {
6762
return null;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
import org.elasticsearch.common.io.stream.StreamInput;
99
import org.elasticsearch.common.io.stream.StreamOutput;
1010
import org.elasticsearch.common.xcontent.XContentBuilder;
11-
import org.elasticsearch.ingest.IngestDocument;
1211
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
1312
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
14-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1513

1614
import java.io.IOException;
1715
import java.util.Collections;
@@ -83,13 +81,6 @@ public Object predictedValue() {
8381
return super.value();
8482
}
8583

86-
@Override
87-
public void writeResult(IngestDocument document, String parentResultField) {
88-
ExceptionsHelper.requireNonNull(document, "document");
89-
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
90-
document.setFieldValue(parentResultField, asMap());
91-
}
92-
9384
@Override
9485
public Map<String, Object> asMap() {
9586
Map<String, Object> map = new LinkedHashMap<>();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import org.elasticsearch.common.io.stream.StreamInput;
1010
import org.elasticsearch.common.io.stream.StreamOutput;
1111
import org.elasticsearch.common.xcontent.XContentBuilder;
12-
import org.elasticsearch.ingest.IngestDocument;
13-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1412

1513
import java.io.IOException;
1614
import java.util.LinkedHashMap;
@@ -54,13 +52,6 @@ public int hashCode() {
5452
return Objects.hash(warning);
5553
}
5654

57-
@Override
58-
public void writeResult(IngestDocument document, String parentResultField) {
59-
ExceptionsHelper.requireNonNull(document, "document");
60-
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
61-
document.setFieldValue(parentResultField, asMap());
62-
}
63-
6455
@Override
6556
public Map<String, Object> asMap() {
6657
Map<String, Object> asMap = new LinkedHashMap<>();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.stream.Collectors;
2323
import java.util.stream.Stream;
2424

25+
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
2526
import static org.hamcrest.Matchers.equalTo;
2627
import static org.hamcrest.Matchers.hasSize;
2728

@@ -64,7 +65,7 @@ public void testWriteResultsWithClassificationLabel() {
6465
1.0,
6566
1.0);
6667
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
67-
result.writeResult(document, "result_field");
68+
writeResult(result, document, "result_field", "test");
6869

6970
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
7071
}
@@ -78,9 +79,20 @@ public void testWriteResultsWithoutClassificationLabel() {
7879
1.0,
7980
1.0);
8081
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
81-
result.writeResult(document, "result_field");
82+
writeResult(result, document, "result_field", "test");
8283

8384
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("1.0"));
85+
86+
result = new ClassificationInferenceResults(2.0,
87+
null,
88+
Collections.emptyList(),
89+
Collections.emptyList(),
90+
ClassificationConfig.EMPTY_PARAMS,
91+
1.0,
92+
1.0);
93+
writeResult(result, document, "result_field", "test");
94+
assertThat(document.getFieldValue("result_field.0.predicted_value", String.class), equalTo("1.0"));
95+
assertThat(document.getFieldValue("result_field.1.predicted_value", String.class), equalTo("2.0"));
8496
}
8597

8698
@SuppressWarnings("unchecked")
@@ -97,7 +109,7 @@ public void testWriteResultsWithTopClasses() {
97109
0.7,
98110
0.7);
99111
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
100-
result.writeResult(document, "result_field");
112+
writeResult(result, document, "result_field", "test");
101113

102114
List<?> list = document.getFieldValue("result_field.bar", List.class);
103115
assertThat(list.size(), equalTo(3));
@@ -126,7 +138,7 @@ public void testWriteResultsWithImportance() {
126138
1.0,
127139
1.0);
128140
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
129-
result.writeResult(document, "result_field");
141+
writeResult(result, document, "result_field", "test");
130142

131143
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
132144
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.stream.Collectors;
2020
import java.util.stream.Stream;
2121

22+
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
2223
import static org.hamcrest.Matchers.equalTo;
2324
import static org.hamcrest.Matchers.hasSize;
2425

@@ -37,9 +38,15 @@ public static RegressionInferenceResults createRandomResults() {
3738
public void testWriteResults() {
3839
RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionConfig.EMPTY_PARAMS);
3940
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
40-
result.writeResult(document, "result_field");
41+
writeResult(result, document, "result_field", "test");
4142

4243
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
44+
45+
result = new RegressionInferenceResults(0.5, RegressionConfig.EMPTY_PARAMS);
46+
writeResult(result, document, "result_field", "test");
47+
48+
assertThat(document.getFieldValue("result_field.0.predicted_value", Double.class), equalTo(0.3));
49+
assertThat(document.getFieldValue("result_field.1.predicted_value", Double.class), equalTo(0.5));
4350
}
4451

4552
public void testWriteResultsWithImportance() {
@@ -50,7 +57,7 @@ public void testWriteResultsWithImportance() {
5057
new RegressionConfig("predicted_value", 3),
5158
importanceList);
5259
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
53-
result.writeResult(document, "result_field");
60+
writeResult(result, document, "result_field", "test");
5461

5562
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
5663
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.HashMap;
1717

1818
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
19+
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
1920
import static org.hamcrest.Matchers.equalTo;
2021

2122
public class WarningInferenceResultsTests extends AbstractSerializingTestCase<WarningInferenceResults> {
@@ -36,9 +37,15 @@ public static WarningInferenceResults createRandomResults() {
3637
public void testWriteResults() {
3738
WarningInferenceResults result = new WarningInferenceResults("foo");
3839
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
39-
result.writeResult(document, "result_field");
40+
writeResult(result, document, "result_field", "test");
4041

4142
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
43+
44+
result = new WarningInferenceResults("bar");
45+
writeResult(result, document, "result_field", "test");
46+
47+
assertThat(document.getFieldValue("result_field.0.warning", String.class), equalTo("foo"));
48+
assertThat(document.getFieldValue("result_field.1.warning", String.class), equalTo("bar"));
4249
}
4350

4451
@Override

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,52 @@ public void testSimulateLangIdent() throws IOException {
347347
assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
348348
}
349349

350+
public void testSimulateLangIdentForeach() throws IOException {
351+
String source = "{" +
352+
" \"pipeline\": {\n" +
353+
" \"description\": \"detect text lang\",\n" +
354+
" \"processors\": [\n" +
355+
" {\n" +
356+
" \"foreach\": {\n" +
357+
" \"field\": \"greetings\",\n" +
358+
" \"processor\": {\n" +
359+
" \"inference\": {\n" +
360+
" \"model_id\": \"lang_ident_model_1\",\n" +
361+
" \"inference_config\": {\n" +
362+
" \"classification\": {\n" +
363+
" \"num_top_classes\": 5\n" +
364+
" }\n" +
365+
" },\n" +
366+
" \"field_map\": {\n" +
367+
" \"_ingest._value.text\": \"text\"\n" +
368+
" }\n" +
369+
" }\n" +
370+
" }\n" +
371+
" }\n" +
372+
" }\n" +
373+
" ]\n" +
374+
" },\n" +
375+
" \"docs\": [\n" +
376+
" {\n" +
377+
" \"_source\": {\n" +
378+
" \"greetings\": [\n" +
379+
" {\n" +
380+
" \"text\": \" a backup credit card by visiting your billing preferences page or visit the adwords help\"\n" +
381+
" },\n" +
382+
" {\n" +
383+
" \"text\": \" 개별적으로 리포트 액세스 권한을 부여할 수 있습니다 액세스 권한 부여사용자에게 프로필 리포트에 \"\n" +
384+
" }\n" +
385+
" ]\n" +
386+
" }\n" +
387+
" }\n" +
388+
" ]\n" +
389+
"}";
390+
Response response = client().performRequest(simulateRequest(source));
391+
String stringResponse = EntityUtils.toString(response.getEntity());
392+
assertThat(stringResponse, containsString("\"predicted_value\":\"en\""));
393+
assertThat(stringResponse, containsString("\"predicted_value\":\"ko\""));
394+
}
395+
350396
private static Request simulateRequest(String jsonEntity) {
351397
Request request = new Request("POST", "_ingest/pipeline/_simulate");
352398
request.setJsonEntity(jsonEntity);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.ingest.Processor;
2929
import org.elasticsearch.rest.RestStatus;
3030
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
31+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
3132
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
3233
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
3334
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
@@ -49,9 +50,11 @@
4950
import java.util.function.BiConsumer;
5051
import java.util.function.Consumer;
5152

53+
import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
5254
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
5355
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
5456
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
57+
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD;
5558

5659
public class InferenceProcessor extends AbstractProcessor {
5760

@@ -63,7 +66,6 @@ public class InferenceProcessor extends AbstractProcessor {
6366
Setting.Property.NodeScope);
6467

6568
public static final String TYPE = "inference";
66-
public static final String MODEL_ID = "model_id";
6769
public static final String INFERENCE_CONFIG = "inference_config";
6870
public static final String TARGET_FIELD = "target_field";
6971
public static final String FIELD_MAPPINGS = "field_mappings";
@@ -92,7 +94,7 @@ public InferenceProcessor(Client client,
9294
this.client = ExceptionsHelper.requireNonNull(client, "client");
9395
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
9496
this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor");
95-
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
97+
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID_RESULTS_FIELD);
9698
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
9799
this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP);
98100
}
@@ -132,6 +134,10 @@ void handleResponse(InternalInferModelAction.Response response,
132134

133135
InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) {
134136
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
137+
// Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor)
138+
if (ingestDocument.getIngestMetadata().isEmpty() == false) {
139+
fields.put(INGEST_KEY, ingestDocument.getIngestMetadata());
140+
}
135141
LocalModel.mapFieldsIfNecessary(fields, fieldMap);
136142
return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
137143
}
@@ -150,8 +156,7 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i
150156
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
151157
}
152158
assert response.getInferenceResults().size() == 1;
153-
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
154-
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
159+
InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId);
155160
}
156161

157162
@Override
@@ -278,7 +283,7 @@ public InferenceProcessor create(Map<String, Processor.Factory> processorFactori
278283
maxIngestProcessors);
279284
}
280285

281-
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
286+
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID_RESULTS_FIELD);
282287
String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag;
283288
// If multiple inference processors are in the same pipeline, it is wise to tag them
284289
// The tag will keep default value entries from stepping on each other
@@ -341,12 +346,10 @@ InferenceConfigUpdate inferenceConfigUpdateFromMap(Map<String, Object> configMap
341346

342347
if (configMap.containsKey(ClassificationConfig.NAME.getPreferredName())) {
343348
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
344-
ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
345-
return config;
349+
return ClassificationConfigUpdate.fromMap(valueMap);
346350
} else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
347351
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
348-
RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
349-
return config;
352+
return RegressionConfigUpdate.fromMap(valueMap);
350353
} else {
351354
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
352355
configMap.keySet(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.ingest.IngestMetadata;
3030
import org.elasticsearch.threadpool.ThreadPool;
3131
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
32+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
3233
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
3334
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
3435
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@@ -561,7 +562,7 @@ private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata)
561562
if (processor instanceof Map<?, ?>) {
562563
Object processorConfig = ((Map<?, ?>) processor).get(InferenceProcessor.TYPE);
563564
if (processorConfig instanceof Map<?, ?>) {
564-
Object modelId = ((Map<?, ?>) processorConfig).get(InferenceProcessor.MODEL_ID);
565+
Object modelId = ((Map<?, ?>) processorConfig).get(InferenceResults.MODEL_ID_RESULTS_FIELD);
565566
if (modelId != null) {
566567
assert modelId instanceof String;
567568
allReferencedModelKeys.add(modelId.toString());

0 commit comments

Comments
 (0)