Skip to content

Commit 4805d8a

Browse files
authored
[ML][Inference] Adding a warning_field for warning msgs. (#49838) (#50183)
This adds a new field for the inference processor. `warning_field` is a place for us to write warnings provided from the inference call. When there are warnings we are not going to write an inference result. The goal of this is to indicate that the data provided was too poor or too different for the model to make an accurate prediction. The user could optionally include the `warning_field`. When it is not provided, it is assumed no warnings were desired to be written. The first of these warnings is when ALL of the input fields are missing. If none of the trained fields are present, we don't bother inferencing against the model and instead provide a warning stating that the fields were missing. Also, this adds checks to not allow duplicated fields during processor creation.
1 parent 41736dd commit 4805d8a

File tree

13 files changed

+316
-17
lines changed

13 files changed

+316
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.results;
7+
8+
import org.elasticsearch.common.ParseField;
9+
import org.elasticsearch.common.io.stream.StreamInput;
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.ingest.IngestDocument;
12+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
13+
14+
import java.io.IOException;
15+
import java.util.Objects;
16+
17+
public class WarningInferenceResults implements InferenceResults {
18+
19+
public static final String NAME = "warning";
20+
public static final ParseField WARNING = new ParseField("warning");
21+
22+
private final String warning;
23+
24+
public WarningInferenceResults(String warning) {
25+
this.warning = warning;
26+
}
27+
28+
public WarningInferenceResults(StreamInput in) throws IOException {
29+
this.warning = in.readString();
30+
}
31+
32+
public String getWarning() {
33+
return warning;
34+
}
35+
36+
@Override
37+
public void writeTo(StreamOutput out) throws IOException {
38+
out.writeString(warning);
39+
}
40+
41+
@Override
42+
public boolean equals(Object object) {
43+
if (object == this) { return true; }
44+
if (object == null || getClass() != object.getClass()) { return false; }
45+
WarningInferenceResults that = (WarningInferenceResults) object;
46+
return Objects.equals(warning, that.warning);
47+
}
48+
49+
@Override
50+
public int hashCode() {
51+
return Objects.hash(warning);
52+
}
53+
54+
@Override
55+
public void writeResult(IngestDocument document, String parentResultField) {
56+
ExceptionsHelper.requireNonNull(document, "document");
57+
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
58+
document.setFieldValue(parentResultField + "." + "warning", warning);
59+
}
60+
61+
@Override
62+
public String getWriteableName() {
63+
return NAME;
64+
}
65+
66+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public final class Messages {
9292
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
9393
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
9494
"Getting model definition is not supported when getting more than one model";
95+
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
9596

9697
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
9798
public static final String JOB_AUDIT_CREATED = "Job created";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.results;
7+
8+
import org.elasticsearch.common.io.stream.Writeable;
9+
import org.elasticsearch.ingest.IngestDocument;
10+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
11+
12+
import java.util.HashMap;
13+
14+
import static org.hamcrest.Matchers.equalTo;
15+
16+
public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
17+
18+
public static WarningInferenceResults createRandomResults() {
19+
return new WarningInferenceResults(randomAlphaOfLength(10));
20+
}
21+
22+
public void testWriteResults() {
23+
WarningInferenceResults result = new WarningInferenceResults("foo");
24+
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
25+
result.writeResult(document, "result_field");
26+
27+
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
28+
}
29+
30+
@Override
31+
protected WarningInferenceResults createTestInstance() {
32+
return createRandomResults();
33+
}
34+
35+
@Override
36+
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
37+
return WarningInferenceResults::new;
38+
}
39+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.elasticsearch.common.xcontent.json.JsonXContent;
1717
import org.elasticsearch.license.License;
1818
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
19+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
20+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
1921
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
2022
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
2123
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@@ -34,6 +36,8 @@
3436
import java.util.concurrent.CountDownLatch;
3537
import java.util.concurrent.TimeUnit;
3638

39+
import static java.util.stream.Collectors.toList;
40+
3741
public class AnalyticsResultProcessor {
3842

3943
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
@@ -163,6 +167,10 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build
163167
Instant createTime = Instant.now();
164168
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
165169
TrainedModelDefinition definition = inferenceModel.build();
170+
String dependentVariable = getDependentVariable();
171+
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
172+
.filter(f -> f.equals(dependentVariable) == false)
173+
.collect(toList());
166174
return TrainedModelConfig.builder()
167175
.setModelId(modelId)
168176
.setCreatedBy("data-frame-analytics")
@@ -175,11 +183,21 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build
175183
.setEstimatedHeapMemory(definition.ramBytesUsed())
176184
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
177185
.setParsedDefinition(inferenceModel)
178-
.setInput(new TrainedModelInput(fieldNames))
186+
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
179187
.setLicenseLevel(License.OperationMode.PLATINUM.description())
180188
.build();
181189
}
182190

191+
private String getDependentVariable() {
192+
if (analytics.getAnalysis() instanceof Classification) {
193+
return ((Classification)analytics.getAnalysis()).getDependentVariable();
194+
}
195+
if (analytics.getAnalysis() instanceof Regression) {
196+
return ((Regression)analytics.getAnalysis()).getDependentVariable();
197+
}
198+
return null;
199+
}
200+
183201
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
184202
CountDownLatch latch = new CountDownLatch(1);
185203
ActionListener<Boolean> storeListener = ActionListener.wrap(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

+38-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import org.elasticsearch.ingest.Processor;
2929
import org.elasticsearch.rest.RestStatus;
3030
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
31+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
32+
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3133
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
3234
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
3335
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@@ -37,7 +39,9 @@
3739

3840
import java.util.Arrays;
3941
import java.util.HashMap;
42+
import java.util.HashSet;
4043
import java.util.Map;
44+
import java.util.Set;
4145
import java.util.concurrent.atomic.AtomicBoolean;
4246
import java.util.function.BiConsumer;
4347
import java.util.function.Consumer;
@@ -146,7 +150,12 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i
146150
if (response.getInferenceResults().isEmpty()) {
147151
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
148152
}
149-
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
153+
InferenceResults inferenceResults = response.getInferenceResults().get(0);
154+
if (inferenceResults instanceof WarningInferenceResults) {
155+
inferenceResults.writeResult(ingestDocument, this.targetField);
156+
} else {
157+
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
158+
}
150159
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
151160
}
152161

@@ -164,6 +173,10 @@ public static final class Factory implements Processor.Factory, Consumer<Cluster
164173

165174
private static final Logger logger = LogManager.getLogger(Factory.class);
166175

176+
private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
177+
WarningInferenceResults.WARNING.getPreferredName(),
178+
MODEL_ID));
179+
167180
private final Client client;
168181
private final IngestService ingestService;
169182
private final InferenceAuditor auditor;
@@ -235,6 +248,7 @@ public InferenceProcessor create(Map<String, Processor.Factory> processorFactori
235248
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
236249
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
237250
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
251+
238252
return new InferenceProcessor(client,
239253
auditor,
240254
tag,
@@ -252,7 +266,6 @@ void setMaxIngestProcessors(int maxIngestProcessors) {
252266

253267
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
254268
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
255-
256269
if (inferenceConfig.size() != 1) {
257270
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
258271
INFERENCE_CONFIG);
@@ -268,17 +281,38 @@ InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
268281

269282
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
270283
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
271-
return ClassificationConfig.fromMap(valueMap);
284+
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
285+
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
286+
return config;
272287
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
273288
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
274-
return RegressionConfig.fromMap(valueMap);
289+
RegressionConfig config = RegressionConfig.fromMap(valueMap);
290+
checkFieldUniqueness(config.getResultsField());
291+
return config;
275292
} else {
276293
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
277294
inferenceConfig.keySet(),
278295
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
279296
}
280297
}
281298

299+
private static void checkFieldUniqueness(String... fieldNames) {
300+
Set<String> duplicatedFieldNames = new HashSet<>();
301+
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
302+
for(String fieldName : fieldNames) {
303+
if (currentFieldNames.contains(fieldName)) {
304+
duplicatedFieldNames.add(fieldName);
305+
} else {
306+
currentFieldNames.add(fieldName);
307+
}
308+
}
309+
if (duplicatedFieldNames.isEmpty() == false) {
310+
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
311+
" More than one field is configured as {}",
312+
duplicatedFieldNames);
313+
}
314+
}
315+
282316
void checkSupportedVersion(InferenceConfig config) {
283317
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
284318
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
@@ -287,6 +321,5 @@ void checkSupportedVersion(InferenceConfig config) {
287321
minNodeVersion));
288322
}
289323
}
290-
291324
}
292325
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,33 @@
66
package org.elasticsearch.xpack.ml.inference.loadingservice;
77

