diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java index 555236cb4c3e0..9a34f21fe42eb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java @@ -31,7 +31,7 @@ public class InferenceStats implements ToXContentObject, Writeable { public static final ParseField NODE_ID = new ParseField("node_id"); public static final ParseField FAILURE_COUNT = new ParseField("failure_count"); public static final ParseField TYPE = new ParseField("type"); - public static final ParseField TIMESTAMP = new ParseField("time_stamp"); + public static final ParseField TIMESTAMP = new ParseField("timestamp"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, @@ -127,6 +127,10 @@ public Instant getTimeStamp() { return timeStamp; } + public boolean hasStats() { + return missingAllFieldsCount > 0 || inferenceCount > 0 || failureCount > 0; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -221,16 +225,19 @@ public Accumulator merge(InferenceStats otherStats) { return this; } - public void incMissingFields() { + public Accumulator incMissingFields() { this.missingFieldsAccumulator.increment(); + return this; } - public void incInference() { + public Accumulator incInference() { this.inferenceAccumulator.increment(); + return this; } - public void incFailure() { + public Accumulator incFailure() { this.failureCountAccumulator.increment(); + return this; } public InferenceStats currentStats() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index eed952cd94018..7600cda1302f6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -96,7 +96,6 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception { assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10")); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786") public void testPipelineIngest() throws Exception { client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java index 0bf4b2d8c7bfe..2213347c62254 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java @@ -27,6 +27,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; @@ -47,12 +48,17 @@ public class TrainedModelStatsService { private static final Logger logger = LogManager.getLogger(TrainedModelStatsService.class); private static final TimeValue PERSISTENCE_INTERVAL = TimeValue.timeValueSeconds(1); + private static final String STATS_UPDATE_SCRIPT_TEMPLATE = "" + + " ctx._source.{0} += params.{0};\n" + + " ctx._source.{1} += params.{1};\n" + + " ctx._source.{2} += params.{2};\n" + + " ctx._source.{3} = params.{3};"; // Script to only update if stats have increased since last persistence - private static final String STATS_UPDATE_SCRIPT = "" + - " ctx._source.missing_all_fields_count += params.missing_all_fields_count;\n" + - " ctx._source.inference_count += params.inference_count;\n" + - " ctx._source.failure_count += params.failure_count;\n" + - " ctx._source.time_stamp = params.time_stamp;"; + private static final String STATS_UPDATE_SCRIPT = Messages.getMessage(STATS_UPDATE_SCRIPT_TEMPLATE, + InferenceStats.MISSING_ALL_FIELDS_COUNT.getPreferredName(), + InferenceStats.INFERENCE_COUNT.getPreferredName(), + InferenceStats.FAILURE_COUNT.getPreferredName(), + InferenceStats.TIMESTAMP.getPreferredName()); private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); @@ -134,7 +140,7 @@ void updateStats() { List stats = new ArrayList<>(statsQueue.size()); for(String k : statsQueue.keySet()) { InferenceStats inferenceStats = statsQueue.remove(k); - if (inferenceStats != null) { + if (inferenceStats != null && inferenceStats.hasStats()) { stats.add(inferenceStats); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index b3ecc40d27c2f..eda3286bf9420 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -110,14 +110,14 @@ public void infer(Map fields, InferenceConfigUpdate update, A return; } try { - statsAccumulator.get().incInference(); + statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference); currentInferenceCount.increment(); Model.mapFieldsIfNecessary(fields, defaultFieldMap); boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0); if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { - statsAccumulator.get().incMissingFields(); + statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields); if (shouldPersistStats) { persistStats(); } @@ -130,7 +130,7 @@ public void infer(Map fields, InferenceConfigUpdate update, A } listener.onResponse(inferenceResults); } catch (Exception e) { - statsAccumulator.get().incFailure(); + statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure); listener.onFailure(e); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 3aef4293075bd..7820eff0c0932 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -29,6 +29,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; +import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -551,7 +552,9 @@ private InferenceStats handleMultiNodeStatsResponse(SearchResponse response, Str failures == null ? 0L : Double.valueOf(failures.getValue()).longValue(), modelId, null, - timeStamp == null ? Instant.now() : Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue()) + timeStamp == null || (Numbers.isValidDouble(timeStamp.getValue()) == false) ? + Instant.now() : + Instant.ofEpochMilli(Double.valueOf(timeStamp.getValue()).longValue()) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 2b1a07892e7a7..58600fafdc108 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.After; import org.junit.Before; -import org.mockito.ArgumentMatcher; import org.mockito.Mockito; import java.io.IOException; @@ -62,7 +61,6 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; @@ -140,12 +138,6 @@ public void testGetCachedModels() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher() { - @Override - public boolean matches(final Object o) { - return ((InferenceStats)o).getModelId().equals(model3); - } - })); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); // It is not referenced, so called eagerly @@ -192,24 +184,6 @@ public void testMaxCachedLimitReached() throws Exception { verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any()); // Only loaded requested once on the initial load from the change event verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); - verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher() { - @Override - public boolean matches(final Object o) { - return ((InferenceStats)o).getModelId().equals(model1); - } - })); - verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher() { - @Override - public boolean matches(final Object o) { - return ((InferenceStats)o).getModelId().equals(model2); - } - })); - verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher() { - @Override - public boolean matches(final Object o) { - return ((InferenceStats)o).getModelId().equals(model3); - } - })); // Load model 3, should invalidate 1 for(int i = 0; i < 10; i++) {