Skip to content

Commit cf66338

Browse files
authored
[ML] optimize the job stats call to do fewer searches (#82362)
This reduces the number of searches required for a closed job stats from 4 to 2. This commit merges the searches for model_size_stats, timing_stats, and data_counts. Relates to: #82255
1 parent e8a4664 commit cf66338

File tree

6 files changed

+199
-118
lines changed

6 files changed

+199
-118
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/DataCounts.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ public DataCounts(StreamInput in) throws IOException {
254254
logTime = in.readOptionalInstant();
255255
}
256256

257-
public String getJobid() {
257+
public String getJobId() {
258258
return jobId;
259259
}
260260

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/DataCountsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public void testSetEarliestRecordTimestamp_doesnotOverwrite() {
208208

209209
public void testDocumentId() {
210210
DataCounts dataCounts = createTestInstance();
211-
String jobId = dataCounts.getJobid();
211+
String jobId = dataCounts.getJobId();
212212
assertEquals(jobId + "_data_counts", DataCounts.documentId(jobId));
213213
}
214214

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.cluster.service.ClusterApplierService;
3030
import org.elasticsearch.cluster.service.ClusterService;
3131
import org.elasticsearch.cluster.service.MasterService;
32+
import org.elasticsearch.common.CheckedSupplier;
3233
import org.elasticsearch.common.Strings;
3334
import org.elasticsearch.common.collect.ImmutableOpenMap;
3435
import org.elasticsearch.common.settings.ClusterSettings;
@@ -62,6 +63,7 @@
6263
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
6364
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
6465
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.Quantiles;
66+
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.TimingStats;
6567
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
6668
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
6769
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
@@ -88,6 +90,7 @@
8890
import java.util.Set;
8991
import java.util.concurrent.CountDownLatch;
9092
import java.util.concurrent.atomic.AtomicReference;
93+
import java.util.function.Consumer;
9194
import java.util.stream.Collectors;
9295

9396
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary;
@@ -106,6 +109,7 @@ public class JobResultsProviderIT extends MlSingleNodeTestCase {
106109

107110
private JobResultsProvider jobProvider;
108111
private ResultsPersisterService resultsPersisterService;
112+
private JobResultsPersister jobResultsPersister;
109113
private AnomalyDetectionAuditor auditor;
110114

111115
@Before
@@ -131,6 +135,7 @@ public void createComponents() throws Exception {
131135

132136
OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN);
133137
resultsPersisterService = new ResultsPersisterService(tp, originSettingClient, clusterService, builder.build());
138+
jobResultsPersister = new JobResultsPersister(originSettingClient, resultsPersisterService);
134139
auditor = new AnomalyDetectionAuditor(client(), clusterService);
135140
waitForMlTemplates();
136141
}
@@ -406,6 +411,96 @@ public void testRemoveJobFromCalendar() throws Exception {
406411
}
407412
}
408413

414+
public void testGetDataCountsModelSizeAndTimingStatsWithNoDocs() throws Exception {
415+
Job.Builder job = new Job.Builder("first_job");
416+
job.setAnalysisConfig(createAnalysisConfig("by_field_1", Collections.emptyList()));
417+
job.setDataDescription(new DataDescription.Builder());
418+
419+
// Put first job. This should create the results index as it's the first job.
420+
client().execute(PutJobAction.INSTANCE, new PutJobAction.Request(job)).actionGet();
421+
AtomicReference<DataCounts> dataCountsAtomicReference = new AtomicReference<>();
422+
AtomicReference<ModelSizeStats> modelSizeStatsAtomicReference = new AtomicReference<>();
423+
AtomicReference<TimingStats> timingStatsAtomicReference = new AtomicReference<>();
424+
AtomicReference<Exception> exceptionAtomicReference = new AtomicReference<>();
425+
426+
getDataCountsModelSizeAndTimingStats(
427+
job.getId(),
428+
dataCountsAtomicReference::set,
429+
modelSizeStatsAtomicReference::set,
430+
timingStatsAtomicReference::set,
431+
exceptionAtomicReference::set
432+
);
433+
434+
if (exceptionAtomicReference.get() != null) {
435+
throw exceptionAtomicReference.get();
436+
}
437+
438+
assertThat(dataCountsAtomicReference.get().getJobId(), equalTo(job.getId()));
439+
assertThat(modelSizeStatsAtomicReference.get().getJobId(), equalTo(job.getId()));
440+
assertThat(timingStatsAtomicReference.get().getJobId(), equalTo(job.getId()));
441+
}
442+
443+
public void testGetDataCountsModelSizeAndTimingStatsWithSomeDocs() throws Exception {
444+
Job.Builder job = new Job.Builder("first_job");
445+
job.setAnalysisConfig(createAnalysisConfig("by_field_1", Collections.emptyList()));
446+
job.setDataDescription(new DataDescription.Builder());
447+
448+
// Put first job. This should create the results index as it's the first job.
449+
client().execute(PutJobAction.INSTANCE, new PutJobAction.Request(job)).actionGet();
450+
AtomicReference<DataCounts> dataCountsAtomicReference = new AtomicReference<>();
451+
AtomicReference<ModelSizeStats> modelSizeStatsAtomicReference = new AtomicReference<>();
452+
AtomicReference<TimingStats> timingStatsAtomicReference = new AtomicReference<>();
453+
AtomicReference<Exception> exceptionAtomicReference = new AtomicReference<>();
454+
455+
CheckedSupplier<Void, Exception> setOrThrow = () -> {
456+
getDataCountsModelSizeAndTimingStats(
457+
job.getId(),
458+
dataCountsAtomicReference::set,
459+
modelSizeStatsAtomicReference::set,
460+
timingStatsAtomicReference::set,
461+
exceptionAtomicReference::set
462+
);
463+
464+
if (exceptionAtomicReference.get() != null) {
465+
throw exceptionAtomicReference.get();
466+
}
467+
return null;
468+
};
469+
470+
ModelSizeStats storedModelSizeStats = new ModelSizeStats.Builder(job.getId()).setModelBytes(10L).build();
471+
jobResultsPersister.persistModelSizeStats(storedModelSizeStats, () -> false);
472+
jobResultsPersister.commitResultWrites(job.getId());
473+
474+
setOrThrow.get();
475+
assertThat(dataCountsAtomicReference.get().getJobId(), equalTo(job.getId()));
476+
assertThat(modelSizeStatsAtomicReference.get(), equalTo(storedModelSizeStats));
477+
assertThat(timingStatsAtomicReference.get().getJobId(), equalTo(job.getId()));
478+
479+
TimingStats storedTimingStats = new TimingStats(job.getId());
480+
storedTimingStats.updateStats(10);
481+
482+
jobResultsPersister.bulkPersisterBuilder(job.getId()).persistTimingStats(storedTimingStats).executeRequest();
483+
jobResultsPersister.commitResultWrites(job.getId());
484+
485+
setOrThrow.get();
486+
487+
assertThat(dataCountsAtomicReference.get().getJobId(), equalTo(job.getId()));
488+
assertThat(modelSizeStatsAtomicReference.get(), equalTo(storedModelSizeStats));
489+
assertThat(timingStatsAtomicReference.get(), equalTo(storedTimingStats));
490+
491+
DataCounts storedDataCounts = new DataCounts(job.getId());
492+
storedDataCounts.incrementInputBytes(1L);
493+
storedDataCounts.incrementMissingFieldCount(1L);
494+
JobDataCountsPersister jobDataCountsPersister = new JobDataCountsPersister(client(), resultsPersisterService, auditor);
495+
jobDataCountsPersister.persistDataCounts(job.getId(), storedDataCounts);
496+
jobResultsPersister.commitResultWrites(job.getId());
497+
498+
setOrThrow.get();
499+
assertThat(dataCountsAtomicReference.get(), equalTo(storedDataCounts));
500+
assertThat(modelSizeStatsAtomicReference.get(), equalTo(storedModelSizeStats));
501+
assertThat(timingStatsAtomicReference.get(), equalTo(storedTimingStats));
502+
}
503+
409504
private Map<String, Object> getIndexMappingProperties(String index) {
410505
GetMappingsRequest request = new GetMappingsRequest().indices(index);
411506
GetMappingsResponse response = client().execute(GetMappingsAction.INSTANCE, request).actionGet();
@@ -498,6 +593,26 @@ private Calendar getCalendar(String calendarId) throws Exception {
498593
return calendarHolder.get();
499594
}
500595

596+
private void getDataCountsModelSizeAndTimingStats(
597+
String jobId,
598+
Consumer<DataCounts> dataCountsConsumer,
599+
Consumer<ModelSizeStats> modelSizeStatsConsumer,
600+
Consumer<TimingStats> timingStatsConsumer,
601+
Consumer<Exception> exceptionConsumer
602+
) throws Exception {
603+
CountDownLatch latch = new CountDownLatch(1);
604+
jobProvider.getDataCountsModelSizeAndTimingStats(jobId, (dataCounts, modelSizeStats, timingStats) -> {
605+
dataCountsConsumer.accept(dataCounts);
606+
modelSizeStatsConsumer.accept(modelSizeStats);
607+
timingStatsConsumer.accept(timingStats);
608+
latch.countDown();
609+
}, e -> {
610+
exceptionConsumer.accept(e);
611+
latch.countDown();
612+
});
613+
latch.await();
614+
}
615+
501616
public void testScheduledEventsForJobs() throws Exception {
502617
Job.Builder jobA = createJob("job_a");
503618
Job.Builder jobB = createJob("job_b");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetJobsStatsAction.java

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.cluster.ClusterState;
1717
import org.elasticsearch.cluster.node.DiscoveryNode;
1818
import org.elasticsearch.cluster.service.ClusterService;
19-
import org.elasticsearch.common.TriConsumer;
2019
import org.elasticsearch.common.inject.Inject;
2120
import org.elasticsearch.common.util.concurrent.AtomicArray;
2221
import org.elasticsearch.core.TimeValue;
@@ -191,7 +190,7 @@ void gatherStatsForClosedJobs(
191190
int slot = i;
192191
String jobId = closedJobIds.get(i);
193192
gatherForecastStats(jobId, forecastStats -> {
194-
gatherDataCountsModelSizeStatsAndTimingStats(jobId, (dataCounts, modelSizeStats, timingStats) -> {
193+
jobResultsProvider.getDataCountsModelSizeAndTimingStats(jobId, (dataCounts, modelSizeStats, timingStats) -> {
195194
JobState jobState = MlTasks.getJobState(jobId, tasks);
196195
PersistentTasksCustomMetadata.PersistentTask<?> pTask = MlTasks.getJobTask(jobId, tasks);
197196
String assignmentExplanation = null;
@@ -238,26 +237,6 @@ void gatherForecastStats(String jobId, Consumer<ForecastStats> handler, Consumer
238237
jobResultsProvider.getForecastStats(jobId, handler, errorHandler);
239238
}
240239

241-
void gatherDataCountsModelSizeStatsAndTimingStats(
242-
String jobId,
243-
TriConsumer<DataCounts, ModelSizeStats, TimingStats> handler,
244-
Consumer<Exception> errorHandler
245-
) {
246-
jobResultsProvider.dataCounts(jobId, dataCounts -> {
247-
jobResultsProvider.modelSizeStats(
248-
jobId,
249-
modelSizeStats -> {
250-
jobResultsProvider.timingStats(
251-
jobId,
252-
timingStats -> { handler.apply(dataCounts, modelSizeStats, timingStats); },
253-
errorHandler
254-
);
255-
},
256-
errorHandler
257-
);
258-
}, errorHandler);
259-
}
260-
261240
static TimeValue durationToTimeValue(Optional<Duration> duration) {
262241
if (duration.isPresent()) {
263242
return TimeValue.timeValueSeconds(duration.get().getSeconds());

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
4747
import org.elasticsearch.cluster.metadata.MappingMetadata;
4848
import org.elasticsearch.common.Strings;
49+
import org.elasticsearch.common.TriConsumer;
4950
import org.elasticsearch.common.bytes.BytesReference;
5051
import org.elasticsearch.common.settings.Settings;
5152
import org.elasticsearch.common.util.CollectionUtils;
@@ -68,9 +69,12 @@
6869
import org.elasticsearch.search.aggregations.Aggregation;
6970
import org.elasticsearch.search.aggregations.AggregationBuilders;
7071
import org.elasticsearch.search.aggregations.Aggregations;
72+
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
73+
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator;
7174
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
7275
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
7376
import org.elasticsearch.search.aggregations.metrics.Stats;
77+
import org.elasticsearch.search.aggregations.metrics.TopHits;
7478
import org.elasticsearch.search.builder.SearchSourceBuilder;
7579
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
7680
import org.elasticsearch.search.sort.FieldSortBuilder;
@@ -489,6 +493,77 @@ public void dataCounts(String jobId, Consumer<DataCounts> handler, Consumer<Exce
489493
);
490494
}
491495

496+
public void getDataCountsModelSizeAndTimingStats(
497+
String jobId,
498+
TriConsumer<DataCounts, ModelSizeStats, TimingStats> handler,
499+
Consumer<Exception> errorHandler
500+
) {
501+
final String results = "results";
502+
final String timingStats = "timing_stats";
503+
final String dataCounts = "data_counts";
504+
final String modelSizeStats = "model_size_stats";
505+
final String topHits = "hits";
506+
SearchRequest request = client.prepareSearch(AnomalyDetectorsIndex.jobResultsAliasedName(jobId))
507+
.setSize(0)
508+
.setTrackTotalHits(false)
509+
.setIndicesOptions(IndicesOptions.lenientExpandOpen())
510+
.addAggregation(
511+
AggregationBuilders.filters(
512+
results,
513+
new FiltersAggregator.KeyedFilter(
514+
dataCounts,
515+
QueryBuilders.idsQuery().addIds(DataCounts.documentId(jobId), DataCounts.v54DocumentId(jobId))
516+
),
517+
new FiltersAggregator.KeyedFilter(timingStats, QueryBuilders.idsQuery().addIds(TimingStats.documentId(jobId))),
518+
new FiltersAggregator.KeyedFilter(
519+
modelSizeStats,
520+
QueryBuilders.termQuery(Result.RESULT_TYPE.getPreferredName(), ModelSizeStats.RESULT_TYPE_VALUE)
521+
)
522+
)
523+
.subAggregation(
524+
AggregationBuilders.topHits(topHits)
525+
.size(1)
526+
.sorts(
527+
List.of(
528+
SortBuilders.fieldSort(DataCounts.LOG_TIME.getPreferredName())
529+
.order(SortOrder.DESC)
530+
.unmappedType(NumberFieldMapper.NumberType.LONG.typeName())
531+
.missing(0L),
532+
SortBuilders.fieldSort(TimingStats.BUCKET_COUNT.getPreferredName())
533+
.order(SortOrder.DESC)
534+
.unmappedType(NumberFieldMapper.NumberType.LONG.typeName())
535+
.missing(0L)
536+
)
537+
)
538+
)
539+
)
540+
.request();
541+
executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, request, ActionListener.<SearchResponse>wrap(response -> {
542+
Aggregations aggs = response.getAggregations();
543+
if (aggs == null) {
544+
handler.apply(new DataCounts(jobId), new ModelSizeStats.Builder(jobId).build(), new TimingStats(jobId));
545+
return;
546+
}
547+
Filters filters = aggs.get(results);
548+
TopHits dataCountHit = filters.getBucketByKey(dataCounts).getAggregations().get(topHits);
549+
DataCounts dataCountsResult = dataCountHit.getHits().getHits().length == 0
550+
? new DataCounts(jobId)
551+
: MlParserUtils.parse(dataCountHit.getHits().getHits()[0], DataCounts.PARSER);
552+
553+
TopHits timingStatsHits = filters.getBucketByKey(timingStats).getAggregations().get(topHits);
554+
TimingStats timingStatsResult = timingStatsHits.getHits().getHits().length == 0
555+
? new TimingStats(jobId)
556+
: MlParserUtils.parse(timingStatsHits.getHits().getHits()[0], TimingStats.PARSER);
557+
558+
TopHits modelSizeHits = filters.getBucketByKey(modelSizeStats).getAggregations().get(topHits);
559+
ModelSizeStats modelSizeStatsResult = modelSizeHits.getHits().getHits().length == 0
560+
? new ModelSizeStats.Builder(jobId).build()
561+
: MlParserUtils.parse(modelSizeHits.getHits().getHits()[0], ModelSizeStats.LENIENT_PARSER).build();
562+
563+
handler.apply(dataCountsResult, modelSizeStatsResult, timingStatsResult);
564+
}, errorHandler), client::search);
565+
}
566+
492567
private SearchRequestBuilder createLatestDataCountsSearch(String indexName, String jobId) {
493568
return client.prepareSearch(indexName)
494569
.setSize(1)
@@ -507,24 +582,6 @@ private SearchRequestBuilder createLatestDataCountsSearch(String indexName, Stri
507582
.addSort(SortBuilders.fieldSort(DataCounts.LATEST_RECORD_TIME.getPreferredName()).order(SortOrder.DESC));
508583
}
509584

510-
/**
511-
* Get the job's timing stats
512-
*
513-
* @param jobId The job id
514-
*/
515-
public void timingStats(String jobId, Consumer<TimingStats> handler, Consumer<Exception> errorHandler) {
516-
String indexName = AnomalyDetectorsIndex.jobResultsAliasedName(jobId);
517-
searchSingleResult(
518-
jobId,
519-
TimingStats.TYPE.getPreferredName(),
520-
createLatestTimingStatsSearch(indexName, jobId),
521-
TimingStats.PARSER,
522-
result -> handler.accept(result.result),
523-
errorHandler,
524-
() -> new TimingStats(jobId)
525-
);
526-
}
527-
528585
private SearchRequestBuilder createLatestTimingStatsSearch(String indexName, String jobId) {
529586
return client.prepareSearch(indexName)
530587
.setSize(1)
@@ -1698,7 +1755,7 @@ public void getForecastStats(String jobId, Consumer<ForecastStats> handler, Cons
16981755
AggregationBuilders.terms(ForecastStats.Fields.STATUSES).field(ForecastRequestStats.STATUS.getPreferredName())
16991756
);
17001757
sourceBuilder.size(0);
1701-
sourceBuilder.trackTotalHits(true);
1758+
sourceBuilder.trackTotalHits(false);
17021759

17031760
searchRequest.source(sourceBuilder);
17041761

@@ -1707,19 +1764,19 @@ public void getForecastStats(String jobId, Consumer<ForecastStats> handler, Cons
17071764
ML_ORIGIN,
17081765
searchRequest,
17091766
ActionListener.<SearchResponse>wrap(searchResponse -> {
1710-
long totalHits = searchResponse.getHits().getTotalHits().value;
17111767
Aggregations aggregations = searchResponse.getAggregations();
1712-
if (totalHits == 0 || aggregations == null) {
1768+
if (aggregations == null) {
17131769
handler.accept(new ForecastStats());
17141770
return;
17151771
}
17161772
Map<String, Aggregation> aggregationsAsMap = aggregations.asMap();
17171773
StatsAccumulator memoryStats = StatsAccumulator.fromStatsAggregation(
17181774
(Stats) aggregationsAsMap.get(ForecastStats.Fields.MEMORY)
17191775
);
1720-
StatsAccumulator recordStats = StatsAccumulator.fromStatsAggregation(
1721-
(Stats) aggregationsAsMap.get(ForecastStats.Fields.RECORDS)
1722-
);
1776+
Stats aggRecordsStats = (Stats) aggregationsAsMap.get(ForecastStats.Fields.RECORDS);
1777+
// Stats already gives us all the counts and every doc as a "records" field.
1778+
long totalHits = aggRecordsStats.getCount();
1779+
StatsAccumulator recordStats = StatsAccumulator.fromStatsAggregation(aggRecordsStats);
17231780
StatsAccumulator runtimeStats = StatsAccumulator.fromStatsAggregation(
17241781
(Stats) aggregationsAsMap.get(ForecastStats.Fields.RUNTIME)
17251782
);

0 commit comments

Comments
 (0)