Skip to content

[ML] Wait for model process to stop in stop deployment #83644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/83644.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 83644
summary: Wait for model process to be stop in stop deployment
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
Expand Down Expand Up @@ -88,6 +89,11 @@ protected void doExecute(
}, listener::onFailure));
return;
}
if (allocation.getAllocationState() == AllocationState.STOPPING) {
String message = "Trained model [" + deploymentId + "] is STOPPING";
listener.onFailure(ExceptionsHelper.conflictStatusException(message));
return;
}
String[] randomRunningNode = allocation.getStartedNodes();
if (randomRunningNode.length == 0) {
String message = "Trained model [" + deploymentId + "] is not allocated to any nodes";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

import java.util.Collections;
Expand All @@ -66,7 +65,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct

private final Client client;
private final IngestService ingestService;
private final TrainedModelAllocationService trainedModelAllocationService;
private final TrainedModelAllocationClusterService trainedModelAllocationClusterService;

@Inject
Expand All @@ -76,7 +74,6 @@ public TransportStopTrainedModelDeploymentAction(
ActionFilters actionFilters,
Client client,
IngestService ingestService,
TrainedModelAllocationService trainedModelAllocationService,
TrainedModelAllocationClusterService trainedModelAllocationClusterService
) {
super(
Expand All @@ -91,7 +88,6 @@ public TransportStopTrainedModelDeploymentAction(
);
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.ingestService = ingestService;
this.trainedModelAllocationService = trainedModelAllocationService;
this.trainedModelAllocationClusterService = trainedModelAllocationClusterService;
}

Expand Down Expand Up @@ -150,6 +146,7 @@ protected void doExecute(
}

// NOTE, should only run on Master node
assert clusterService.localNode().isMasterNode();
trainedModelAllocationClusterService.setModelAllocationToStopping(
modelId,
ActionListener.wrap(
Expand Down Expand Up @@ -196,30 +193,25 @@ private void normalUndeploy(
) {
request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new));
ActionListener<StopTrainedModelDeploymentAction.Response> finalListener = ActionListener.wrap(r -> {
waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap(waited -> {
trainedModelAllocationService.deleteModelAllocation(
modelId,
ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
logger.error(
() -> new ParameterizedMessage(
"[{}] failed to delete model allocation after nodes unallocated the deployment",
modelId
),
assert clusterService.localNode().isMasterNode();
trainedModelAllocationClusterService.removeModelAllocation(
modelId,
ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
logger.error(
() -> new ParameterizedMessage(
"[{}] failed to delete model allocation after nodes unallocated the deployment",
modelId
),
deletionFailed
);
listener.onFailure(
ExceptionsHelper.serverError(
"failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
deletionFailed
);
listener.onFailure(
ExceptionsHelper.serverError(
"failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
deletionFailed
)
);
})
);
},
// TODO should we attempt to delete the deployment here?
listener::onFailure
));

)
);
})
);
}, e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) {
// A node has dropped out of the cluster since we started executing the requests.
Expand All @@ -235,24 +227,6 @@ private void normalUndeploy(
super.doExecute(task, request, finalListener);
}

void waitForTaskRemoved(
String modelId,
TrainedModelAllocation trainedModelAllocation,
StopTrainedModelDeploymentAction.Request request,
StopTrainedModelDeploymentAction.Response response,
ActionListener<StopTrainedModelDeploymentAction.Response> listener
) {
final Set<String> nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet();
client.admin()
.cluster()
.prepareListTasks(nodesOfConcern.toArray(String[]::new))
.setDetailed(true)
.setWaitForCompletion(true)
.setActions(modelId)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this bug was added in: #81259

The actions used to contain the model id. Regardless, the new stopping path is much cleaner.

.setTimeout(request.getTimeout())
.execute(ActionListener.wrap(complete -> listener.onResponse(response), listener::onFailure));
}

@Override
protected StopTrainedModelDeploymentAction.Response newResponse(
StopTrainedModelDeploymentAction.Request request,
Expand All @@ -275,7 +249,9 @@ protected void taskOperation(
TrainedModelDeploymentTask task,
ActionListener<StopTrainedModelDeploymentAction.Response> listener
) {
task.stop("undeploy_trained_model (api)");
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
task.stop(
"undeploy_trained_model (api)",
ActionListener.wrap(r -> listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)), listener::onFailure)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much cleaner and the execution path is more easily read.

);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionL
if (stopped) {
return;
}
task.stopWithoutNotification(reason);
task.markAsStopped(reason);

threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
try {
deploymentManager.stopDeployment(task);
Expand Down Expand Up @@ -204,20 +205,12 @@ void loadQueuedModels() {
loadingModels.addAll(loadingToRetry);
}

public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) {
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
ActionListener<Void> notifyDeploymentOfStopped = ActionListener.wrap(
_void -> updateStoredState(
task.getModelId(),
new RoutingStateAndReason(RoutingState.STOPPED, reason),
ActionListener.wrap(s -> {}, failure -> {})
),
_void -> updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener),
failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop
logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", task.getModelId()), failed);
updateStoredState(
task.getModelId(),
new RoutingStateAndReason(RoutingState.STOPPED, reason),
ActionListener.wrap(s -> {}, failure -> {})
);
updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener);
}
);
updateStoredState(
Expand Down Expand Up @@ -309,7 +302,7 @@ public void clusterChanged(ClusterChangedEvent event) {
&& isResetMode == false) {
prepareModelToLoad(trainedModelAllocation.getTaskParams());
}
// This mode is not routed to the current node at all
// This model is not routed to the current node at all
if (routingStateAndReason == null) {
TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId());
if (task != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
Expand Down Expand Up @@ -80,15 +82,11 @@ public TaskParams getParams() {
return params;
}

public void stop(String reason) {
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this still needs to be called. If you are concerned, maybe wrap the listener and call this on response/failure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not happy with this pattern. The task asks the node service to stop then the node service calls back to TrainedModelDeploymentTask::markAsStopped from stopDeploymentAsync. This means markAsStopped is a public method which doesn't make much sense as part of the public API.

The problem is there are a few ways TrainedModelAllocationNodeService::stopDeploymentAsync can be called such as the service noticing the deployment has been deleted or the task being cancelled.

I deleted these lines because markAsStopped (formerly stopWithoutNotification) was being called anyway and the same work occurs there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted these lines because markAsStopped (formerly stopWithoutNotification) was being called anyway and the same work occurs there.

Gotcha, we just need to be careful and make sure that stopTracking is being called.

stopped = true;
stoppedReasonHolder.trySet(reason);
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
public void stop(String reason, ActionListener<AcknowledgedResponse> listener) {
trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason, listener);
}

public void stopWithoutNotification(String reason) {
public void markAsStopped(String reason) {
licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
stoppedReasonHolder.trySet(reason);
Expand All @@ -106,7 +104,14 @@ public Optional<String> stoppedReason() {
@Override
protected void onCancelled() {
String reason = getReasonCancelled();
stop(reason);
logger.info("[{}] task cancelled due to reason [{}]", getModelId(), reason);
stop(
reason,
ActionListener.wrap(
acknowledgedResponse -> {},
e -> logger.error(new ParameterizedMessage("[{}] error stopping the model after task cancellation", getModelId()), e)
)
);
}

public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception {
// Only one model should be loaded, the other should be stopped
trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad));
trainedModelAllocationNodeService.prepareModelToLoad(newParams(stoppedModelToLoad));
trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing");
trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing", ActionListener.wrap(r -> {}, e -> {}));
trainedModelAllocationNodeService.loadQueuedModels();

assertBusy(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@

package org.elasticsearch.xpack.ml.inference.deployment;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
import org.mockito.ArgumentCaptor;

import java.util.Map;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_ACTION;
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -29,6 +33,15 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
TrainedModelAllocationNodeService nodeService = mock(TrainedModelAllocationNodeService.class);

ArgumentCaptor<TrainedModelDeploymentTask> taskCaptor = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
ArgumentCaptor<String> reasonCaptur = ArgumentCaptor.forClass(String.class);
doAnswer(invocation -> {
taskCaptor.getValue().markAsStopped(reasonCaptur.getValue());
return null;
}).when(nodeService).stopDeploymentAndNotify(taskCaptor.capture(), reasonCaptur.capture(), any());

TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
0,
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
Expand All @@ -42,7 +55,7 @@ void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String
randomInt(5),
randomInt(5)
),
mock(TrainedModelAllocationNodeService.class),
nodeService,
licenseState,
feature
);
Expand All @@ -53,12 +66,12 @@ void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String
verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
}

public void testOnStopWithoutNotification() {
assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
public void testMarkAsStopped() {
assertTrackingComplete(t -> t.markAsStopped("foo"), randomAlphaOfLength(10));
}

public void testOnStop() {
assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
assertTrackingComplete(t -> t.stop("foo", ActionListener.wrap(r -> {}, e -> {})), randomAlphaOfLength(10));
}

public void testCancelled() {
Expand Down