Skip to content

[ML] Fix master node deadlock during ML daily maintenance #31836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
Expand Down Expand Up @@ -57,8 +58,8 @@ private void deleteExpiredData(ActionListener<DeleteExpiredDataAction.Response>
Auditor auditor = new Auditor(client, clusterService.nodeName());
List<MlDataRemover> dataRemovers = Arrays.asList(
new ExpiredResultsRemover(client, clusterService, auditor),
new ExpiredForecastsRemover(client),
new ExpiredModelSnapshotsRemover(client, clusterService),
new ExpiredForecastsRemover(client, threadPool),
new ExpiredModelSnapshotsRemover(client, threadPool, clusterService),
new UnusedStateRemover(client, clusterService)
);
Iterator<MlDataRemover> dataRemoversIterator = new VolatileCursorIterator<>(dataRemovers);
Expand All @@ -69,9 +70,15 @@ private void deleteExpiredData(Iterator<MlDataRemover> mlDataRemoversIterator,
ActionListener<DeleteExpiredDataAction.Response> listener) {
if (mlDataRemoversIterator.hasNext()) {
MlDataRemover remover = mlDataRemoversIterator.next();
remover.remove(ActionListener.wrap(
booleanResponse -> deleteExpiredData(mlDataRemoversIterator, listener),
listener::onFailure));
ActionListener<Boolean> nextListener = ActionListener.wrap(
booleanResponse -> deleteExpiredData(mlDataRemoversIterator, listener), listener::onFailure);
// Removing expired ML data and artifacts requires multiple operations.
// These are queued up and executed sequentially in the action listener,
// the chained calls must all run the ML utility thread pool NOT the thread
// the previous action returned in which in the case of a transport_client_boss
// thread is a disaster.
remover.remove(new ThreadedActionListener<>(logger, threadPool, MachineLearning.UTILITY_THREAD_POOL_NAME, nextListener,
false));
} else {
logger.info("Completed deletion of expired data");
listener.onResponse(new DeleteExpiredDataAction.Response(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
Expand All @@ -27,11 +28,13 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.results.Forecast;
import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats;
import org.elasticsearch.xpack.core.ml.job.results.Result;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.joda.time.DateTime;
import org.joda.time.chrono.ISOChronology;

Expand All @@ -57,10 +60,12 @@ public class ExpiredForecastsRemover implements MlDataRemover {
private static final String RESULTS_INDEX_PATTERN = AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*";

private final Client client;
private final ThreadPool threadPool;
private final long cutoffEpochMs;

public ExpiredForecastsRemover(Client client) {
public ExpiredForecastsRemover(Client client, ThreadPool threadPool) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.cutoffEpochMs = DateTime.now(ISOChronology.getInstance()).getMillis();
}

Expand All @@ -79,7 +84,8 @@ public void remove(ActionListener<Boolean> listener) {

SearchRequest searchRequest = new SearchRequest(RESULTS_INDEX_PATTERN);
searchRequest.source(source);
client.execute(SearchAction.INSTANCE, searchRequest, forecastStatsHandler);
client.execute(SearchAction.INSTANCE, searchRequest, new ThreadedActionListener<>(LOGGER, threadPool,
MachineLearning.UTILITY_THREAD_POOL_NAME, forecastStatsHandler, false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the removeDataBefore() method in ExpiredModelSnapshotsRemover should also use a ThreadedActionListener in exactly the same way this class does. It also has the problem of doing a (potentially) large amount of parsing on the network thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I pushed a commit to fix that too.

}

private void deleteForecasts(SearchResponse searchResponse, ActionListener<Boolean> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshotField;
import org.elasticsearch.xpack.ml.MachineLearning;

import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -51,10 +54,12 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000;

private final Client client;
private final ThreadPool threadPool;

public ExpiredModelSnapshotsRemover(Client client, ClusterService clusterService) {
public ExpiredModelSnapshotsRemover(Client client, ThreadPool threadPool, ClusterService clusterService) {
super(clusterService);
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
}

@Override
Expand Down Expand Up @@ -84,7 +89,12 @@ protected void removeDataBefore(Job job, long cutoffEpochMs, ActionListener<Bool

searchRequest.source(new SearchSourceBuilder().query(query).size(MODEL_SNAPSHOT_SEARCH_SIZE));

client.execute(SearchAction.INSTANCE, searchRequest, new ActionListener<SearchResponse>() {
client.execute(SearchAction.INSTANCE, searchRequest, new ThreadedActionListener<>(LOGGER, threadPool,
MachineLearning.UTILITY_THREAD_POOL_NAME, expiredSnapshotsListener(job.getId(), listener), false));
}

private ActionListener<SearchResponse> expiredSnapshotsListener(String jobId, ActionListener<Boolean> listener) {
return new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
try {
Expand All @@ -100,9 +110,9 @@ public void onResponse(SearchResponse searchResponse) {

@Override
public void onFailure(Exception e) {
listener.onFailure(new ElasticsearchException("[" + job.getId() + "] Search for expired snapshots failed", e));
listener.onFailure(new ElasticsearchException("[" + jobId + "] Search for expired snapshots failed", e));
}
});
};
}

private void deleteModelSnapshots(Iterator<ModelSnapshot> modelSnapshotIterator, ActionListener<Boolean> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,25 @@
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MLMetadataField;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobTests;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.junit.After;
import org.junit.Before;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
Expand All @@ -38,24 +43,27 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {

private Client client;
private ThreadPool threadPool;
private ClusterService clusterService;
private ClusterState clusterState;
private List<SearchRequest> capturedSearchRequests;
private List<DeleteModelSnapshotAction.Request> capturedDeleteModelSnapshotRequests;
private List<SearchResponse> searchResponsesPerCall;
private ActionListener<Boolean> listener;
private TestListener listener;

@Before
public void setUpTests() {
Expand All @@ -66,7 +74,19 @@ public void setUpTests() {
clusterState = mock(ClusterState.class);
when(clusterService.state()).thenReturn(clusterState);
client = mock(Client.class);
listener = mock(ActionListener.class);
listener = new TestListener();

// Init thread pool
Settings settings = Settings.builder()
.put("node.name", "expired_model_snapshots_remover_test")
.build();
threadPool = new ThreadPool(settings,
new FixedExecutorBuilder(settings, MachineLearning.UTILITY_THREAD_POOL_NAME, 1, 1000, ""));
}

@After
public void shutdownThreadPool() throws InterruptedException {
terminate(threadPool);
}

public void testRemove_GivenJobsWithoutRetentionPolicy() {
Expand All @@ -78,7 +98,8 @@ public void testRemove_GivenJobsWithoutRetentionPolicy() {

createExpiredModelSnapshotsRemover().remove(listener);

verify(listener).onResponse(true);
listener.waitToCompletion();
assertThat(listener.success, is(true));
Mockito.verifyNoMoreInteractions(client);
}

Expand All @@ -88,7 +109,8 @@ public void testRemove_GivenJobWithoutActiveSnapshot() {

createExpiredModelSnapshotsRemover().remove(listener);

verify(listener).onResponse(true);
listener.waitToCompletion();
assertThat(listener.success, is(true));
Mockito.verifyNoMoreInteractions(client);
}

Expand All @@ -108,6 +130,9 @@ public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException

createExpiredModelSnapshotsRemover().remove(listener);

listener.waitToCompletion();
assertThat(listener.success, is(true));

assertThat(capturedSearchRequests.size(), equalTo(2));
SearchRequest searchRequest = capturedSearchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
Expand All @@ -124,8 +149,6 @@ public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException
deleteSnapshotRequest = capturedDeleteModelSnapshotRequests.get(2);
assertThat(deleteSnapshotRequest.getJobId(), equalTo("snapshots-2"));
assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-2_1"));

verify(listener).onResponse(true);
}

public void testRemove_GivenClientSearchRequestsFail() throws IOException {
Expand All @@ -144,13 +167,14 @@ public void testRemove_GivenClientSearchRequestsFail() throws IOException {

createExpiredModelSnapshotsRemover().remove(listener);

listener.waitToCompletion();
assertThat(listener.success, is(false));

assertThat(capturedSearchRequests.size(), equalTo(1));
SearchRequest searchRequest = capturedSearchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));

assertThat(capturedDeleteModelSnapshotRequests.size(), equalTo(0));

verify(listener).onFailure(any());
}

public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException {
Expand All @@ -169,6 +193,9 @@ public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOExceptio

createExpiredModelSnapshotsRemover().remove(listener);

listener.waitToCompletion();
assertThat(listener.success, is(false));

assertThat(capturedSearchRequests.size(), equalTo(1));
SearchRequest searchRequest = capturedSearchRequests.get(0);
assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
Expand All @@ -177,8 +204,6 @@ public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOExceptio
DeleteModelSnapshotAction.Request deleteSnapshotRequest = capturedDeleteModelSnapshotRequests.get(0);
assertThat(deleteSnapshotRequest.getJobId(), equalTo("snapshots-1"));
assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1"));

verify(listener).onFailure(any());
}

private void givenJobs(List<Job> jobs) {
Expand All @@ -192,7 +217,7 @@ private void givenJobs(List<Job> jobs) {
}

private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() {
return new ExpiredModelSnapshotsRemover(client, clusterService);
return new ExpiredModelSnapshotsRemover(client, threadPool, clusterService);
}

private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) {
Expand Down Expand Up @@ -230,7 +255,7 @@ private void givenClientRequests(boolean shouldSearchRequestsSucceed, boolean sh
int callCount = 0;

@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
public Void answer(InvocationOnMock invocationOnMock) {
SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1];
capturedSearchRequests.add(searchRequest);
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
Expand All @@ -244,7 +269,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
}).when(client).execute(same(SearchAction.INSTANCE), any(), any());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
public Void answer(InvocationOnMock invocationOnMock) {
capturedDeleteModelSnapshotRequests.add((DeleteModelSnapshotAction.Request) invocationOnMock.getArguments()[1]);
ActionListener<DeleteModelSnapshotAction.Response> listener =
(ActionListener<DeleteModelSnapshotAction.Response>) invocationOnMock.getArguments()[2];
Expand All @@ -257,4 +282,30 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
}
}).when(client).execute(same(DeleteModelSnapshotAction.INSTANCE), any(), any());
}

private class TestListener implements ActionListener<Boolean> {

private boolean success;
private final CountDownLatch latch = new CountDownLatch(1);

@Override
public void onResponse(Boolean aBoolean) {
success = aBoolean;
latch.countDown();
}

@Override
public void onFailure(Exception e) {
latch.countDown();
}

public void waitToCompletion() {
try {
latch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
fail("listener timed out before completing");
}
}
}

}