Skip to content

Commit d32f6fe

Browse files
authored
[ML] inference only persist if there are stats (#54752) (#55121)
We needlessly send documents to be persisted. If there are no stats added, then we should not attempt to persist them. Also, this PR fixes the race condition that caused issue: #54786
1 parent fcd96db commit d32f6fe

File tree

6 files changed

+30
-41
lines changed

6 files changed

+30
-41
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public class InferenceStats implements ToXContentObject, Writeable {
3131
public static final ParseField NODE_ID = new ParseField("node_id");
3232
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
3333
public static final ParseField TYPE = new ParseField("type");
34-
public static final ParseField TIMESTAMP = new ParseField("time_stamp");
34+
public static final ParseField TIMESTAMP = new ParseField("timestamp");
3535

3636
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
3737
NAME,
@@ -127,6 +127,10 @@ public Instant getTimeStamp() {
127127
return timeStamp;
128128
}
129129

130+
public boolean hasStats() {
131+
return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0;
132+
}
133+
130134
@Override
131135
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
132136
builder.startObject();
@@ -221,16 +225,19 @@ public Accumulator merge(InferenceStats otherStats) {
221225
return this;
222226
}
223227

224-
public void incMissingFields() {
228+
public Accumulator incMissingFields() {
225229
this.missingFieldsAccumulator.increment();
230+
return this;
226231
}
227232

228-
public void incInference() {
233+
public Accumulator incInference() {
229234
this.inferenceAccumulator.increment();
235+
return this;
230236
}
231237

232-
public void incFailure() {
238+
public Accumulator incFailure() {
233239
this.failureCountAccumulator.increment();
240+
return this;
234241
}
235242

236243
public InferenceStats currentStats() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
9696
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
9797
}
9898

99-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
10099
public void testPipelineIngest() throws Exception {
101100

102101
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java

+12-6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.threadpool.ThreadPool;
2828
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
2929
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
30+
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
3031
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
3132
import org.elasticsearch.xpack.ml.MachineLearning;
3233
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
@@ -47,12 +48,17 @@ public class TrainedModelStatsService {
4748
private static final Logger logger = LogManager.getLogger(TrainedModelStatsService.class);
4849
private static final TimeValue PERSISTENCE_INTERVAL = TimeValue.timeValueSeconds(1);
4950

51+
private static final String STATS_UPDATE_SCRIPT_TEMPLATE = "" +
52+
" ctx._source.{0} += params.{0};\n" +
53+
" ctx._source.{1} += params.{1};\n" +
54+
" ctx._source.{2} += params.{2};\n" +
55+
" ctx._source.{3} = params.{3};";
5056
// Script to only update if stats have increased since last persistence
51-
private static final String STATS_UPDATE_SCRIPT = "" +
52-
" ctx._source.missing_all_fields_count += params.missing_all_fields_count;\n" +
53-
" ctx._source.inference_count += params.inference_count;\n" +
54-
" ctx._source.failure_count += params.failure_count;\n" +
55-
" ctx._source.time_stamp = params.time_stamp;";
57+
private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE,
58+
InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(),
59+
InferenceStats.INFERENCE_COUNT.getPreferredName(),
60+
InferenceStats.FAILURE_COUNT.getPreferredName(),
61+
InferenceStats.TIMESTAMP.getPreferredName());
5662
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
5763
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
5864

@@ -134,7 +140,7 @@ void updateStats() {
134140
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
135141
for(String k : statsQueue.keySet()) {
136142
InferenceStats inferenceStats = statsQueue.remove(k);
137-
if (inferenceStats != null) {
143+
if (inferenceStats != null && inferenceStats.hasStats()) {
138144
stats.add(inferenceStats);
139145
}
140146
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, A
110110
return;
111111
}
112112
try {
113-
statsAccumulator.get().incInference();
113+
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference);
114114
currentInferenceCount.increment();
115115

116116
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
117117

118118
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
119119
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
120-
statsAccumulator.get().incMissingFields();
120+
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields);
121121
if (shouldPersistStats) {
122122
persistStats();
123123
}
@@ -130,7 +130,7 @@ public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, A
130130
}
131131
listener.onResponse(inferenceResults);
132132
} catch (Exception e) {
133-
statsAccumulator.get().incFailure();
133+
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure);
134134
listener.onFailure(e);
135135
}
136136
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.action.support.WriteRequest;
3030
import org.elasticsearch.client.Client;
3131
import org.elasticsearch.common.CheckedBiFunction;
32+
import org.elasticsearch.common.Numbers;
3233
import org.elasticsearch.common.Strings;
3334
import org.elasticsearch.common.bytes.BytesReference;
3435
import org.elasticsearch.common.collect.Tuple;
@@ -551,7 +552,9 @@ private InferenceStats handleMultiNodeStatsResponse(SearchResponse response, Str
551552
failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(),
552553
modelId,
553554
null,
554-
timeStamp == null ? Instant.now() : Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
555+
timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ?
556+
Instant.now() :
557+
Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue())
555558
);
556559
}
557560

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java

-26
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
4747
import org.junit.After;
4848
import org.junit.Before;
49-
import org.mockito.ArgumentMatcher;
5049
import org.mockito.Mockito;
5150

5251
import java.io.IOException;
@@ -62,7 +61,6 @@
6261
import static org.hamcrest.Matchers.not;
6362
import static org.hamcrest.Matchers.nullValue;
6463
import static org.mockito.Matchers.any;
65-
import static org.mockito.Matchers.argThat;
6664
import static org.mockito.Matchers.eq;
6765
import static org.mockito.Mockito.atMost;
6866
import static org.mockito.Mockito.doAnswer;
@@ -140,12 +138,6 @@ public void testGetCachedModels() throws Exception {
140138
assertThat(future.get(), is(not(nullValue())));
141139
}
142140

143-
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
144-
@Override
145-
public boolean matches(final Object o) {
146-
return ((InferenceStats)o).getModelId().equals(model3);
147-
}
148-
}));
149141
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
150142
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
151143
// It is not referenced, so called eagerly
@@ -192,24 +184,6 @@ public void testMaxCachedLimitReached() throws Exception {
192184
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
193185
// Only loaded requested once on the initial load from the change event
194186
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
195-
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
196-
@Override
197-
public boolean matches(final Object o) {
198-
return ((InferenceStats)o).getModelId().equals(model1);
199-
}
200-
}));
201-
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
202-
@Override
203-
public boolean matches(final Object o) {
204-
return ((InferenceStats)o).getModelId().equals(model2);
205-
}
206-
}));
207-
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
208-
@Override
209-
public boolean matches(final Object o) {
210-
return ((InferenceStats)o).getModelId().equals(model3);
211-
}
212-
}));
213187

214188
// Load model 3, should invalidate 1
215189
for(int i = 0; i < 10; i++) {

0 commit comments

Comments
 (0)