Skip to content

Commit 9076fbc

Browse files
committed
[ML] Enrich documents with inference results at Fetch (#53230)
Adds a FetchSubPhase which adds a new field to the search hits with the result of the model inference performed on the hit. There isn't a direct way of configuring FetchSubPhases so SearchExtSpec is used for the purpose.
1 parent febe7af commit 9076fbc

File tree

20 files changed

+692
-11
lines changed

20 files changed

+692
-11
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ public void writeResult(IngestDocument document, String parentResultField) {
122122
}
123123
}
124124

125+
@Override
126+
public Map<String, Object> writeResultToMap(String parentResultField) {
127+
Map<String, Object> parentField = new HashMap<>();
128+
Map<String, Object> results = new HashMap<>();
129+
parentField.put(parentResultField, results);
130+
131+
results.put(resultsField, valueAsString());
132+
if (topClasses.size() > 0) {
133+
results.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
134+
}
135+
if (getFeatureImportance().size() > 0) {
136+
results.put("feature_importance", getFeatureImportance());
137+
}
138+
139+
return parentField;
140+
}
141+
142+
125143
@Override
126144
public String getWriteableName() {
127145
return NAME;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
import org.elasticsearch.common.io.stream.NamedWriteable;
99
import org.elasticsearch.ingest.IngestDocument;
1010

11+
import java.util.Map;
12+
1113
public interface InferenceResults extends NamedWriteable {
1214

1315
void writeResult(IngestDocument document, String parentResultField);
1416

17+
Map<String, Object> writeResultToMap(String parentResultField);
1518
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ public void writeResult(IngestDocument document, String parentResultField) {
5757
throw new UnsupportedOperationException("[raw] does not support writing inference results");
5858
}
5959

60+
@Override
61+
public Map<String, Object> writeResultToMap(String parentResultField) {
62+
throw new UnsupportedOperationException("[raw] does not support writing inference results");
63+
}
64+
6065
@Override
6166
public String getWriteableName() {
6267
return NAME;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import java.io.IOException;
1616
import java.util.Collections;
17+
import java.util.HashMap;
1718
import java.util.Map;
1819
import java.util.Objects;
1920

@@ -74,6 +75,20 @@ public void writeResult(IngestDocument document, String parentResultField) {
7475
}
7576
}
7677

78+
@Override
79+
public Map<String, Object> writeResultToMap(String parentResultField) {
80+
Map<String, Object> parentResult = new HashMap<>();
81+
Map<String, Object> result = new HashMap<>();
82+
parentResult.put(parentResultField, result);
83+
84+
result.put(resultsField, value());
85+
if (getFeatureImportance().size() > 0) {
86+
result.put("feature_importance", getFeatureImportance());
87+
}
88+
89+
return parentResult;
90+
}
91+
7792
@Override
7893
public String getWriteableName() {
7994
return NAME;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1313

1414
import java.io.IOException;
15+
import java.util.Collections;
16+
import java.util.Map;
1517
import java.util.Objects;
1618

1719
public class WarningInferenceResults implements InferenceResults {
1820

1921
public static final String NAME = "warning";
20-
public static final ParseField WARNING = new ParseField("warning");
22+
public static final ParseField WARNING = new ParseField(NAME);
2123

2224
private final String warning;
2325

@@ -55,7 +57,12 @@ public int hashCode() {
5557
public void writeResult(IngestDocument document, String parentResultField) {
5658
ExceptionsHelper.requireNonNull(document, "document");
5759
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
58-
document.setFieldValue(parentResultField + "." + "warning", warning);
60+
document.setFieldValue(parentResultField + "." + NAME, warning);
61+
}
62+
63+
@Override
64+
public Map<String, Object> writeResultToMap(String parentResultField) {
65+
return Collections.singletonMap(parentResultField, Collections.singletonMap(NAME, warning));
5966
}
6067

6168
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public String getTopClassesResultsField() {
108108
return topClassesResultsField;
109109
}
110110

111+
@Override
111112
public String getResultsField() {
112113
return resultsField;
113114
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
2222
default boolean requestingImportance() {
2323
return false;
2424
}
25+
26+
String getResultsField();
2527
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ public NullInferenceConfig(boolean requestingFeatureImportance) {
2323
this.requestingFeatureImportance = requestingFeatureImportance;
2424
}
2525

26+
@Override
27+
public String getResultsField() {
28+
return null;
29+
}
30+
2631
@Override
2732
public boolean isTargetTypeSupported(TargetType targetType) {
2833
return true;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public int getNumTopFeatureImportanceValues() {
8282
return numTopFeatureImportanceValues;
8383
}
8484

85+
@Override
8586
public String getResultsField() {
8687
return resultsField;
8788
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1111
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
1212
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
13+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1314

1415
import java.util.Arrays;
1516
import java.util.Collections;
@@ -81,6 +82,30 @@ public void testWriteResultsWithTopClasses() {
8182
assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
8283
}
8384

85+
@SuppressWarnings("unchecked")
86+
public void testWriteResultsToMapWithTopClasses() {
87+
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
88+
new ClassificationInferenceResults.TopClassEntry("foo", 0.7),
89+
new ClassificationInferenceResults.TopClassEntry("bar", 0.2),
90+
new ClassificationInferenceResults.TopClassEntry("baz", 0.1));
91+
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
92+
"foo",
93+
entries,
94+
new ClassificationConfig(3, "my_results", "bar"));
95+
Map<String, Object> resultsDoc = result.writeResultToMap("result_field");
96+
97+
List<?> list = (List<?>) MapHelper.dig("result_field.bar", resultsDoc);
98+
assertThat(list.size(), equalTo(3));
99+
100+
for(int i = 0; i < 3; i++) {
101+
Map<String, Object> map = (Map<String, Object>)list.get(i);
102+
assertThat(map, equalTo(entries.get(i).asValueMap()));
103+
}
104+
105+
Object value = MapHelper.dig("result_field.my_results", resultsDoc);
106+
assertThat(value, equalTo("foo"));
107+
}
108+
84109
@Override
85110
protected ClassificationInferenceResults createTestInstance() {
86111
return createRandomResults();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import org.elasticsearch.common.io.stream.Writeable;
99
import org.elasticsearch.ingest.IngestDocument;
1010
import org.elasticsearch.test.AbstractWireSerializingTestCase;
11-
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1211
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
1312
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
13+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1414

1515
import java.util.HashMap;
16+
import java.util.Map;
1617

1718
import static org.hamcrest.Matchers.equalTo;
1819

@@ -31,6 +32,14 @@ public void testWriteResults() {
3132
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
3233
}
3334

35+
public void testWriteResultsToMap() {
36+
RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionConfig.EMPTY_PARAMS);
37+
Map<String, Object> doc = result.writeResultToMap("result_field");
38+
39+
Object value = MapHelper.dig("result_field.predicted_value", doc);
40+
assertThat(value, equalTo(0.3));
41+
}
42+
3443
@Override
3544
protected RegressionInferenceResults createTestInstance() {
3645
return createRandomResults();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import org.elasticsearch.common.io.stream.Writeable;
99
import org.elasticsearch.ingest.IngestDocument;
1010
import org.elasticsearch.test.AbstractWireSerializingTestCase;
11+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1112

1213
import java.util.HashMap;
14+
import java.util.Map;
1315

1416
import static org.hamcrest.Matchers.equalTo;
1517

@@ -27,6 +29,14 @@ public void testWriteResults() {
2729
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
2830
}
2931

32+
public void testWriteResultToMap() {
33+
WarningInferenceResults result = new WarningInferenceResults("foo");
34+
Map<String, Object> doc = result.writeResultToMap("result_field");
35+
36+
Object field = MapHelper.dig("result_field.warning", doc);
37+
assertThat(field, equalTo("foo"));
38+
}
39+
3040
@Override
3141
protected WarningInferenceResults createTestInstance() {
3242
return createRandomResults();

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ integTest.runner {
123123
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
124124
'ml/delete_model_snapshot/Test delete snapshot missing job_id',
125125
'ml/delete_model_snapshot/Test delete with in-use model',
126+
'ml/fetch_inference/Test fetch regression',
127+
'ml/fetch_inference/Test fetch classification',
126128
'ml/filter_crud/Test create filter api with mismatching body ID',
127129
'ml/filter_crud/Test create filter given invalid filter_id',
128130
'ml/filter_crud/Test get filter API with bad ID',

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@
4747
import org.elasticsearch.plugins.IngestPlugin;
4848
import org.elasticsearch.plugins.PersistentTaskPlugin;
4949
import org.elasticsearch.plugins.Plugin;
50+
import org.elasticsearch.plugins.SearchPlugin;
5051
import org.elasticsearch.plugins.SystemIndexPlugin;
5152
import org.elasticsearch.rest.RestController;
5253
import org.elasticsearch.rest.RestHandler;
5354
import org.elasticsearch.script.ScriptService;
55+
import org.elasticsearch.search.fetch.FetchSubPhase;
5456
import org.elasticsearch.threadpool.ExecutorBuilder;
5557
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
5658
import org.elasticsearch.threadpool.ThreadPool;
@@ -213,6 +215,8 @@
213215
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
214216
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
215217
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
218+
import org.elasticsearch.xpack.ml.inference.search.InferencePhase;
219+
import org.elasticsearch.xpack.ml.inference.search.InferenceSearchExtBuilder;
216220
import org.elasticsearch.xpack.ml.job.JobManager;
217221
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
218222
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
@@ -321,7 +325,7 @@
321325

322326
import static java.util.Collections.emptyList;
323327

324-
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
328+
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin {
325329
public static final String NAME = "ml";
326330
public static final String BASE_PATH = "/_ml/";
327331
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
@@ -417,6 +421,7 @@ public Set<DiscoveryNodeRole> getRoles() {
417421
private final SetOnce<DataFrameAnalyticsManager> dataFrameAnalyticsManager = new SetOnce<>();
418422
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
419423
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
424+
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
420425

421426
public MachineLearning(Settings settings, Path configPath) {
422427
this.settings = settings;
@@ -631,6 +636,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
631636
clusterService,
632637
xContentRegistry,
633638
settings);
639+
this.modelLoadingService.set(modelLoadingService);
634640

635641
// Data frame analytics components
636642
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
@@ -891,6 +897,18 @@ public Map<String, AnalysisProvider<TokenizerFactory>> getTokenizers() {
891897
return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
892898
}
893899

900+
@Override
901+
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
902+
return Collections.singletonList(new InferencePhase(modelLoadingService));
903+
}
904+
905+
@Override
906+
public List<SearchExtSpec<?>> getSearchExts() {
907+
return Collections.singletonList(
908+
new SearchExtSpec<>(InferenceSearchExtBuilder.NAME, InferenceSearchExtBuilder::new,
909+
InferenceSearchExtBuilder::fromXContent));
910+
}
911+
894912
@Override
895913
public UnaryOperator<Map<String, IndexTemplateMetaData>> getIndexTemplateMetaDataUpgrader() {
896914
return UnaryOperator.identity();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import org.elasticsearch.action.ActionListener;
99
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
1010
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
11+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
12+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
13+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1114
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1215
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
1316
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
1417
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
15-
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
16-
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
17-
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1818
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1919

2020
import java.util.HashMap;
@@ -50,6 +50,11 @@ public String getModelId() {
5050
return modelId;
5151
}
5252

53+
@Override
54+
public Set<String> getFieldNames() {
55+
return fieldNames;
56+
}
57+
5358
@Override
5459
public String getResultsType() {
5560
switch (trainedModelDefinition.getTrainedModel().targetType()) {
@@ -59,24 +64,32 @@ public String getResultsType() {
5964
return RegressionInferenceResults.NAME;
6065
default:
6166
throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]",
62-
modelId,
63-
trainedModelDefinition.getTrainedModel().targetType());
67+
modelId,
68+
trainedModelDefinition.getTrainedModel().targetType());
6469
}
6570
}
6671

6772
@Override
68-
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
73+
public void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener) {
6974
try {
7075
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
7176
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
7277
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
7378
return;
7479
}
7580

76-
listener.onResponse(trainedModelDefinition.infer(fields, config));
81+
listener.onResponse(trainedModelDefinition.infer(fields, inferenceConfig));
7782
} catch (Exception e) {
7883
listener.onFailure(e);
7984
}
8085
}
8186

87+
@Override
88+
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
89+
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
90+
return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId));
91+
}
92+
93+
return trainedModelDefinition.infer(fields, config);
94+
}
8295
}

0 commit comments

Comments
 (0)