Skip to content

[7.x] [ML] inference only persist if there are stats (#54752) #55121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
public static final ParseField NODE_ID = new ParseField("node_id");
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
public static final ParseField TYPE = new ParseField("type");
public static final ParseField TIMESTAMP = new ParseField("time_stamp");
public static final ParseField TIMESTAMP = new ParseField("timestamp");

public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
NAME,
Expand Down Expand Up @@ -127,6 +127,10 @@ public Instant getTimeStamp() {
return timeStamp;
}

public boolean hasStats() {
return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -221,16 +225,19 @@ public Accumulator merge(InferenceStats otherStats) {
return this;
}

public void incMissingFields() {
public Accumulator incMissingFields() {
this.missingFieldsAccumulator.increment();
return this;
}

public void incInference() {
public Accumulator incInference() {
this.inferenceAccumulator.increment();
return this;
}

public void incFailure() {
public Accumulator incFailure() {
this.failureCountAccumulator.increment();
return this;
}

public InferenceStats currentStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
public void testPipelineIngest() throws Exception {

client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
Expand All @@ -47,12 +48,17 @@ public class TrainedModelStatsService {
private static final Logger logger = LogManager.getLogger(TrainedModelStatsService.class);
private static final TimeValue PERSISTENCE_INTERVAL = TimeValue.timeValueSeconds(1);

private static final String STATS_UPDATE_SCRIPT_TEMPLATE = "" +
" ctx._source.{0} += params.{0};\n" +
" ctx._source.{1} += params.{1};\n" +
" ctx._source.{2} += params.{2};\n" +
" ctx._source.{3} = params.{3};";
// Script to only update if stats have increased since last persistence
private static final String STATS_UPDATE_SCRIPT = "" +
" ctx._source.missing_all_fields_count += params.missing_all_fields_count;\n" +
" ctx._source.inference_count += params.inference_count;\n" +
" ctx._source.failure_count += params.failure_count;\n" +
" ctx._source.time_stamp = params.time_stamp;";
private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE,
InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(),
InferenceStats.INFERENCE_COUNT.getPreferredName(),
InferenceStats.FAILURE_COUNT.getPreferredName(),
InferenceStats.TIMESTAMP.getPreferredName());
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));

Expand Down Expand Up @@ -134,7 +140,7 @@ void updateStats() {
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
for(String k : statsQueue.keySet()) {
InferenceStats inferenceStats = statsQueue.remove(k);
if (inferenceStats != null) {
if (inferenceStats != null && inferenceStats.hasStats()) {
stats.add(inferenceStats);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, A
return;
}
try {
statsAccumulator.get().incInference();
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference);
currentInferenceCount.increment();

Model.mapFieldsIfNecessary(fields, defaultFieldMap);

boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
statsAccumulator.get().incMissingFields();
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields);
if (shouldPersistStats) {
persistStats();
}
Expand All @@ -130,7 +130,7 @@ public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, A
}
listener.onResponse(inferenceResults);
} catch (Exception e) {
statsAccumulator.get().incFailure();
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure);
listener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
Expand Down Expand Up @@ -551,7 +552,9 @@ private InferenceStats handleMultiNodeStatsResponse(SearchResponse response, Str
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
modelId,
null,
timeStamp == null ? Instant.now() : Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ?
Instant.now() :
Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentMatcher;
import org.mockito.Mockito;

import java.io.IOException;
Expand All @@ -62,7 +61,6 @@
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -140,12 +138,6 @@ public void testGetCachedModels() throws Exception {
assertThat(future.get(), is(not(nullValue())));
}

verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
// It is not referenced, so called eagerly
Expand Down Expand Up @@ -192,24 +184,6 @@ public void testMaxCachedLimitReached() throws Exception {
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
// Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model1);
}
}));
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model2);
}
}));
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));

// Load model 3, should invalidate 1
for(int i = 0; i < 10; i++) {
Expand Down