32
32
import java .util .Iterator ;
33
33
import java .util .List ;
34
34
import java .util .concurrent .ConcurrentHashMap ;
35
+ import java .util .concurrent .Phaser ;
35
36
import java .util .stream .Collectors ;
36
37
37
38
/**
@@ -55,6 +56,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
55
56
private final ClusterService clusterService ;
56
57
private final JobManager jobManager ;
57
58
private final JobResultsProvider jobResultsProvider ;
59
+ private final Phaser stopPhaser ;
58
60
private volatile boolean isMaster ;
59
61
private volatile Instant lastUpdateTime ;
60
62
private volatile Duration reassignmentRecheckInterval ;
@@ -65,6 +67,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
65
67
this .clusterService = clusterService ;
66
68
this .jobManager = jobManager ;
67
69
this .jobResultsProvider = jobResultsProvider ;
70
+ this .stopPhaser = new Phaser (1 );
68
71
setReassignmentRecheckInterval (PersistentTasksClusterService .CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING .get (settings ));
69
72
clusterService .addLocalNodeMasterListener (this );
70
73
clusterService .getClusterSettings ().addSettingsUpdateConsumer (
@@ -89,6 +92,23 @@ public void offMaster() {
89
92
lastUpdateTime = null ;
90
93
}
91
94
95
+ /**
96
+ * Wait for all outstanding searches to complete.
97
+ * After returning, no new searches can be started.
98
+ */
99
+ public void stop () {
100
+ logger .trace ("ML memory tracker stop called" );
101
+ // We never terminate the phaser
102
+ assert stopPhaser .isTerminated () == false ;
103
+ // If there are no registered parties or no unarrived parties then there is a flaw
104
+ // in the register/arrive/unregister logic in another method that uses the phaser
105
+ assert stopPhaser .getRegisteredParties () > 0 ;
106
+ assert stopPhaser .getUnarrivedParties () > 0 ;
107
+ stopPhaser .arriveAndAwaitAdvance ();
108
+ assert stopPhaser .getPhase () > 0 ;
109
+ logger .debug ("ML memory tracker stopped" );
110
+ }
111
+
92
112
@ Override
93
113
public String executorName () {
94
114
return MachineLearning .UTILITY_THREAD_POOL_NAME ;
@@ -146,13 +166,13 @@ public boolean asyncRefresh() {
146
166
try {
147
167
ActionListener <Void > listener = ActionListener .wrap (
148
168
aVoid -> logger .trace ("Job memory requirement refresh request completed successfully" ),
149
- e -> logger .error ("Failed to refresh job memory requirements" , e )
169
+ e -> logger .warn ("Failed to refresh job memory requirements" , e )
150
170
);
151
171
threadPool .executor (executorName ()).execute (
152
172
() -> refresh (clusterService .state ().getMetaData ().custom (PersistentTasksCustomMetaData .TYPE ), listener ));
153
173
return true ;
154
174
} catch (EsRejectedExecutionException e ) {
155
- logger .debug ("Couldn't schedule ML memory update - node might be shutting down" , e );
175
+ logger .warn ("Couldn't schedule ML memory update - node might be shutting down" , e );
156
176
}
157
177
}
158
178
@@ -246,25 +266,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
246
266
return ;
247
267
}
248
268
269
+ // The phaser prevents searches being started after the memory tracker's stop() method has returned
270
+ if (stopPhaser .register () != 0 ) {
271
+ // Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
272
+ stopPhaser .arriveAndDeregister ();
273
+ listener .onFailure (new EsRejectedExecutionException ("Couldn't run ML memory update - node is shutting down" ));
274
+ return ;
275
+ }
276
+ ActionListener <Long > phaserListener = ActionListener .wrap (
277
+ r -> {
278
+ stopPhaser .arriveAndDeregister ();
279
+ listener .onResponse (r );
280
+ },
281
+ e -> {
282
+ stopPhaser .arriveAndDeregister ();
283
+ listener .onFailure (e );
284
+ }
285
+ );
286
+
249
287
try {
250
288
jobResultsProvider .getEstablishedMemoryUsage (jobId , null , null ,
251
289
establishedModelMemoryBytes -> {
252
290
if (establishedModelMemoryBytes <= 0L ) {
253
- setJobMemoryToLimit (jobId , listener );
291
+ setJobMemoryToLimit (jobId , phaserListener );
254
292
} else {
255
293
Long memoryRequirementBytes = establishedModelMemoryBytes + Job .PROCESS_MEMORY_OVERHEAD .getBytes ();
256
294
memoryRequirementByJob .put (jobId , memoryRequirementBytes );
257
- listener .onResponse (memoryRequirementBytes );
295
+ phaserListener .onResponse (memoryRequirementBytes );
258
296
}
259
297
},
260
298
e -> {
261
299
logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
262
- setJobMemoryToLimit (jobId , listener );
300
+ setJobMemoryToLimit (jobId , phaserListener );
263
301
}
264
302
);
265
303
} catch (Exception e ) {
266
304
logger .error ("[" + jobId + "] failed to calculate job established model memory requirement" , e );
267
- setJobMemoryToLimit (jobId , listener );
305
+ setJobMemoryToLimit (jobId , phaserListener );
268
306
}
269
307
}
270
308
0 commit comments