88
import org.elasticsearch.action.ActionListener;
9+
import org.elasticsearch.common.util.set.Sets;
910
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
11+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
12+
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1013
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
14+
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
1115
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1216
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
1317
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1418
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1519

20+
import java.util.HashSet;
1621
import java.util.Map;
22+
import java.util.Set;
23+
24+
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
1725

1826
public class LocalModel implements Model {
1927

2028
private final TrainedModelDefinition trainedModelDefinition;
2129
private final String modelId;
30+
private final Set<String> fieldNames;
2231

23-
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
32+
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) {
2433
this.trainedModelDefinition = trainedModelDefinition;
2534
this.modelId = modelId;
35+
this.fieldNames = new HashSet<>(input.getFieldNames());
2636
}
2737

2838
long ramBytesUsed() {
@@ -51,6 +61,11 @@ public String getResultsType() {
5161
@Override
5262
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
5363
try {
64+
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
65+
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
66+
return;
67+
}
68+
5469
listener.onResponse(trainedModelDefinition.infer(fields, config));
5570
} catch (Exception e) {
5671
listener.onFailure(e);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ public void getModel(String modelId, ActionListener<Model> modelActionListener)
141141
trainedModelConfig ->
142142
modelActionListener.onResponse(new LocalModel(
143143
trainedModelConfig.getModelId(),
144-
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())),
144+
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
145+
trainedModelConfig.getInput())),
145146
modelActionListener::onFailure
146147
));
147148
} else {
@@ -198,7 +199,8 @@ private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelCo
198199
Queue<ActionListener<Model>> listeners;
199200
LocalModel loadedModel = new LocalModel(
200201
trainedModelConfig.getModelId(),
201-
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition());
202+
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
203+
trainedModelConfig.getInput());
202204
synchronized (loadingListeners) {
203205
listeners = loadingListeners.remove(modelId);
204206
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
171171
assertThat(storedModel.getTags(), contains(JOB_ID));
172172
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
173173
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
174-
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
174+
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz")));
175175
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
176176
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
177177
Map<String, Object> metadata = storedModel.getMetadata();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java

+23
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,29 @@ public void testCreateProcessor() {
240240
}
241241
}
242242

243+
public void testCreateProcessorWithDuplicateFields() {
244+
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
245+
clusterService,
246+
Settings.EMPTY,
247+
ingestService);
248+
249+
Map<String, Object> regression = new HashMap<String, Object>() {{
250+
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
251+
put(InferenceProcessor.MODEL_ID, "my_model");
252+
put(InferenceProcessor.TARGET_FIELD, "ml");
253+
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
254+
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
255+
}};
256+
257+
try {
258+
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
259+
fail("should not have succeeded creating with duplicate fields");
260+
} catch (Exception ex) {
261+
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
262+
"More than one field is configured as [warning]"));
263+
}
264+
}
265+
243266
private static ClusterState buildClusterState(MetaData metaData) {
244267
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
245268
}

0 commit comments

Comments
 (0)