Skip to content

Commit eed5494

Browse files
[ML] Refactor data frame analytics framework in steps (#67423)
This commit improves the design of the code that manages and runs data frame analytics jobs. First, it splits the code into asynchronous steps. At the moment, there are only two steps: the reindexing step and the analysis step. However, splitting the task into steps allows us to factor out running inference into its own step which in turn allows us to properly resume a job that failed during inference without having to start a c++ process. I will follow with this improvement in a follow up PR. The other main improvement this commit does is that it simplifies the state of a DFA task by getting rid of `reindexing` and `analyzing` states. Now, once the task goes to `started` it stays there until it finishes or gets stopped. The removed states are no longer useful. They used to be useful in order to know how to resume a job before progress was added. But currently they serve no purpose at all.
1 parent bbaf4ce commit eed5494

File tree

21 files changed

+886
-596
lines changed

21 files changed

+886
-596
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
public enum DataFrameAnalyticsState implements Writeable {
1717

18+
// States reindexing and analyzing are no longer used.
19+
// However, we need to keep them for BWC as tasks may be
20+
// awaiting assignment in older versioned nodes.
1821
STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED, FAILED, STARTING;
1922

2023
public static DataFrameAnalyticsState fromString(String name) {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -203,46 +203,6 @@ public void testGetDataFrameAnalyticsState_GivenStaleTaskWithStartedState() {
203203
assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
204204
}
205205

206-
public void testGetDataFrameAnalyticsState_GivenTaskWithReindexingState() {
207-
String jobId = "foo";
208-
PersistentTasksCustomMetadata.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
209-
DataFrameAnalyticsState.REINDEXING, false);
210-
211-
DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
212-
213-
assertThat(state, equalTo(DataFrameAnalyticsState.REINDEXING));
214-
}
215-
216-
public void testGetDataFrameAnalyticsState_GivenStaleTaskWithReindexingState() {
217-
String jobId = "foo";
218-
PersistentTasksCustomMetadata.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
219-
DataFrameAnalyticsState.REINDEXING, true);
220-
221-
DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
222-
223-
assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
224-
}
225-
226-
public void testGetDataFrameAnalyticsState_GivenTaskWithAnalyzingState() {
227-
String jobId = "foo";
228-
PersistentTasksCustomMetadata.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
229-
DataFrameAnalyticsState.ANALYZING, false);
230-
231-
DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
232-
233-
assertThat(state, equalTo(DataFrameAnalyticsState.ANALYZING));
234-
}
235-
236-
public void testGetDataFrameAnalyticsState_GivenStaleTaskWithAnalyzingState() {
237-
String jobId = "foo";
238-
PersistentTasksCustomMetadata.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
239-
DataFrameAnalyticsState.ANALYZING, true);
240-
241-
DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
242-
243-
assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
244-
}
245-
246206
public void testGetDataFrameAnalyticsState_GivenTaskWithStoppingState() {
247207
String jobId = "foo";
248208
PersistentTasksCustomMetadata.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
4949
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
5050
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
51+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
5152
import org.junit.After;
5253
import org.junit.Before;
5354

@@ -496,15 +497,10 @@ public void testStopAndRestart() throws Exception {
496497
NodeAcknowledgedResponse response = startAnalytics(jobId);
497498
assertThat(response.getNode(), not(emptyString()));
498499

499-
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
500+
// Wait until progress for first phase is over 1
500501
assertBusy(() -> {
501-
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
502-
assertThat(
503-
state,
504-
is(anyOf(
505-
equalTo(DataFrameAnalyticsState.REINDEXING),
506-
equalTo(DataFrameAnalyticsState.ANALYZING),
507-
equalTo(DataFrameAnalyticsState.STOPPED))));
502+
List<PhaseProgress> progress = getAnalyticsStats(jobId).getProgress();
503+
assertThat(progress.get(0).getProgressPercent(), greaterThan(1));
508504
});
509505
stopAnalytics(jobId);
510506
waitUntilAnalyticsIsStopped(jobId);

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.action.search.SearchResponse;
1616
import org.elasticsearch.action.support.WriteRequest;
1717
import org.elasticsearch.common.settings.Settings;
18-
import org.elasticsearch.common.unit.TimeValue;
1918
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
2019
import org.elasticsearch.rest.RestStatus;
2120
import org.elasticsearch.search.SearchHit;
@@ -27,14 +26,14 @@
2726
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2827
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
2928
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
30-
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3129
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
3230
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
3331
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
3432
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
3533
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3634
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
3735
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
36+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
3837
import org.junit.After;
3938

4039
import java.io.IOException;
@@ -47,7 +46,6 @@
4746
import java.util.Set;
4847

4948
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
50-
import static org.hamcrest.Matchers.anyOf;
5149
import static org.hamcrest.Matchers.emptyString;
5250
import static org.hamcrest.Matchers.equalTo;
5351
import static org.hamcrest.Matchers.greaterThan;
@@ -305,15 +303,10 @@ public void testStopAndRestart() throws Exception {
305303
NodeAcknowledgedResponse response = startAnalytics(jobId);
306304
assertThat(response.getNode(), not(emptyString()));
307305

308-
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
306+
// Wait until progress for first phase is over 1
309307
assertBusy(() -> {
310-
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
311-
assertThat(
312-
state,
313-
is(anyOf(
314-
equalTo(DataFrameAnalyticsState.REINDEXING),
315-
equalTo(DataFrameAnalyticsState.ANALYZING),
316-
equalTo(DataFrameAnalyticsState.STOPPED))));
308+
List<PhaseProgress> progress = getAnalyticsStats(jobId).getProgress();
309+
assertThat(progress.get(0).getProgressPercent(), greaterThan(1));
317310
});
318311
stopAnalytics(jobId);
319312
waitUntilAnalyticsIsStopped(jobId);

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
2929
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3030
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
31+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
3132
import org.junit.After;
3233
import org.junit.Before;
3334

@@ -610,11 +611,10 @@ public void testOutlierDetectionStopAndRestart() throws Exception {
610611
NodeAcknowledgedResponse response = startAnalytics(id);
611612
assertThat(response.getNode(), not(emptyString()));
612613

613-
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
614+
// Wait until progress for first phase is over 1
614615
assertBusy(() -> {
615-
DataFrameAnalyticsState state = getAnalyticsStats(id).getState();
616-
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
617-
equalTo(DataFrameAnalyticsState.STOPPED))));
616+
List<PhaseProgress> progress = getAnalyticsStats(id).getProgress();
617+
assertThat(progress.get(0).getProgressPercent(), greaterThan(1));
618618
});
619619
stopAnalytics(id);
620620
waitUntilAnalyticsIsStopped(id);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
773773
DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry,
774774
dataFrameAnalyticsAuditor);
775775
assert client instanceof NodeClient;
776-
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client,
776+
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService,
777777
dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver);
778778
this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager);
779779

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
105105
ActionListener<QueryPage<Stats>> listener) {
106106
logger.debug("Get stats for running task [{}]", task.getParams().getId());
107107

108-
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
108+
ActionListener<Void> updateProgressListener = ActionListener.wrap(
109109
aVoid -> {
110110
Stats stats = buildStats(
111111
task.getParams().getId(),
@@ -120,7 +120,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
120120
);
121121

122122
// We must update the progress of the reindexing task as it might be stale
123-
task.updateReindexTaskProgress(reindexingProgressListener);
123+
task.updateTaskProgress(updateProgressListener);
124124
}
125125

126126
@Override

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import org.elasticsearch.action.ActionListener;
1515
import org.elasticsearch.action.search.SearchAction;
1616
import org.elasticsearch.action.search.SearchRequest;
17+
import org.elasticsearch.action.search.SearchResponse;
1718
import org.elasticsearch.action.support.ActionFilters;
19+
import org.elasticsearch.action.support.IndicesOptions;
1820
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
1921
import org.elasticsearch.client.Client;
2022
import org.elasticsearch.client.ParentTaskAssigningClient;
@@ -31,6 +33,7 @@
3133
import org.elasticsearch.common.unit.ByteSizeValue;
3234
import org.elasticsearch.common.unit.TimeValue;
3335
import org.elasticsearch.index.IndexNotFoundException;
36+
import org.elasticsearch.index.query.QueryBuilders;
3437
import org.elasticsearch.license.License;
3538
import org.elasticsearch.license.LicenseUtils;
3639
import org.elasticsearch.license.XPackLicenseState;
@@ -40,6 +43,7 @@
4043
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
4144
import org.elasticsearch.persistent.PersistentTasksService;
4245
import org.elasticsearch.rest.RestStatus;
46+
import org.elasticsearch.search.SearchHit;
4347
import org.elasticsearch.tasks.Task;
4448
import org.elasticsearch.tasks.TaskId;
4549
import org.elasticsearch.threadpool.ThreadPool;
@@ -71,6 +75,7 @@
7175
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
7276
import org.elasticsearch.xpack.ml.dataframe.MappingsMerger;
7377
import org.elasticsearch.xpack.ml.dataframe.SourceDestValidations;
78+
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
7479
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
7580
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
7681
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
@@ -79,6 +84,7 @@
7984
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
8085
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
8186
import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;
87+
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;
8288

8389
import java.util.List;
8490
import java.util.Map;
@@ -145,6 +151,7 @@ protected ClusterBlockException checkBlock(StartDataFrameAnalyticsAction.Request
145151
@Override
146152
protected void masterOperation(Task task, StartDataFrameAnalyticsAction.Request request, ClusterState state,
147153
ActionListener<NodeAcknowledgedResponse> listener) {
154+
logger.debug(() -> new ParameterizedMessage("[{}] received start request", request.getId()));
148155
if (licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) == false) {
149156
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
150157
return;
@@ -531,8 +538,6 @@ public boolean test(PersistentTasksCustomMetadata.PersistentTask<?> persistentTa
531538
DataFrameAnalyticsState analyticsState = taskState == null ? DataFrameAnalyticsState.STOPPED : taskState.getState();
532539
switch (analyticsState) {
533540
case STARTED:
534-
case REINDEXING:
535-
case ANALYZING:
536541
node = persistentTask.getExecutorNode();
537542
return true;
538543
case STOPPING:
@@ -586,7 +591,6 @@ public void onFailure(Exception e) {
586591
public static class TaskExecutor extends AbstractJobPersistentTasksExecutor<TaskParams> {
587592

588593
private final Client client;
589-
private final ClusterService clusterService;
590594
private final DataFrameAnalyticsManager manager;
591595
private final DataFrameAnalyticsAuditor auditor;
592596
private final IndexTemplateConfig inferenceIndexTemplate;
@@ -603,7 +607,6 @@ public TaskExecutor(Settings settings, Client client, ClusterService clusterServ
603607
memoryTracker,
604608
resolver);
605609
this.client = Objects.requireNonNull(client);
606-
this.clusterService = Objects.requireNonNull(clusterService);
607610
this.manager = Objects.requireNonNull(manager);
608611
this.auditor = Objects.requireNonNull(auditor);
609612
this.inferenceIndexTemplate = Objects.requireNonNull(inferenceIndexTemplate);
@@ -616,7 +619,7 @@ protected AllocatedPersistentTask createTask(
616619
PersistentTasksCustomMetadata.PersistentTask<TaskParams> persistentTask,
617620
Map<String, String> headers) {
618621
return new DataFrameAnalyticsTask(
619-
id, type, action, parentTaskId, headers, client, clusterService, manager, auditor, persistentTask.getParams());
622+
id, type, action, parentTaskId, headers, client, manager, auditor, persistentTask.getParams());
620623
}
621624

622625
@Override
@@ -650,8 +653,10 @@ public PersistentTasksCustomMetadata.Assignment getAssignment(TaskParams params,
650653

651654
@Override
652655
protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, PersistentTaskState state) {
656+
DataFrameAnalyticsTask dfaTask = (DataFrameAnalyticsTask) task;
653657
DataFrameAnalyticsTaskState analyticsTaskState = (DataFrameAnalyticsTaskState) state;
654-
DataFrameAnalyticsState analyticsState = analyticsTaskState == null ? null : analyticsTaskState.getState();
658+
DataFrameAnalyticsState analyticsState = analyticsTaskState == null ? DataFrameAnalyticsState.STOPPED
659+
: analyticsTaskState.getState();
655660
logger.info("[{}] Starting data frame analytics from state [{}]", params.getId(), analyticsState);
656661

657662
// If we are "stopping" there is nothing to do and we should stop
@@ -665,8 +670,26 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe
665670
return;
666671
}
667672

673+
ActionListener<StoredProgress> progressListener = ActionListener.wrap(
674+
storedProgress -> {
675+
if (storedProgress != null) {
676+
dfaTask.getStatsHolder().setProgressTracker(storedProgress.get());
677+
}
678+
executeTask(dfaTask);
679+
},
680+
dfaTask::setFailed
681+
);
682+
668683
ActionListener<Boolean> templateCheckListener = ActionListener.wrap(
669-
ok -> executeTask(analyticsTaskState, task),
684+
ok -> {
685+
if (analyticsState != DataFrameAnalyticsState.STOPPED) {
686+
// If the state is not stopped it means the task is reassigning and
687+
// we need to update the progress from the last stored progress doc.
688+
searchProgressFromIndex(params.getId(), progressListener);
689+
} else {
690+
progressListener.onResponse(null);
691+
}
692+
},
670693
error -> {
671694
Throwable cause = ExceptionsHelper.unwrapCause(error);
672695
logger.error(
@@ -675,23 +698,44 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe
675698
params.getId(),
676699
inferenceIndexTemplate.getTemplateName()),
677700
cause);
678-
task.markAsFailed(error);
701+
dfaTask.setFailed(error);
679702
}
680703
);
681704

682705
MlIndexAndAlias.installIndexTemplateIfRequired(clusterState, client, inferenceIndexTemplate, templateCheckListener);
683706
}
684707

685-
private void executeTask(DataFrameAnalyticsTaskState analyticsTaskState, AllocatedPersistentTask task) {
686-
if (analyticsTaskState == null) {
687-
DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED,
688-
task.getAllocationId(), null);
689-
task.updatePersistentTaskState(startedState, ActionListener.wrap(
690-
response -> manager.execute((DataFrameAnalyticsTask) task, DataFrameAnalyticsState.STARTED, clusterState),
691-
task::markAsFailed));
692-
} else {
693-
manager.execute((DataFrameAnalyticsTask) task, analyticsTaskState.getState(), clusterState);
694-
}
708+
private void searchProgressFromIndex(String jobId, ActionListener<StoredProgress> listener) {
709+
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
710+
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
711+
searchRequest.source().size(1);
712+
searchRequest.source().query(QueryBuilders.idsQuery().addIds(StoredProgress.documentId(jobId)));
713+
searchRequest.allowPartialSearchResults(false);
714+
715+
ActionListener<SearchResponse> searchListener = ActionListener.wrap(
716+
searchResponse -> {
717+
SearchHit[] hits = searchResponse.getHits().getHits();
718+
if (hits.length == 0) {
719+
logger.debug(() -> new ParameterizedMessage("[{}] No stored progress found", jobId));
720+
listener.onResponse(null);
721+
} else {
722+
StoredProgress storedProgress = MlParserUtils.parse(hits[0], StoredProgress.PARSER);
723+
logger.debug(() -> new ParameterizedMessage("[{}] Found stored progress {}", jobId, storedProgress.get().get(0)));
724+
listener.onResponse(storedProgress);
725+
}
726+
},
727+
listener::onFailure
728+
);
729+
730+
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchListener);
731+
}
732+
733+
private void executeTask(DataFrameAnalyticsTask task) {
734+
DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED,
735+
task.getAllocationId(), null);
736+
task.updatePersistentTaskState(startedState, ActionListener.wrap(
737+
response -> manager.execute(task, clusterState),
738+
task::markAsFailed));
695739
}
696740

697741
public static String nodeFilter(DiscoveryNode node, TaskParams params) {

0 commit comments

Comments
 (0)