Skip to content

[ML][Inference] adding more options to inference processor #48545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,33 @@ public class InferenceProcessor extends AbstractProcessor {
public static final String TARGET_FIELD = "target_field";
public static final String FIELD_MAPPINGS = "field_mappings";
public static final String MODEL_INFO_FIELD = "model_info_field";
public static final String INCLUDE_MODEL_METADATA = "include_model_metadata";

private final Client client;
private final String modelId;

private final String targetField;
private final String modelInfoField;
private final Map<String, Object> modelInfo;
private final InferenceConfig inferenceConfig;
private final Map<String, String> fieldMapping;
private final boolean includeModelMetadata;

public InferenceProcessor(Client client,
String tag,
String targetField,
String modelId,
InferenceConfig inferenceConfig,
Map<String, String> fieldMapping,
String modelInfoField) {
String modelInfoField,
boolean includeModelMetadata) {
super(tag);
this.client = client;
this.targetField = targetField;
this.modelInfoField = modelInfoField;
this.modelId = modelId;
this.inferenceConfig = inferenceConfig;
this.fieldMapping = fieldMapping;
this.modelInfo = new HashMap<>();
this.modelInfo.put("model_id", modelId);
this.client = ExceptionsHelper.requireNonNull(client, "client");
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
this.includeModelMetadata = includeModelMetadata;
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS);
}

public String getModelId() {
Expand Down Expand Up @@ -128,8 +129,8 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
if (modelInfoField != null) {
ingestDocument.setFieldValue(modelInfoField, modelInfo);
if (includeModelMetadata) {
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that modelInfoField should never be null?
Does it make sense to add requireNonNull in the constructor?

}
}

Expand Down Expand Up @@ -207,15 +208,25 @@ public InferenceProcessor create(Map<String, Processor.Factory> processorFactori
maxIngestProcessors);
}

boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true);
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD);
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info");
if (modelInfoField != null && tag != null) {
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml");
// If multiple inference processors are in the same pipeline, it is wise to tag them
// The tag will keep metadata entries from stepping on each other
if (tag != null) {
modelInfoField += "." + tag;
}
return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField);
return new InferenceProcessor(client,
tag,
targetField,
modelId,
inferenceConfig,
fieldMapping,
modelInfoField,
includeModelMetadata);
}

// Package private for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public void testMutateDocumentWithClassification() {
"classification_model",
new ClassificationConfig(0),
Collections.emptyMap(),
"_ml_model.my_processor");
"ml.my_processor",
true);

Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
Expand All @@ -54,7 +55,7 @@ public void testMutateDocumentWithClassification() {
inferenceProcessor.mutateDocument(response, document);

assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
assertThat(document.getFieldValue("_ml_model", Map.class),
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
}

Expand All @@ -67,7 +68,8 @@ public void testMutateDocumentClassificationTopNClasses() {
"classification_model",
new ClassificationConfig(2),
Collections.emptyMap(),
"_ml_model.my_processor");
"ml.my_processor",
true);

Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
Expand All @@ -83,7 +85,7 @@ public void testMutateDocumentClassificationTopNClasses() {

assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("_ml_model", Map.class),
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
}

Expand All @@ -95,7 +97,8 @@ public void testMutateDocumentRegression() {
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"_ml_model.my_processor");
"ml.my_processor",
true);

Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
Expand All @@ -106,7 +109,7 @@ public void testMutateDocumentRegression() {
inferenceProcessor.mutateDocument(response, document);

assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.getFieldValue("_ml_model", Map.class),
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model"))));
}

Expand All @@ -118,7 +121,8 @@ public void testMutateDocumentNoModelMetaData() {
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
null);
"ml.my_processor",
false);

Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
Expand All @@ -129,20 +133,54 @@ public void testMutateDocumentNoModelMetaData() {
inferenceProcessor.mutateDocument(response, document);

assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.hasField("_ml_model"), is(false));
assertThat(document.hasField("ml"), is(false));
}

public void testMutateDocumentModelMetaDataExistingField() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"ml.my_processor",
true);

//cannot use singleton map as attempting to mutate later
Map<String, Object> ml = new HashMap<>(){{
put("regression_prediction", 0.55);
}};
Map<String, Object> source = new HashMap<>(){{
put("ml", ml);
}};
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);

InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
inferenceProcessor.mutateDocument(response, document);

assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(new HashMap<>(){{
put("my_processor", Collections.singletonMap("model_id", "regression_model"));
put("regression_prediction", 0.55);
}}));
}

public void testGenerateRequestWithEmptyMapping() {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);

InferenceProcessor processor = new InferenceProcessor(client,
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
Collections.emptyMap(),
null);
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
Collections.emptyMap(),
"ml.my_processor",
false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
false);
false);


Map<String, Object> source = new HashMap<>(){{
put("value1", 1);
Expand Down Expand Up @@ -171,7 +209,8 @@ public void testGenerateWithMapping() {
modelId,
new ClassificationConfig(topNClasses),
fieldMapping,
null);
"ml.my_processor",
false);

Map<String, Object> source = new HashMap<>(3){{
put("value1", 1);
Expand Down