Skip to content

Commit 2ba4e15

Browse files
[ML] Persist progress when setting DFA task to failed (#61782)
When an error occurs and we set the task to failed via the `DataFrameAnalyticsTask.setFailed` method we do not persist progress. If the job is later restarted, this means we do not correctly restore from where we can but instead we start the job from scratch and have to redo the reindexing phase. This commit solves this bug by persisting the progress before setting the task to failed.
1 parent 5f290c2 commit 2ba4e15

File tree

3 files changed

+101
-63
lines changed

3 files changed

+101
-63
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
232232

233233
Exception reindexError = getReindexError(task.getParams().getId(), reindexResponse);
234234
if (reindexError != null) {
235-
task.markAsFailed(reindexError);
235+
task.setFailed(reindexError);
236236
return;
237237
}
238238

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import org.apache.logging.log4j.LogManager;
99
import org.apache.logging.log4j.Logger;
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
11-
import org.apache.lucene.util.SetOnce;
1211
import org.elasticsearch.ResourceNotFoundException;
1312
import org.elasticsearch.action.ActionListener;
1413
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
@@ -38,7 +37,6 @@
3837
import org.elasticsearch.tasks.TaskManager;
3938
import org.elasticsearch.tasks.TaskResult;
4039
import org.elasticsearch.xpack.core.ml.MlTasks;
41-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
4240
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
4341
import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction;
4442
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
@@ -216,23 +214,25 @@ public void setFailed(Exception error) {
216214
error);
217215
return;
218216
}
219-
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
220-
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
221-
DataFrameAnalyticsTaskState newTaskState =
222-
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
223-
updatePersistentTaskState(
224-
newTaskState,
225-
ActionListener.wrap(
226-
updatedTask -> {
227-
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
217+
persistProgress(client, taskParams.getId(), () -> {
218+
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
219+
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
220+
DataFrameAnalyticsTaskState newTaskState =
221+
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
222+
updatePersistentTaskState(
223+
newTaskState,
224+
ActionListener.wrap(
225+
updatedTask -> {
226+
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
228227
DataFrameAnalyticsState.FAILED, reason);
229-
auditor.info(getParams().getId(), message);
230-
LOGGER.info("[{}] {}", getParams().getId(), message);
231-
},
232-
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
233-
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
234-
)
235-
);
228+
auditor.info(getParams().getId(), message);
229+
LOGGER.info("[{}] {}", getParams().getId(), message);
230+
},
231+
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
232+
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
233+
)
234+
);
235+
});
236236
}
237237

238238
public void updateReindexTaskProgress(ActionListener<Void> listener) {
@@ -285,13 +285,12 @@ private TaskId getReindexTaskId() {
285285
}
286286

287287
// Visible for testing
288-
static void persistProgress(Client client, String jobId, Runnable runnable) {
288+
void persistProgress(Client client, String jobId, Runnable runnable) {
289289
LOGGER.debug("[{}] Persisting progress", jobId);
290290

291291
String progressDocId = StoredProgress.documentId(jobId);
292-
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
293292

294-
// Step 4: Run the runnable provided as the argument
293+
// Step 3: Run the runnable provided as the argument
295294
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
296295
indexResponse -> {
297296
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
@@ -304,7 +303,7 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
304303
}
305304
);
306305

307-
// Step 3: Create or update the progress document:
306+
// Step 2: Create or update the progress document:
308307
// - if the document did not exist, create the new one in the current write index
309308
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
310309
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
@@ -317,9 +316,10 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
317316
.id(progressDocId)
318317
.setRequireAlias(AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias))
319318
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
319+
List<PhaseProgress> progress = statsHolder.getProgressTracker().report();
320320
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
321-
LOGGER.debug("[{}] Persisting progress is: {}", jobId, stats.get().getProgress());
322-
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
321+
LOGGER.debug("[{}] Persisting progress is: {}", jobId, progress);
322+
new StoredProgress(progress).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
323323
indexRequest.source(jsonBuilder);
324324
}
325325
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
@@ -331,28 +331,14 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
331331
}
332332
);
333333

334-
// Step 2: Search for existing progress document in .ml-state*
335-
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
336-
statsResponse -> {
337-
stats.set(statsResponse.getResponse().results().get(0));
338-
SearchRequest searchRequest =
339-
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
340-
.source(
341-
new SearchSourceBuilder()
342-
.size(1)
343-
.query(new IdsQueryBuilder().addIds(progressDocId)));
344-
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
345-
},
346-
e -> {
347-
LOGGER.error(new ParameterizedMessage(
348-
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
349-
runnable.run();
350-
}
351-
);
352-
353-
// Step 1: Fetch progress to be persisted
354-
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
355-
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
334+
// Step 1: Search for existing progress document in .ml-state*
335+
SearchRequest searchRequest =
336+
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
337+
.source(
338+
new SearchSourceBuilder()
339+
.size(1)
340+
.query(new IdsQueryBuilder().addIds(progressDocId)));
341+
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
356342
}
357343

