Skip to content

[ML] Fix ML memory tracker lockup when inner step fails #44158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,18 @@ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener<Void>
}
fullRefreshCompletionListeners.clear();
}
}, onCompletion::onFailure);
},
e -> {
synchronized (fullRefreshCompletionListeners) {
assert fullRefreshCompletionListeners.isEmpty() == false;
for (ActionListener<Void> listener : fullRefreshCompletionListeners) {
listener.onFailure(e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to comment on how we don't signal onCompletion any more, but then I saw line 286. This is definitely a sneaky bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as well as making it impossible to retry the bug also meant that only 1 of the queuing listeners got notified of failures. The others didn't receive any notification at all.

}
// It's critical that we empty out the current listener list on
// error otherwise subsequent retries to refresh will be ignored
fullRefreshCompletionListeners.clear();
}
});

// persistentTasks will be null if there's never been a persistent task created in this cluster
if (persistentTasks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
Expand All @@ -32,11 +33,13 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -125,6 +128,66 @@ public void testRefreshAll() {
}
}

public void testRefreshAllFailure() {

Map<String, PersistentTasksCustomMetaData.PersistentTask<?>> tasks = new HashMap<>();

int numAnomalyDetectorJobTasks = randomIntBetween(2, 5);
for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) {
String jobId = "job" + i;
PersistentTasksCustomMetaData.PersistentTask<?> task = makeTestAnomalyDetectorTask(jobId);
tasks.put(task.getId(), task);
}

int numDataFrameAnalyticsTasks = randomIntBetween(2, 5);
for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) {
String id = "analytics" + i;
PersistentTasksCustomMetaData.PersistentTask<?> task = makeTestDataFrameAnalyticsTask(id);
tasks.put(task.getId(), task);
}

PersistentTasksCustomMetaData persistentTasks =
new PersistentTasksCustomMetaData(numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks, tasks);

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
Consumer<Long> listener = (Consumer<Long>) invocation.getArguments()[3];
listener.accept(randomLongBetween(1000, 1000000));
return null;
}).when(jobResultsProvider).getEstablishedMemoryUsage(anyString(), any(), any(), any(Consumer.class), any());

// First run a refresh using a component that calls the onFailure method of the listener

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<List<DataFrameAnalyticsConfig>> listener =
(ActionListener<List<DataFrameAnalyticsConfig>>) invocation.getArguments()[2];
listener.onFailure(new IllegalArgumentException("computer says no"));
return null;
}).when(configProvider).getMultiple(anyString(), anyBoolean(), any(ActionListener.class));

AtomicBoolean gotErrorResponse = new AtomicBoolean(false);
memoryTracker.refresh(persistentTasks,
ActionListener.wrap(aVoid -> fail("Expected error response"), e -> gotErrorResponse.set(true)));
assertTrue(gotErrorResponse.get());

// Now run another refresh using a component that calls the onResponse method of the listener - this
// proves that the ML memory tracker has not been permanently blocked up by the previous failure

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<List<DataFrameAnalyticsConfig>> listener =
(ActionListener<List<DataFrameAnalyticsConfig>>) invocation.getArguments()[2];
listener.onResponse(Collections.emptyList());
return null;
}).when(configProvider).getMultiple(anyString(), anyBoolean(), any(ActionListener.class));

AtomicBoolean gotSuccessResponse = new AtomicBoolean(false);
memoryTracker.refresh(persistentTasks,
ActionListener.wrap(aVoid -> gotSuccessResponse.set(true), e -> fail("Expected success response")));
assertTrue(gotSuccessResponse.get());
}

public void testRefreshOneAnomalyDetectorJob() {

boolean isMaster = randomBoolean();
Expand Down