Skip to content

Commit 2547cfb

Browse files
[7.x][ML] Persist progress when setting DFA task to failed (#61782) (#61792)
When an error occurs and we set the task to failed via the `DataFrameAnalyticsTask.setFailed` method we do not persist progress. If the job is later restarted, this means we do not correctly restore from where we can but instead we start the job from scratch and have to redo the reindexing phase. This commit solves this bug by persisting the progress before setting the task to failed. Backport of #61782
1 parent d52ee17 commit 2547cfb

File tree

3 files changed

+102
-62
lines changed

3 files changed

+102
-62
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
232232

233233
Exception reindexError = getReindexError(task.getParams().getId(), reindexResponse);
234234
if (reindexError != null) {
235-
task.markAsFailed(reindexError);
235+
task.setFailed(reindexError);
236236
return;
237237
}
238238

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import org.apache.logging.log4j.LogManager;
99
import org.apache.logging.log4j.Logger;
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
11-
import org.apache.lucene.util.SetOnce;
1211
import org.elasticsearch.ResourceNotFoundException;
1312
import org.elasticsearch.action.ActionListener;
1413
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
@@ -38,7 +37,6 @@
3837
import org.elasticsearch.tasks.TaskManager;
3938
import org.elasticsearch.tasks.TaskResult;
4039
import org.elasticsearch.xpack.core.ml.MlTasks;
41-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
4240
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
4341
import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction;
4442
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
@@ -216,23 +214,25 @@ public void setFailed(Exception error) {
216214
error);
217215
return;
218216
}
219-
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
220-
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
221-
DataFrameAnalyticsTaskState newTaskState =
222-
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
223-
updatePersistentTaskState(
224-
newTaskState,
225-
ActionListener.wrap(
226-
updatedTask -> {
227-
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
217+
persistProgress(client, taskParams.getId(), () -> {
218+
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
219+
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
220+
DataFrameAnalyticsTaskState newTaskState =
221+
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
222+
updatePersistentTaskState(
223+
newTaskState,
224+
ActionListener.wrap(
225+
updatedTask -> {
226+
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
228227
DataFrameAnalyticsState.FAILED, reason);
229-
auditor.info(getParams().getId(), message);
230-
LOGGER.info("[{}] {}", getParams().getId(), message);
231-
},
232-
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
233-
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
234-
)
235-
);
228+
auditor.info(getParams().getId(), message);
229+
LOGGER.info("[{}] {}", getParams().getId(), message);
230+
},
231+
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
232+
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
233+
)
234+
);
235+
});
236236
}
237237

238238
public void updateReindexTaskProgress(ActionListener<Void> listener) {
@@ -285,13 +285,12 @@ private TaskId getReindexTaskId() {
285285
}
286286

287287
// Visible for testing
288-
static void persistProgress(Client client, String jobId, Runnable runnable) {
288+
void persistProgress(Client client, String jobId, Runnable runnable) {
289289
LOGGER.debug("[{}] Persisting progress", jobId);
290290

291291
String progressDocId = StoredProgress.documentId(jobId);
292-
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
293292

294-
// Step 4: Run the runnable provided as the argument
293+
// Step 3: Run the runnable provided as the argument
295294
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
296295
indexResponse -> {
297296
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
@@ -304,7 +303,7 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
304303
}
305304
);
306305

307-
// Step 3: Create or update the progress document:
306+
// Step 2: Create or update the progress document:
308307
// - if the document did not exist, create the new one in the current write index
309308
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
310309
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
@@ -317,8 +316,10 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
317316
.id(progressDocId)
318317
.setRequireAlias(AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias))
319318
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
319+
List<PhaseProgress> progress = statsHolder.getProgressTracker().report();
320320
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
321-
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
321+
LOGGER.debug("[{}] Persisting progress is: {}", jobId, progress);
322+
new StoredProgress(progress).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
322323
indexRequest.source(jsonBuilder);
323324
}
324325
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
@@ -330,28 +331,14 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
330331
}
331332
);
332333

333-
// Step 2: Search for existing progress document in .ml-state*
334-
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
335-
statsResponse -> {
336-
stats.set(statsResponse.getResponse().results().get(0));
337-
SearchRequest searchRequest =
338-
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
339-
.source(
340-
new SearchSourceBuilder()
341-
.size(1)
342-
.query(new IdsQueryBuilder().addIds(progressDocId)));
343-
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
344-
},
345-
e -> {
346-
LOGGER.error(new ParameterizedMessage(
347-
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
348-
runnable.run();
349-
}
350-
);
351-
352-
// Step 1: Fetch progress to be persisted
353-
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
354-
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
334+
// Step 1: Search for existing progress document in .ml-state*
335+
SearchRequest searchRequest =
336+
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
337+
.source(
338+
new SearchSourceBuilder()
339+
.size(1)
340+
.query(new IdsQueryBuilder().addIds(progressDocId)));
341+
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
355342
}
356343