358344
/**

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616
import org.elasticsearch.cluster.service.ClusterService;
1717
import org.elasticsearch.common.settings.Settings;
1818
import org.elasticsearch.common.util.concurrent.ThreadContext;
19+
import org.elasticsearch.common.xcontent.DeprecationHandler;
20+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
21+
import org.elasticsearch.common.xcontent.XContentParser;
22+
import org.elasticsearch.common.xcontent.json.JsonXContent;
1923
import org.elasticsearch.persistent.PersistentTasksService;
2024
import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
2125
import org.elasticsearch.search.SearchHit;
2226
import org.elasticsearch.search.SearchHits;
2327
import org.elasticsearch.tasks.TaskManager;
2428
import org.elasticsearch.test.ESTestCase;
2529
import org.elasticsearch.threadpool.ThreadPool;
26-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
27-
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
2830
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
2931
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3032
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
33+
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
3134
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
3235
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
3336
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
@@ -36,6 +39,7 @@
3639
import org.mockito.InOrder;
3740
import org.mockito.stubbing.Answer;
3841

42+
import java.io.IOException;
3943
import java.util.Arrays;
4044
import java.util.Collections;
4145
import java.util.List;
@@ -126,14 +130,24 @@ public void testDetermineStartingState_GivenEmptyProgress() {
126130
assertThat(startingState, equalTo(StartingState.FINISHED));
127131
}
128132

129-
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
133+
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) throws IOException {
130134
Client client = mock(Client.class);
131135
ThreadPool threadPool = mock(ThreadPool.class);
132136
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
133137
when(client.threadPool()).thenReturn(threadPool);
134138

135-
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
136-
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
139+
ClusterService clusterService = mock(ClusterService.class);
140+
DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
141+
DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
142+
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
143+
144+
List<PhaseProgress> progress = List.of(
145+
new PhaseProgress(ProgressTracker.REINDEXING, 100),
146+
new PhaseProgress(ProgressTracker.LOADING_DATA, 50),
147+
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));
148+
149+
StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
150+
"task_id", Version.CURRENT, progress, false);
137151

138152
SearchResponse searchResponse = mock(SearchResponse.class);
139153
when(searchResponse.getHits()).thenReturn(searchHits);
@@ -142,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
142156
IndexResponse indexResponse = mock(IndexResponse.class);
143157
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
144158

159+
TaskManager taskManager = mock(TaskManager.class);
160+
145161
Runnable runnable = mock(Runnable.class);
146162

147-
DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
163+
DataFrameAnalyticsTask task =
164+
new DataFrameAnalyticsTask(
165+
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
166+
task.init(persistentTasksService, taskManager, "task-id", 42);
167+
168+
task.persistProgress(client, "task_id", runnable);
148169

149170
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
150171

151172
InOrder inOrder = inOrder(client, runnable);
152-
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
153173
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
154174
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
155175
inOrder.verify(runnable).run();
@@ -158,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
158178
IndexRequest indexRequest = indexRequestCaptor.getValue();
159179
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
160180
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
181+
182+
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
183+
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
184+
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
185+
assertThat(parsedProgress.get(), equalTo(progress));
186+
}
161187
}
162188

163-
public void testPersistProgress_ProgressDocumentCreated() {
189+
public void testPersistProgress_ProgressDocumentCreated() throws IOException {
164190
testPersistProgress(SearchHits.empty(), ".ml-state-write");
165191
}
166192

167-
public void testPersistProgress_ProgressDocumentUpdated() {
193+
public void testPersistProgress_ProgressDocumentUpdated() throws IOException {
168194
testPersistProgress(
169195
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Map.of("_index", ".ml-state-dummy")) }, null, 0.0f),
170196
".ml-state-dummy");
171197
}
172198

173-
public void testSetFailed() {
199+
public void testSetFailed() throws IOException {
174200
testSetFailed(false);
175201
}
176202

177-
public void testSetFailedDuringNodeShutdown() {
203+
public void testSetFailedDuringNodeShutdown() throws IOException {
178204
testSetFailed(true);
179205
}
180206

181-
private void testSetFailed(boolean nodeShuttingDown) {
207+
private void testSetFailed(boolean nodeShuttingDown) throws IOException {
182208
ThreadPool threadPool = mock(ThreadPool.class);
183209
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
184210
Client client = mock(Client.class);
@@ -190,15 +216,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
190216
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
191217
TaskManager taskManager = mock(TaskManager.class);
192218

219+
List<PhaseProgress> progress = List.of(
220+
new PhaseProgress(ProgressTracker.REINDEXING, 100),
221+
new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
222+
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30));
223+
193224
StartDataFrameAnalyticsAction.TaskParams taskParams =
194225
new StartDataFrameAnalyticsAction.TaskParams(
195226
"job-id",
196227
Version.CURRENT,
197-
List.of(
198-
new PhaseProgress(ProgressTracker.REINDEXING, 0),
199-
new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
200-
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
228+
progress,
201229
false);
230+
231+
SearchResponse searchResponse = mock(SearchResponse.class);
232+
when(searchResponse.getHits()).thenReturn(SearchHits.empty());
233+
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
234+
235+
IndexResponse indexResponse = mock(IndexResponse.class);
236+
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
237+
202238
DataFrameAnalyticsTask task =
203239
new DataFrameAnalyticsTask(
204240
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
@@ -210,7 +246,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
210246
verify(analyticsManager).isNodeShuttingDown();
211247
verify(client, atLeastOnce()).settings();
212248
verify(client, atLeastOnce()).threadPool();
249+
213250
if (nodeShuttingDown == false) {
251+
// Verify progress was persisted
252+
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
253+
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
254+
verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
255+
256+
IndexRequest indexRequest = indexRequestCaptor.getValue();
257+
assertThat(indexRequest.index(), equalTo(AnomalyDetectorsIndex.jobStateIndexWriteAlias()));
258+
assertThat(indexRequest.id(), equalTo("data_frame_analytics-job-id-progress"));
259+
260+
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
261+
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
262+
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
263+
assertThat(parsedProgress.get(), equalTo(progress));
264+
}
265+
214266
verify(client).execute(
215267
same(UpdatePersistentTaskStatusAction.INSTANCE),
216268
eq(new UpdatePersistentTaskStatusAction.Request(

0 commit comments

Comments
 (0)