Skip to content

[7.x] [ML][Inference] Adding a warning_field for warning msgs. (#49838) #50183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class WarningInferenceResults implements InferenceResults {

public static final String NAME = "warning";
public static final ParseField WARNING = new ParseField("warning");

private final String warning;

public WarningInferenceResults(String warning) {
this.warning = warning;
}

public WarningInferenceResults(StreamInput in) throws IOException {
this.warning = in.readString();
}

public String getWarning() {
return warning;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(warning);
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
WarningInferenceResults that = (WarningInferenceResults) object;
return Objects.equals(warning, that.warning);
}

@Override
public int hashCode() {
return Objects.hash(warning);
}

@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
document.setFieldValue(parentResultField + "." + "warning", warning);
}

@Override
public String getWriteableName() {
return NAME;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public final class Messages {
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";

public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.HashMap;

import static org.hamcrest.Matchers.equalTo;

public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {

public static WarningInferenceResults createRandomResults() {
return new WarningInferenceResults(randomAlphaOfLength(10));
}

public void testWriteResults() {
WarningInferenceResults result = new WarningInferenceResults("foo");
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");

assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
}

@Override
protected WarningInferenceResults createTestInstance() {
return createRandomResults();
}

@Override
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
return WarningInferenceResults::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
Expand All @@ -34,6 +36,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static java.util.stream.Collectors.toList;

public class AnalyticsResultProcessor {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
Expand Down Expand Up @@ -163,6 +167,10 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build
Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
TrainedModelDefinition definition = inferenceModel.build();
String dependentVariable = getDependentVariable();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy("data-frame-analytics")
Expand All @@ -175,11 +183,21 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setParsedDefinition(inferenceModel)
.setInput(new TrainedModelInput(fieldNames))
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.build();
}

private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();
}
if (analytics.getAnalysis() instanceof Regression) {
return ((Regression)analytics.getAnalysis()).getDependentVariable();
}
return null;
}

private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
Expand All @@ -37,7 +39,9 @@

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
Expand Down Expand Up @@ -146,7 +150,12 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
InferenceResults inferenceResults = response.getInferenceResults().get(0);
if (inferenceResults instanceof WarningInferenceResults) {
inferenceResults.writeResult(ingestDocument, this.targetField);
} else {
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
}
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
}

Expand All @@ -164,6 +173,10 @@ public static final class Factory implements Processor.Factory, Consumer<Cluster

private static final Logger logger = LogManager.getLogger(Factory.class);

private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
WarningInferenceResults.WARNING.getPreferredName(),
MODEL_ID));

private final Client client;
private final IngestService ingestService;
private final InferenceAuditor auditor;
Expand Down Expand Up @@ -235,6 +248,7 @@ public InferenceProcessor create(Map<String, Processor.Factory> processorFactori
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));

return new InferenceProcessor(client,
auditor,
tag,
Expand All @@ -252,7 +266,6 @@ void setMaxIngestProcessors(int maxIngestProcessors) {

InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);

if (inferenceConfig.size() != 1) {
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
INFERENCE_CONFIG);
Expand All @@ -268,17 +281,38 @@ InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {

if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
return ClassificationConfig.fromMap(valueMap);
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
return config;
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
return RegressionConfig.fromMap(valueMap);
RegressionConfig config = RegressionConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField());
return config;
} else {
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
inferenceConfig.keySet(),
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
}
}

private static void checkFieldUniqueness(String... fieldNames) {
Set<String> duplicatedFieldNames = new HashSet<>();
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
for(String fieldName : fieldNames) {
if (currentFieldNames.contains(fieldName)) {
duplicatedFieldNames.add(fieldName);
} else {
currentFieldNames.add(fieldName);
}
}
if (duplicatedFieldNames.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
" More than one field is configured as {}",
duplicatedFieldNames);
}
}

void checkSupportedVersion(InferenceConfig config) {
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
Expand All @@ -287,6 +321,5 @@ void checkSupportedVersion(InferenceConfig config) {
minNodeVersion));
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,33 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;

public class LocalModel implements Model {

private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
private final Set<String> fieldNames;

public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
}

long ramBytesUsed() {
Expand Down Expand Up @@ -51,6 +61,11 @@ public String getResultsType() {
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}

listener.onResponse(trainedModelDefinition.infer(fields, config));
} catch (Exception e) {
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public void getModel(String modelId, ActionListener<Model> modelActionListener)
trainedModelConfig ->
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput())),
modelActionListener::onFailure
));
} else {
Expand Down Expand Up @@ -198,7 +199,8 @@ private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelCo
Queue<ActionListener<Model>> listeners;
LocalModel loadedModel = new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition());
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput());
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz")));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
Map<String, Object> metadata = storedModel.getMetadata();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,29 @@ public void testCreateProcessor() {
}
}

public void testCreateProcessorWithDuplicateFields() {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService);

Map<String, Object> regression = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "ml");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
}};

try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
fail("should not have succeeded creating with duplicate fields");
} catch (Exception ex) {
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
"More than one field is configured as [warning]"));
}
}

private static ClusterState buildClusterState(MetaData metaData) {
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
}
Expand Down
Loading