diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 40ca9ba253594..5a0e50cd84bae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -59,15 +59,16 @@ public class InferenceProcessor extends AbstractProcessor { public static final String TARGET_FIELD = "target_field"; public static final String FIELD_MAPPINGS = "field_mappings"; public static final String MODEL_INFO_FIELD = "model_info_field"; + public static final String INCLUDE_MODEL_METADATA = "include_model_metadata"; private final Client client; private final String modelId; private final String targetField; private final String modelInfoField; - private final Map modelInfo; private final InferenceConfig inferenceConfig; private final Map fieldMapping; + private final boolean includeModelMetadata; public InferenceProcessor(Client client, String tag, @@ -75,16 +76,16 @@ public InferenceProcessor(Client client, String modelId, InferenceConfig inferenceConfig, Map fieldMapping, - String modelInfoField) { + String modelInfoField, + boolean includeModelMetadata) { super(tag); - this.client = client; - this.targetField = targetField; - this.modelInfoField = modelInfoField; - this.modelId = modelId; - this.inferenceConfig = inferenceConfig; - this.fieldMapping = fieldMapping; - this.modelInfo = new HashMap<>(); - this.modelInfo.put("model_id", modelId); + this.client = ExceptionsHelper.requireNonNull(client, "client"); + this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); + this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD); + this.includeModelMetadata = includeModelMetadata; + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS); } public String getModelId() { @@ -128,8 +129,8 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); - if (modelInfoField != null) { - ingestDocument.setFieldValue(modelInfoField, modelInfo); + if (includeModelMetadata) { + ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId); } } @@ -207,15 +208,25 @@ public InferenceProcessor create(Map processorFactori maxIngestProcessors); } + boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true); String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); - String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); - if (modelInfoField != null && tag != null) { + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml"); + // If multiple inference processors are in the same pipeline, it is wise to tag them + // The tag will keep metadata entries from stepping on each other + if (tag != null) { modelInfoField += "." + tag; } - return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); + return new InferenceProcessor(client, + tag, + targetField, + modelId, + inferenceConfig, + fieldMapping, + modelInfoField, + includeModelMetadata); } // Package private for testing diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 4f55768407339..81f3be6aeefbf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -43,7 +43,8 @@ public void testMutateDocumentWithClassification() { "classification_model", new ClassificationConfig(0), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -54,7 +55,7 @@ public void testMutateDocumentWithClassification() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } @@ -67,7 +68,8 @@ public void testMutateDocumentClassificationTopNClasses() { "classification_model", new ClassificationConfig(2), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -83,7 +85,7 @@ public void testMutateDocumentClassificationTopNClasses() { assertThat((List>)document.getFieldValue(targetField, List.class), contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } @@ -95,7 +97,8 @@ public void testMutateDocumentRegression() { "regression_model", new RegressionConfig(), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -106,7 +109,7 @@ public void testMutateDocumentRegression() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); } @@ -118,7 +121,8 @@ public void testMutateDocumentNoModelMetaData() { "regression_model", new RegressionConfig(), Collections.emptyMap(), - null); + "ml.my_processor", + false); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -129,7 +133,40 @@ public void testMutateDocumentNoModelMetaData() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); - assertThat(document.hasField("_ml_model"), is(false)); + assertThat(document.hasField("ml"), is(false)); + } + + public void testMutateDocumentModelMetaDataExistingField() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "ml.my_processor", + true); + + //cannot use singleton map as attempting to mutate later + Map ml = new HashMap<>(){{ + put("regression_prediction", 0.55); + }}; + Map source = new HashMap<>(){{ + put("ml", ml); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(new HashMap<>(){{ + put("my_processor", Collections.singletonMap("model_id", "regression_model")); + put("regression_prediction", 0.55); + }})); } public void testGenerateRequestWithEmptyMapping() { @@ -137,12 +174,13 @@ public void testGenerateRequestWithEmptyMapping() { Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); InferenceProcessor processor = new InferenceProcessor(client, - "my_processor", - "my_field", - modelId, - new ClassificationConfig(topNClasses), - Collections.emptyMap(), - null); + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + "ml.my_processor", + false); Map source = new HashMap<>(){{ put("value1", 1); @@ -171,7 +209,8 @@ public void testGenerateWithMapping() { modelId, new ClassificationConfig(topNClasses), fieldMapping, - null); + "ml.my_processor", + false); Map source = new HashMap<>(3){{ put("value1", 1);