Skip to content

Commit bf94542

Browse files
authored
[ML] Wait for model process to stop in stop deployment (#83644) (#83657)
1 parent d766ffc commit bf94542

File tree

7 files changed

+71
-73
lines changed

7 files changed

+71
-73
lines changed

docs/changelog/83644.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 83644
2+
summary: Wait for model process to be stop in stop deployment
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
2424
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
2525
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
26+
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
2627
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
2728
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2829
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
@@ -88,6 +89,11 @@ protected void doExecute(
8889
}, listener::onFailure));
8990
return;
9091
}
92+
if (allocation.getAllocationState() == AllocationState.STOPPING) {
93+
String message = "Trained model [" + deploymentId + "] is STOPPING";
94+
listener.onFailure(ExceptionsHelper.conflictStatusException(message));
95+
return;
96+
}
9197
String[] randomRunningNode = allocation.getStartedNodes();
9298
if (randomRunningNode.length == 0) {
9399
String message = "Trained model [" + deploymentId + "] is not allocated to any nodes";

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
4040
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
4141
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
42-
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
4342
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
4443

4544
import java.util.Collections;
@@ -66,7 +65,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
6665

6766
private final Client client;
6867
private final IngestService ingestService;
69-
private final TrainedModelAllocationService trainedModelAllocationService;
7068
private final TrainedModelAllocationClusterService trainedModelAllocationClusterService;
7169

7270
@Inject
@@ -76,7 +74,6 @@ public TransportStopTrainedModelDeploymentAction(
7674
ActionFilters actionFilters,
7775
Client client,
7876
IngestService ingestService,
79-
TrainedModelAllocationService trainedModelAllocationService,
8077
TrainedModelAllocationClusterService trainedModelAllocationClusterService
8178
) {
8279
super(
@@ -91,7 +88,6 @@ public TransportStopTrainedModelDeploymentAction(
9188
);
9289
this.client = new OriginSettingClient(client, ML_ORIGIN);
9390
this.ingestService = ingestService;
94-
this.trainedModelAllocationService = trainedModelAllocationService;
9591
this.trainedModelAllocationClusterService = trainedModelAllocationClusterService;
9692
}
9793

@@ -150,6 +146,7 @@ protected void doExecute(
150146
}
151147

152148
// NOTE, should only run on Master node
149+
assert clusterService.localNode().isMasterNode();
153150
trainedModelAllocationClusterService.setModelAllocationToStopping(
154151
modelId,
155152
ActionListener.wrap(
@@ -196,30 +193,25 @@ private void normalUndeploy(
196193
) {
197194
request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new));
198195
ActionListener<StopTrainedModelDeploymentAction.Response> finalListener = ActionListener.wrap(r -> {
199-
waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap(waited -> {
200-
trainedModelAllocationService.deleteModelAllocation(
201-
modelId,
202-
ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
203-
logger.error(
204-
() -> new ParameterizedMessage(
205-
"[{}] failed to delete model allocation after nodes unallocated the deployment",
206-
modelId
207-
),
196+
assert clusterService.localNode().isMasterNode();
197+
trainedModelAllocationClusterService.removeModelAllocation(
198+
modelId,
199+
ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
200+
logger.error(
201+
() -> new ParameterizedMessage(
202+
"[{}] failed to delete model allocation after nodes unallocated the deployment",
203+
modelId
204+
),
205+
deletionFailed
206+
);
207+
listener.onFailure(
208+
ExceptionsHelper.serverError(
209+
"failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
208210
deletionFailed
209-
);
210-
listener.onFailure(
211-
ExceptionsHelper.serverError(
212-
"failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
213-
deletionFailed
214-
)
215-
);
216-
})
217-
);
218-
},
219-
// TODO should we attempt to delete the deployment here?
220-
listener::onFailure
221-
));
222-
211+
)
212+
);
213+
})
214+
);
223215
}, e -> {
224216
if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) {
225217
// A node has dropped out of the cluster since we started executing the requests.
@@ -235,24 +227,6 @@ private void normalUndeploy(
235227
super.doExecute(task, request, finalListener);
236228
}
237229

238-
void waitForTaskRemoved(
239-
String modelId,
240-
TrainedModelAllocation trainedModelAllocation,
241-
StopTrainedModelDeploymentAction.Request request,
242-
StopTrainedModelDeploymentAction.Response response,
243-
ActionListener<StopTrainedModelDeploymentAction.Response> listener
244-
) {
245-
final Set<String> nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet();
246-
client.admin()
247-
.cluster()
248-
.prepareListTasks(nodesOfConcern.toArray(String[]::new))
249-
.setDetailed(true)
250-
.setWaitForCompletion(true)
251-
.setActions(modelId)
252-
.setTimeout(request.getTimeout())
253-
.execute(ActionListener.wrap(complete -> listener.onResponse(response), listener::onFailure));
254-
}
255-
256230
@Override
257231
protected StopTrainedModelDeploymentAction.Response newResponse(
258232
StopTrainedModelDeploymentAction.Request request,
@@ -275,7 +249,9 @@ protected void taskOperation(
275249
TrainedModelDeploymentTask task,
276250
ActionListener<StopTrainedModelDeploymentAction.Response> listener
277251
) {
278-
task.stop("undeploy_trained_model (api)");
279-
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
252+
task.stop(
253+
"undeploy_trained_model (api)",
254+
ActionListener.wrap(r -> listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)), listener::onFailure)
255+
);
280256
}
281257
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionL
135135
if (stopped) {
136136
return;
137137
}
138-
task.stopWithoutNotification(reason);
138+
task.markAsStopped(reason);
139+
139140
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
140141
try {
141142
deploymentManager.stopDeployment(task);
@@ -204,20 +205,12 @@ void loadQueuedModels() {
204205
loadingModels.addAll(loadingToRetry);
205206
}
206207

207-
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) {
208+
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
208209
ActionListener<Void> notifyDeploymentOfStopped = ActionListener.wrap(
209-
_void -> updateStoredState(
210-
task.getModelId(),
211-
new RoutingStateAndReason(RoutingState.STOPPED, reason),
212-
ActionListener.wrap(s -> {}, failure -> {})
213-
),
210+
_void -> updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener),
214211
failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop
215212
logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", task.getModelId()), failed);
216-
updateStoredState(
217-
task.getModelId(),
218-
new RoutingStateAndReason(RoutingState.STOPPED, reason),
219-
ActionListener.wrap(s -> {}, failure -> {})
220-
);
213+
updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener);
221214
}
222215
);
223216
updateStoredState(
@@ -309,7 +302,7 @@ public void clusterChanged(ClusterChangedEvent event) {
309302
&& isResetMode == false) {
310303
prepareModelToLoad(trainedModelAllocation.getTaskParams());
311304
}
312-
// This mode is not routed to the current node at all
305+
// This model is not routed to the current node at all
313306
if (routingStateAndReason == null) {
314307
TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId());
315308
if (task != null) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.apache.logging.log4j.message.ParameterizedMessage;
1213
import org.apache.lucene.util.SetOnce;
1314
import org.elasticsearch.ElasticsearchStatusException;
1415
import org.elasticsearch.action.ActionListener;
16+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
1517
import org.elasticsearch.core.TimeValue;
1618
import org.elasticsearch.license.LicensedFeature;
1719
import org.elasticsearch.license.XPackLicenseState;
@@ -80,15 +82,11 @@ public TaskParams getParams() {
8082
return params;
8183
}
8284

83-
public void stop(String reason) {
84-
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
85-
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
86-
stopped = true;
87-
stoppedReasonHolder.trySet(reason);
88-
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
85+
public void stop(String reason, ActionListener<AcknowledgedResponse> listener) {
86+
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason, listener);
8987
}
9088

91-
public void stopWithoutNotification(String reason) {
89+
public void markAsStopped(String reason) {
9290
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
9391
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
9492
stoppedReasonHolder.trySet(reason);
@@ -106,7 +104,14 @@ public Optional<String> stoppedReason() {
106104
@Override
107105
protected void onCancelled() {
108106
String reason = getReasonCancelled();
109-
stop(reason);
107+
logger.info("[{}] task cancelled due to reason [{}]", getModelId(), reason);
108+
stop(
109+
reason,
110+
ActionListener.wrap(
111+
acknowledgedResponse -> {},
112+
e -> logger.error(new ParameterizedMessage("[{}] error stopping the model after task cancellation", getModelId()), e)
113+
)
114+
);
110115
}
111116

112117
public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception {
196196
// Only one model should be loaded, the other should be stopped
197197
trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad));
198198
trainedModelAllocationNodeService.prepareModelToLoad(newParams(stoppedModelToLoad));
199-
trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing");
199+
trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing", ActionListener.wrap(r -> {}, e -> {}));
200200
trainedModelAllocationNodeService.loadQueuedModels();
201201

202202
assertBusy(() -> {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@
77

88
package org.elasticsearch.xpack.ml.inference.deployment;
99

10+
import org.elasticsearch.action.ActionListener;
1011
import org.elasticsearch.license.LicensedFeature;
1112
import org.elasticsearch.license.XPackLicenseState;
1213
import org.elasticsearch.tasks.TaskId;
1314
import org.elasticsearch.test.ESTestCase;
1415
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
1617
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
18+
import org.mockito.ArgumentCaptor;
1719

1820
import java.util.Map;
1921
import java.util.function.Consumer;
2022

2123
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_ACTION;
2224
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
25+
import static org.mockito.ArgumentMatchers.any;
26+
import static org.mockito.Mockito.doAnswer;
2327
import static org.mockito.Mockito.mock;
2428
import static org.mockito.Mockito.times;
2529
import static org.mockito.Mockito.verify;
@@ -29,6 +33,15 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
2933
void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
3034
XPackLicenseState licenseState = mock(XPackLicenseState.class);
3135
LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
36+
TrainedModelAllocationNodeService nodeService = mock(TrainedModelAllocationNodeService.class);
37+
38+
ArgumentCaptor<TrainedModelDeploymentTask> taskCaptor = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
39+
ArgumentCaptor<String> reasonCaptur = ArgumentCaptor.forClass(String.class);
40+
doAnswer(invocation -> {
41+
taskCaptor.getValue().markAsStopped(reasonCaptur.getValue());
42+
return null;
43+
}).when(nodeService).stopDeploymentAndNotify(taskCaptor.capture(), reasonCaptur.capture(), any());
44+
3245
TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
3346
0,
3447
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
@@ -42,7 +55,7 @@ void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String
4255
randomInt(5),
4356
randomInt(5)
4457
),
45-
mock(TrainedModelAllocationNodeService.class),
58+
nodeService,
4659
licenseState,
4760
feature
4861
);
@@ -53,12 +66,12 @@ void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String
5366
verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
5467
}
5568

56-
public void testOnStopWithoutNotification() {
57-
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
69+
public void testMarkAsStopped() {
70+
assertTrackingComplete(t -> t.markAsStopped("foo"), randomAlphaOfLength(10));
5871
}
5972

6073
public void testOnStop() {
61-
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
74+
assertTrackingComplete(t -> t.stop("foo", ActionListener.wrap(r -> {}, e -> {})), randomAlphaOfLength(10));
6275
}
6376

6477
public void testCancelled() {

0 commit comments

Comments
 (0)