357344
/**

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616
import org.elasticsearch.cluster.service.ClusterService;
1717
import org.elasticsearch.common.settings.Settings;
1818
import org.elasticsearch.common.util.concurrent.ThreadContext;
19+
import org.elasticsearch.common.xcontent.DeprecationHandler;
20+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
21+
import org.elasticsearch.common.xcontent.XContentParser;
22+
import org.elasticsearch.common.xcontent.json.JsonXContent;
1923
import org.elasticsearch.persistent.PersistentTasksService;
2024
import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
2125
import org.elasticsearch.search.SearchHit;
2226
import org.elasticsearch.search.SearchHits;
2327
import org.elasticsearch.tasks.TaskManager;
2428
import org.elasticsearch.test.ESTestCase;
2529
import org.elasticsearch.threadpool.ThreadPool;
26-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
27-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
2830
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
2931
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3032
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
33+
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
3134
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
3235
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
3336
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
@@ -36,6 +39,7 @@
3639
import org.mockito.InOrder;
3740
import org.mockito.stubbing.Answer;
3841

42+
import java.io.IOException;
3943
import java.util.Arrays;
4044
import java.util.Collections;
4145
import java.util.List;
@@ -125,14 +129,25 @@ public void testDetermineStartingState_GivenEmptyProgress() {
125129
assertThat(startingState, equalTo(StartingState.FINISHED));
126130
}
127131

128-
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
132+
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) throws IOException {
129133
Client client = mock(Client.class);
134+
when(client.settings()).thenReturn(Settings.EMPTY);
130135
ThreadPool threadPool = mock(ThreadPool.class);
131136
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
132137
when(client.threadPool()).thenReturn(threadPool);
133138

134-
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
135-
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
139+
ClusterService clusterService = mock(ClusterService.class);
140+
DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
141+
DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
142+
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, threadPool, client);
143+
144+
List<PhaseProgress> progress = Arrays.asList(
145+
new PhaseProgress(ProgressTracker.REINDEXING, 100),
146+
new PhaseProgress(ProgressTracker.LOADING_DATA, 50),
147+
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));
148+
149+
StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
150+
"task_id", Version.CURRENT, progress, false);
136151

137152
SearchResponse searchResponse = mock(SearchResponse.class);
138153
when(searchResponse.getHits()).thenReturn(searchHits);
@@ -141,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
141156
IndexResponse indexResponse = mock(IndexResponse.class);
142157
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
143158

159+
TaskManager taskManager = mock(TaskManager.class);
160+
144161
Runnable runnable = mock(Runnable.class);
145162

146-
DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
163+
DataFrameAnalyticsTask task =
164+
new DataFrameAnalyticsTask(
165+
123, "type", "action", null, Collections.emptyMap(), client, clusterService, analyticsManager, auditor, taskParams);
166+
task.init(persistentTasksService, taskManager, "task-id", 42);
167+
168+
task.persistProgress(client, "task_id", runnable);
147169

148170
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
149171

150172
InOrder inOrder = inOrder(client, runnable);
151-
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
152173
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
153174
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
154175
inOrder.verify(runnable).run();
@@ -157,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
157178
IndexRequest indexRequest = indexRequestCaptor.getValue();
158179
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
159180
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
181+
182+
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
183+
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
184+
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
185+
assertThat(parsedProgress.get(), equalTo(progress));
186+
}
160187
}
161188

162-
public void testPersistProgress_ProgressDocumentCreated() {
189+
public void testPersistProgress_ProgressDocumentCreated() throws IOException {
163190
testPersistProgress(SearchHits.empty(), ".ml-state-write");
164191
}
165192

166-
public void testPersistProgress_ProgressDocumentUpdated() {
193+
public void testPersistProgress_ProgressDocumentUpdated() throws IOException {
167194
testPersistProgress(
168195
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
169196
".ml-state-dummy");
170197
}
171198

172-
public void testSetFailed() {
199+
public void testSetFailed() throws IOException {
173200
testSetFailed(false);
174201
}
175202

176-
public void testSetFailedDuringNodeShutdown() {
203+
public void testSetFailedDuringNodeShutdown() throws IOException {
177204
testSetFailed(true);
178205
}
179206

180-
private void testSetFailed(boolean nodeShuttingDown) {
207+
private void testSetFailed(boolean nodeShuttingDown) throws IOException {
181208
ThreadPool threadPool = mock(ThreadPool.class);
182209
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
183210
Client client = mock(Client.class);
@@ -190,15 +217,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
190217
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
191218
TaskManager taskManager = mock(TaskManager.class);
192219

220+
List<PhaseProgress> progress = Arrays.asList(
221+
new PhaseProgress(ProgressTracker.REINDEXING, 100),
222+
new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
223+
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30));
224+
193225
StartDataFrameAnalyticsAction.TaskParams taskParams =
194226
new StartDataFrameAnalyticsAction.TaskParams(
195227
"job-id",
196228
Version.CURRENT,
197-
Arrays.asList(
198-
new PhaseProgress(ProgressTracker.REINDEXING, 0),
199-
new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
200-
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
229+
progress,
201230
false);
231+
232+
SearchResponse searchResponse = mock(SearchResponse.class);
233+
when(searchResponse.getHits()).thenReturn(SearchHits.empty());
234+
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
235+
236+
IndexResponse indexResponse = mock(IndexResponse.class);
237+
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
238+
202239
DataFrameAnalyticsTask task =
203240
new DataFrameAnalyticsTask(
204241
123, "type", "action", null, Collections.emptyMap(), client, clusterService, analyticsManager, auditor, taskParams);
@@ -210,7 +247,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
210247
verify(analyticsManager).isNodeShuttingDown();
211248
verify(client, atLeastOnce()).settings();
212249
verify(client, atLeastOnce()).threadPool();
250+
213251
if (nodeShuttingDown == false) {
252+
// Verify progress was persisted
253+
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
254+
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
255+
verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
256+
257+
IndexRequest indexRequest = indexRequestCaptor.getValue();
258+
assertThat(indexRequest.index(), equalTo(AnomalyDetectorsIndex.jobStateIndexWriteAlias()));
259+
assertThat(indexRequest.id(), equalTo("data_frame_analytics-job-id-progress"));
260+
261+
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
262+
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
263+
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
264+
assertThat(parsedProgress.get(), equalTo(progress));
265+
}
266+
214267
verify(client).execute(
215268
same(UpdatePersistentTaskStatusAction.INSTANCE),
216269
eq(new UpdatePersistentTaskStatusAction.Request(

0 commit comments

Comments
 (0)