Skip to content

Commit f449b8f

Browse files
[ML] Improve resuming a DFA job stopped during inference (#67623)
If a DFA job is stopped while in the inference phase, after resuming we should start inference immediately. However, this is currently not the case. Inference is tied in `AnalyticsProcessManager` and thus we start a process, load data, restore state, etc., until we get to start inference. This commit gets rid of this unnecessary delay by factoring inference out as an independent step and ensuring we can resume straight from that phase upon restarting a job.
1 parent 1a05a5a commit f449b8f

17 files changed

+479
-91
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,6 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
764764
analyticsProcessFactory,
765765
dataFrameAnalyticsAuditor,
766766
trainedModelProvider,
767-
modelLoadingService,
768767
resultsPersisterService,
769768
EsExecutors.allocatedProcessors(settings));
770769
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
@@ -773,8 +772,9 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
773772
DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry,
774773
dataFrameAnalyticsAuditor);
775774
assert client instanceof NodeClient;
776-
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService,
777-
dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver);
775+
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(settings, (NodeClient) client, threadPool,
776+
clusterService, dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor,
777+
indexNameExpressionResolver, resultsPersisterService, modelLoadingService);
778778
this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager);
779779

780780
// Components shared by anomaly detection and data frame analytics

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ private void getStartContext(String id, Task task, ActionListener<StartContext>
273273
break;
274274
case RESUMING_REINDEXING:
275275
case RESUMING_ANALYZING:
276+
case RESUMING_INFERENCE:
276277
toValidateMappingsListener.onResponse(startContext);
277278
break;
278279
case FINISHED:

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

+55-4
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,30 @@
1919
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2020
import org.elasticsearch.cluster.metadata.MappingMetadata;
2121
import org.elasticsearch.cluster.service.ClusterService;
22+
import org.elasticsearch.common.settings.Settings;
2223
import org.elasticsearch.index.IndexNotFoundException;
24+
import org.elasticsearch.threadpool.ThreadPool;
2325
import org.elasticsearch.xpack.core.ClientHelper;
2426
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
2527
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2628
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
2729
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
2830
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
31+
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
32+
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
33+
import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
2934
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
3035
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
3136
import org.elasticsearch.xpack.ml.dataframe.steps.AnalysisStep;
3237
import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep;
38+
import org.elasticsearch.xpack.ml.dataframe.steps.FinalStep;
39+
import org.elasticsearch.xpack.ml.dataframe.steps.InferenceStep;
3340
import org.elasticsearch.xpack.ml.dataframe.steps.ReindexingStep;
3441
import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse;
42+
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
43+
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
3544
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
45+
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
3646

3747
import java.util.Objects;
3848
import java.util.concurrent.atomic.AtomicBoolean;
@@ -43,27 +53,36 @@ public class DataFrameAnalyticsManager {
4353

4454
private static final Logger LOGGER = LogManager.getLogger(DataFrameAnalyticsManager.class);
4555

56+
private final Settings settings;
4657
/**
4758
* We need a {@link NodeClient} to get the reindexing task and be able to report progress
4859
*/
4960
private final NodeClient client;
61+
private final ThreadPool threadPool;
5062
private final ClusterService clusterService;
5163
private final DataFrameAnalyticsConfigProvider configProvider;
5264
private final AnalyticsProcessManager processManager;
5365
private final DataFrameAnalyticsAuditor auditor;
5466
private final IndexNameExpressionResolver expressionResolver;
67+
private final ResultsPersisterService resultsPersisterService;
68+
private final ModelLoadingService modelLoadingService;
5569
/** Indicates whether the node is shutting down. */
5670
private final AtomicBoolean nodeShuttingDown = new AtomicBoolean();
5771

58-
public DataFrameAnalyticsManager(NodeClient client, ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider,
59-
AnalyticsProcessManager processManager, DataFrameAnalyticsAuditor auditor,
60-
IndexNameExpressionResolver expressionResolver) {
72+
public DataFrameAnalyticsManager(Settings settings, NodeClient client, ThreadPool threadPool, ClusterService clusterService,
73+
DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager,
74+
DataFrameAnalyticsAuditor auditor, IndexNameExpressionResolver expressionResolver,
75+
ResultsPersisterService resultsPersisterService, ModelLoadingService modelLoadingService) {
76+
this.settings = Objects.requireNonNull(settings);
6177
this.client = Objects.requireNonNull(client);
78+
this.threadPool = Objects.requireNonNull(threadPool);
6279
this.clusterService = Objects.requireNonNull(clusterService);
6380
this.configProvider = Objects.requireNonNull(configProvider);
6481
this.processManager = Objects.requireNonNull(processManager);
6582
this.auditor = Objects.requireNonNull(auditor);
6683
this.expressionResolver = Objects.requireNonNull(expressionResolver);
84+
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
85+
this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
6786
}
6887

6988
public void execute(DataFrameAnalyticsTask task, ClusterState clusterState) {
@@ -141,6 +160,12 @@ private void determineProgressAndResume(DataFrameAnalyticsTask task, DataFrameAn
141160
case RESUMING_ANALYZING:
142161
executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
143162
break;
163+
case RESUMING_INFERENCE:
164+
buildInferenceStep(task, config, ActionListener.wrap(
165+
inferenceStep -> executeStep(task, config, inferenceStep),
166+
task::setFailed
167+
));
168+
break;
144169
case FINISHED:
145170
default:
146171
task.setFailed(ExceptionsHelper.serverError("Unexpected starting state [" + startingState + "]"));
@@ -162,7 +187,15 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c
162187
executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
163188
break;
164189
case ANALYSIS:
165-
// This is the last step
190+
buildInferenceStep(task, config, ActionListener.wrap(
191+
inferenceStep -> executeStep(task, config, inferenceStep),
192+
task::setFailed
193+
));
194+
break;
195+
case INFERENCE:
196+
executeStep(task, config, new FinalStep(client, task, auditor, config));
197+
break;
198+
case FINAL:
166199
LOGGER.info("[{}] Marking task completed", config.getId());
167200
task.markAsCompleted();
168201
break;
@@ -199,6 +232,24 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra
199232
));
200233
}
201234

235+
private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener<InferenceStep> listener) {
236+
ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
237+
238+
ActionListener<ExtractedFieldsDetector> extractedFieldsDetectorListener = ActionListener.wrap(
239+
extractedFieldsDetector -> {
240+
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
241+
InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService,
242+
resultsPersisterService, task.getParentTaskId(), config, extractedFields, task.getStatsHolder().getProgressTracker(),
243+
task.getStatsHolder().getDataCountsTracker());
244+
InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner);
245+
listener.onResponse(inferenceStep);
246+
},
247+
listener::onFailure
248+
);
249+
250+
new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener);
251+
}
252+
202253
public boolean isNodeShuttingDown() {
203254
return nodeShuttingDown.get();
204255
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public void updateTaskProgress(ActionListener<Void> updateProgressListener) {
287287
* {@code FINISHED} means the job had finished.
288288
*/
289289
public enum StartingState {
290-
FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED
290+
FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, RESUMING_INFERENCE, FINISHED
291291
}
292292

293293
public StartingState determineStartingState() {
@@ -313,6 +313,9 @@ public static StartingState determineStartingState(String jobId, List<PhaseProgr
313313
if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) {
314314
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
315315
}
316+
if (ProgressTracker.INFERENCE.equals(lastIncompletePhase.getPhase())) {
317+
return StartingState.RESUMING_INFERENCE;
318+
}
316319
return StartingState.RESUMING_ANALYZING;
317320
}
318321
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public boolean isCancelled() {
9999
}
100100

101101
public void cancel() {
102-
LOGGER.debug("[{}] Data extractor was cancelled", context.jobId);
102+
LOGGER.debug(() -> new ParameterizedMessage("[{}] Data extractor was cancelled", context.jobId));
103103
isCancelled = true;
104104
}
105105

@@ -127,7 +127,7 @@ private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request)
127127
// We've set allow_partial_search_results to false which means if something
128128
// goes wrong the request will throw.
129129
SearchResponse searchResponse = request.get();
130-
LOGGER.debug("[{}] Search response was obtained", context.jobId);
130+
LOGGER.trace(() -> new ParameterizedMessage("[{}] Search response was obtained", context.jobId));
131131

132132
List<Row> rows = processSearchResponse(searchResponse);
133133

@@ -153,7 +153,7 @@ private SearchRequestBuilder buildSearchRequest() {
153153
long from = lastSortKey + 1;
154154
long to = from + context.scrollSize;
155155

156-
LOGGER.debug(() -> new ParameterizedMessage(
156+
LOGGER.trace(() -> new ParameterizedMessage(
157157
"[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to));
158158

159159
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
@@ -283,7 +283,7 @@ private Row createRow(SearchHit hit) {
283283
}
284284
boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
285285
Row row = new Row(extractedValues, hit, isTraining);
286-
LOGGER.debug(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
286+
LOGGER.trace(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
287287
context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values)));
288288
return row;
289289
}
@@ -306,7 +306,7 @@ public DataSummary collectDataSummary() {
306306
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
307307
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
308308
long rows = searchResponse.getHits().getTotalHits().value;
309-
LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
309+
LOGGER.debug(() -> new ParameterizedMessage("[{}] Data summary rows [{}]", context.jobId, rows));
310310
return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
311311
}
312312

0 commit comments

Comments
 (0)