Skip to content

Commit ed98ee0

Browse files
authored
[ML] Stop the ML memory tracker before closing node (#39111)
The ML memory tracker does searches against ML results and config indices. These searches can be asynchronous, and if they are running while the node is closing then they can cause problems for other components. This change adds a stop() method to the MlMemoryTracker that waits for in-flight searches to complete. Once stop() has returned the MlMemoryTracker will not kick off any new searches. The MlLifeCycleService now calls MlMemoryTracker.stop() before stopping stopping the node. Fixes #37117
1 parent 92ef753 commit ed98ee0

File tree

4 files changed

+68
-13
lines changed

4 files changed

+68
-13
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,10 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
435435
DatafeedManager datafeedManager = new DatafeedManager(threadPool, client, clusterService, datafeedJobBuilder,
436436
System::currentTimeMillis, auditor, autodetectProcessManager);
437437
this.datafeedManager.set(datafeedManager);
438-
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
439-
autodetectProcessManager);
440438
MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider);
441439
this.memoryTracker.set(memoryTracker);
440+
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
441+
autodetectProcessManager, memoryTracker);
442442

443443
// This object's constructor attaches to the license state, so there's no need to retain another reference to it
444444
new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.elasticsearch.common.component.LifecycleListener;
1010
import org.elasticsearch.env.Environment;
1111
import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
12+
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
1213
import org.elasticsearch.xpack.ml.process.NativeController;
1314
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
1415
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
@@ -20,16 +21,14 @@ public class MlLifeCycleService {
2021
private final Environment environment;
2122
private final DatafeedManager datafeedManager;
2223
private final AutodetectProcessManager autodetectProcessManager;
23-
24-
public MlLifeCycleService(Environment environment, ClusterService clusterService) {
25-
this(environment, clusterService, null, null);
26-
}
24+
private final MlMemoryTracker memoryTracker;
2725

2826
public MlLifeCycleService(Environment environment, ClusterService clusterService, DatafeedManager datafeedManager,
29-
AutodetectProcessManager autodetectProcessManager) {
27+
AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) {
3028
this.environment = environment;
3129
this.datafeedManager = datafeedManager;
3230
this.autodetectProcessManager = autodetectProcessManager;
31+
this.memoryTracker = memoryTracker;
3332
clusterService.addLifecycleListener(new LifecycleListener() {
3433
@Override
3534
public void beforeStop() {
@@ -59,5 +58,8 @@ public synchronized void stop() {
5958
} catch (IOException e) {
6059
// We're stopping anyway, so don't let this complicate the shutdown sequence
6160
}
61+
if (memoryTracker != null) {
62+
memoryTracker.stop();
63+
}
6264
}
6365
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.Iterator;
3333
import java.util.List;
3434
import java.util.concurrent.ConcurrentHashMap;
35+
import java.util.concurrent.Phaser;
3536
import java.util.stream.Collectors;
3637

3738
/**
@@ -55,6 +56,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
5556
private final ClusterService clusterService;
5657
private final JobManager jobManager;
5758
private final JobResultsProvider jobResultsProvider;
59+
private final Phaser stopPhaser;
5860
private volatile boolean isMaster;
5961
private volatile Instant lastUpdateTime;
6062
private volatile Duration reassignmentRecheckInterval;
@@ -65,6 +67,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
6567
this.clusterService = clusterService;
6668
this.jobManager = jobManager;
6769
this.jobResultsProvider = jobResultsProvider;
70+
this.stopPhaser = new Phaser(1);
6871
setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
6972
clusterService.addLocalNodeMasterListener(this);
7073
clusterService.getClusterSettings().addSettingsUpdateConsumer(
@@ -89,6 +92,23 @@ public void offMaster() {
8992
lastUpdateTime = null;
9093
}
9194

95+
/**
96+
* Wait for all outstanding searches to complete.
97+
* After returning, no new searches can be started.
98+
*/
99+
public void stop() {
100+
logger.trace("ML memory tracker stop called");
101+
// We never terminate the phaser
102+
assert stopPhaser.isTerminated() == false;
103+
// If there are no registered parties or no unarrived parties then there is a flaw
104+
// in the register/arrive/unregister logic in another method that uses the phaser
105+
assert stopPhaser.getRegisteredParties() > 0;
106+
assert stopPhaser.getUnarrivedParties() > 0;
107+
stopPhaser.arriveAndAwaitAdvance();
108+
assert stopPhaser.getPhase() > 0;
109+
logger.debug("ML memory tracker stopped");
110+
}
111+
92112
@Override
93113
public String executorName() {
94114
return MachineLearning.UTILITY_THREAD_POOL_NAME;
@@ -146,13 +166,13 @@ public boolean asyncRefresh() {
146166
try {
147167
ActionListener<Void> listener = ActionListener.wrap(
148168
aVoid -> logger.trace("Job memory requirement refresh request completed successfully"),
149-
e -> logger.error("Failed to refresh job memory requirements", e)
169+
e -> logger.warn("Failed to refresh job memory requirements", e)
150170
);
151171
threadPool.executor(executorName()).execute(
152172
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener));
153173
return true;
154174
} catch (EsRejectedExecutionException e) {
155-
logger.debug("Couldn't schedule ML memory update - node might be shutting down", e);
175+
logger.warn("Couldn't schedule ML memory update - node might be shutting down", e);
156176
}
157177
}
158178

@@ -246,25 +266,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
246266
return;
247267
}
248268

269+
// The phaser prevents searches being started after the memory tracker's stop() method has returned
270+
if (stopPhaser.register() != 0) {
271+
// Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
272+
stopPhaser.arriveAndDeregister();
273+
listener.onFailure(new EsRejectedExecutionException("Couldn't run ML memory update - node is shutting down"));
274+
return;
275+
}
276+
ActionListener<Long> phaserListener = ActionListener.wrap(
277+
r -> {
278+
stopPhaser.arriveAndDeregister();
279+
listener.onResponse(r);
280+
},
281+
e -> {
282+
stopPhaser.arriveAndDeregister();
283+
listener.onFailure(e);
284+
}
285+
);
286+
249287
try {
250288
jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null,
251289
establishedModelMemoryBytes -> {
252290
if (establishedModelMemoryBytes <= 0L) {
253-
setJobMemoryToLimit(jobId, listener);
291+
setJobMemoryToLimit(jobId, phaserListener);
254292
} else {
255293
Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes();
256294
memoryRequirementByJob.put(jobId, memoryRequirementBytes);
257-
listener.onResponse(memoryRequirementBytes);
295+
phaserListener.onResponse(memoryRequirementBytes);
258296
}
259297
},
260298
e -> {
261299
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
262-
setJobMemoryToLimit(jobId, listener);
300+
setJobMemoryToLimit(jobId, phaserListener);
263301
}
264302
);
265303
} catch (Exception e) {
266304
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
267-
setJobMemoryToLimit(jobId, listener);
305+
setJobMemoryToLimit(jobId, phaserListener);
268306
}
269307
}
270308

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.settings.ClusterSettings;
1111
import org.elasticsearch.common.settings.Settings;
1212
import org.elasticsearch.common.unit.ByteSizeUnit;
13+
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
1314
import org.elasticsearch.persistent.PersistentTasksClusterService;
1415
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
1516
import org.elasticsearch.test.ESTestCase;
@@ -29,6 +30,7 @@
2930
import java.util.concurrent.atomic.AtomicReference;
3031
import java.util.function.Consumer;
3132

33+
import static org.hamcrest.CoreMatchers.instanceOf;
3234
import static org.mockito.Matchers.any;
3335
import static org.mockito.Matchers.eq;
3436
import static org.mockito.Mockito.anyString;
@@ -157,6 +159,19 @@ public void testRefreshOne() {
157159
assertNull(memoryTracker.getJobMemoryRequirement(jobId));
158160
}
159161

162+
public void testStop() {
163+
164+
memoryTracker.onMaster();
165+
memoryTracker.stop();
166+
167+
AtomicReference<Exception> exception = new AtomicReference<>();
168+
memoryTracker.refreshJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set));
169+
170+
assertNotNull(exception.get());
171+
assertThat(exception.get(), instanceOf(EsRejectedExecutionException.class));
172+
assertEquals("Couldn't run ML memory update - node is shutting down", exception.get().getMessage());
173+
}
174+
160175
private PersistentTasksCustomMetaData.PersistentTask<OpenJobAction.JobParams> makeTestTask(String jobId) {
161176
return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId),
162177
0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);

0 commit comments

Comments
 (0)