Skip to content

[ML] Persist progress when setting DFA task to failed #61782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF

Exception reindexError = getReindexError(task.getParams().getId(), reindexResponse);
if (reindexError != null) {
task.markAsFailed(reindexError);
task.setFailed(reindexError);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
Expand Down Expand Up @@ -38,7 +37,6 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
Expand Down Expand Up @@ -216,23 +214,25 @@ public void setFailed(Exception error) {
error);
return;
}
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
persistProgress(client, taskParams.getId(), () -> {
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
DataFrameAnalyticsState.FAILED, reason);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
});
}

public void updateReindexTaskProgress(ActionListener<Void> listener) {
Expand Down Expand Up @@ -285,13 +285,12 @@ private TaskId getReindexTaskId() {
}

// Visible for testing
static void persistProgress(Client client, String jobId, Runnable runnable) {
void persistProgress(Client client, String jobId, Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", jobId);

String progressDocId = StoredProgress.documentId(jobId);
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();

// Step 4: Run the runnable provided as the argument
// Step 3: Run the runnable provided as the argument
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
Expand All @@ -304,7 +303,7 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
}
);

// Step 3: Create or update the progress document:
// Step 2: Create or update the progress document:
// - if the document did not exist, create the new one in the current write index
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
Expand All @@ -317,9 +316,10 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
.id(progressDocId)
.setRequireAlias(AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<PhaseProgress> progress = statsHolder.getProgressTracker().report();
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
LOGGER.debug("[{}] Persisting progress is: {}", jobId, stats.get().getProgress());
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
LOGGER.debug("[{}] Persisting progress is: {}", jobId, progress);
new StoredProgress(progress).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
indexRequest.source(jsonBuilder);
}
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
Expand All @@ -331,28 +331,14 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
}
);

// Step 2: Search for existing progress document in .ml-state*
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
statsResponse -> {
stats.set(statsResponse.getResponse().results().get(0));
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
runnable.run();
}
);

// Step 1: Fetch progress to be persisted
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
// Step 1: Search for existing progress document in .ml-state*
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
Expand All @@ -36,6 +39,7 @@
import org.mockito.InOrder;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -126,14 +130,24 @@ public void testDetermineStartingState_GivenEmptyProgress() {
assertThat(startingState, equalTo(StartingState.FINISHED));
}

private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) throws IOException {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(threadPool);

GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
ClusterService clusterService = mock(ClusterService.class);
DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);

List<PhaseProgress> progress = List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 50),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));

StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
"task_id", Version.CURRENT, progress, false);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
Expand All @@ -142,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());

TaskManager taskManager = mock(TaskManager.class);

Runnable runnable = mock(Runnable.class);

DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
task.init(persistentTasksService, taskManager, "task-id", 42);

task.persistProgress(client, "task_id", runnable);

ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);

InOrder inOrder = inOrder(client, runnable);
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
inOrder.verify(runnable).run();
Expand All @@ -158,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));

try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}
}

public void testPersistProgress_ProgressDocumentCreated() {
public void testPersistProgress_ProgressDocumentCreated() throws IOException {
testPersistProgress(SearchHits.empty(), ".ml-state-write");
}

public void testPersistProgress_ProgressDocumentUpdated() {
public void testPersistProgress_ProgressDocumentUpdated() throws IOException {
testPersistProgress(
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Map.of("_index", ".ml-state-dummy")) }, null, 0.0f),
".ml-state-dummy");
}

public void testSetFailed() {
public void testSetFailed() throws IOException {
testSetFailed(false);
}

public void testSetFailedDuringNodeShutdown() {
public void testSetFailedDuringNodeShutdown() throws IOException {
testSetFailed(true);
}

private void testSetFailed(boolean nodeShuttingDown) {
private void testSetFailed(boolean nodeShuttingDown) throws IOException {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
Client client = mock(Client.class);
Expand All @@ -190,15 +216,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
TaskManager taskManager = mock(TaskManager.class);

List<PhaseProgress> progress = List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30));

StartDataFrameAnalyticsAction.TaskParams taskParams =
new StartDataFrameAnalyticsAction.TaskParams(
"job-id",
Version.CURRENT,
List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 0),
new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
progress,
false);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(SearchHits.empty());
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());

IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());

DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
Expand All @@ -210,7 +246,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
verify(analyticsManager).isNodeShuttingDown();
verify(client, atLeastOnce()).settings();
verify(client, atLeastOnce()).threadPool();

if (nodeShuttingDown == false) {
// Verify progress was persisted
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());

IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(AnomalyDetectorsIndex.jobStateIndexWriteAlias()));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-job-id-progress"));

try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}

verify(client).execute(
same(UpdatePersistentTaskStatusAction.INSTANCE),
eq(new UpdatePersistentTaskStatusAction.Request(
Expand Down