Skip to content

[ML] Adding new trained model allocation service #75778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

benwtrent
Copy link
Member

Adds a new service for trained model allocation to nodes.

Initially, this only supports PyTorch models and simply allocates
to nodes with the ML roles.

Design is fairly simple:

  • A master node service runs allowing for new allocations to be created/updated/deleted from cluster state
  • A node service runs listening to updates referencing the local node + any models it may have allocated and updates accordingly.

This type of service sort of splits the difference between the logic of shard allocation and persistent tasks. Neither really fully addressed the need here.

@elasticmachine elasticmachine added the Team:ML Meta label for the ML team label Jul 28, 2021
@elasticmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

Comment on lines +64 to +67
"\"transient\" : {\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" +
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" +
" }" +
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice and I think should be kept for now as it has uncovered some weird conditions while running

Comment on lines +57 to +66
this.trainedModelAllocationClusterService = trainedModelAllocationClusterService;
// Here we create our singleton for the node service
clusterService.addListener(
new TrainedModelAllocationNodeService(
trainedModelAllocationService,
clusterService,
deploymentManager,
transportService.getTaskManager(),
threadPool
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty much what persistent tasks do. Its ok as transport classes are singletons and are created on every node where the plugin is loaded.

Comment on lines +67 to +70
// TODO Do better routing for inference calls
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length);
request.setNodes(randomRunningNode[nodeIndex]);
super.doExecute(task, request, listener);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we do tasks, and still use the object TrainedModelDeploymentTask all we need to do is set what node we care about and the internal routing does the rest (filtering to the correct node, finding the right task, allowing us to infer against it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ I like how simple this has turned out

Comment on lines +264 to +265
if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) {
nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any fail, they all fail. This can be easily changed later. We may want to support "partial allocations"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that it would be beneficial to our users to allow flexibility here and support partial allocations. But I'm happy to deal with this after this PR is merged. Let's raise an issue though so we don't forget.

Comment on lines +379 to +380
// TODO: Do we want to remove from the modelIdToTask map? This would cause it to be reloaded by state updates on INITIALIZING
modelIdToTask.remove(task.getModelId());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the node state in the cluster state somehow got set back to initializing (of which there is no way now other than recreating the allocation or deleting the route), we would start the model again, which may be exactly what we want.

Comment on lines +71 to +82
client.execute(UpdateTrainedModelAllocationStateAction.INSTANCE, request, ActionListener.wrap(listener::onResponse, failure -> {
if (isMasterChannelException(failure)) {
logger.info(
"[{}] master channel exception will retry on new master node for allocation state update [{}]",
request.getModelId(),
request.getRoutingState().getState()
);
waitForNewMasterAndRetry(observer, UpdateTrainedModelAllocationStateAction.INSTANCE, request, listener, changePredicate);
return;
}
listener.onFailure(failure);
}));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for table updates, this is important. We don't want a spurious failure to cause the routing table to be stale and cause inference issues. So, if the failure has to do with some intermittent master node communication issue, we should retry.

@@ -176,7 +168,7 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
public void infer(TrainedModelDeploymentTask task,
String input, TimeValue timeout,
ActionListener<InferenceResults> listener) {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
ProcessContext processContext = processContextByAllocation.get(task.getId());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task IDs are monotonically increasing for the life time of the TaskManager, so this is safe.

Copy link
Contributor

@dimitris-athanasiou dimitris-athanasiou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting those comments so they don't get lost. I'm half way so I'll come back with more. Also, some of the comments I've made might make no sense once I understand this fully so bear with me.

);
}

private final StartTrainedModelDeploymentAction.TaskParams taskParams;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to keep these as TaskParams. I would consider getting rid of that object and have the model id, index and model bytes as flat variables in this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, those could be renamed to ModelParams or something like that, if we intend to capture information about the model into a single object.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, these params are still used to create a Task object. Though, I do think maybe bringing them out of StartTrainedModelDeploymentAction makes sense.

Comment on lines +67 to +70
// TODO Do better routing for inference calls
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length);
request.setNodes(randomRunningNode[nodeIndex]);
super.doExecute(task, request, listener);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +264 to +265
if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) {
nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that it would be beneficial to our users to allow flexibility here and support partial allocations. But I'm happy to deal with this after this PR is merged. Let's raise an issue though so we don't forget.


@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
logger.trace("updated model allocations based on node changes in the cluster");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also print the new routing table here? I think it would be useful in debug.

Comment on lines +86 to +90
// TODO this has a weird side-effect for allocating to nodes
// If the event indicates there were nodes added/removed, this method only looks at the current state and has
// no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node
// it may get re-allocated to that node when another node is added/removed...
return addRemoveAllocationNodes(currentState);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should worry about this for now. If a user manually cancels a task on a node, we can only expect they've done this during troubleshooting with our support and guidance. The idea of not then allowing the model to be allocated back to that node is a new feature that would cover the requirement to "manually designate suitable nodes for a model to be allocated on". We do not have that requirement currently.

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design is good and I can see how advanced features would be built on this.
In many ways it is a simplification as a lot of code is removed

this.taskParams = taskParams;
}

public Builder addNewRoutingEntry(String nodeId) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how you can get to FAILED without going through INTIALIZING first.

Comment on lines +67 to +70
// TODO Do better routing for inference calls
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length);
request.setNodes(randomRunningNode[nodeIndex]);
super.doExecute(task, request, listener);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ I like how simple this has turned out

PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment();

String reason = "__unknown__";
final Set<Map.Entry<String, RoutingStateAndReason>> nodesAndState = trainedModelAllocation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the logic below to get failed nodes could be a member method on TrainedModelAllocation. Easier to test and reuse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't figure a nice way to get them both in a nice loop (with routing reasons intact).

Comment on lines +86 to +90
// TODO this has a weird side-effect for allocating to nodes
// If the event indicates there were nodes added/removed, this method only looks at the current state and has
// no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node
// it may get re-allocated to that node when another node is added/removed...
return addRemoveAllocationNodes(currentState);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue 'works as intended' while allocations are all nodes or none at least.

@benwtrent
Copy link
Member Author

@elasticmachine update branch

@@ -101,7 +102,17 @@ protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request req
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
return;
}
normalUndeploy(task, models.get(0).getModelId(), maybeAllocation.get(), request, listener);
final String modelId = models.get(0).getModelId();
trainedModelAllocationService.stopModelAllocation(modelId, ActionListener.wrap(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a whole new action for this, would it be an option to call the update allocation action and update its state to stopping through that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be an option to call the update allocation action and update its state to stopping through that?

Maybe, but there is no "update allocation" action. The only updates that occur throughout the life time of the allocation are route updates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one I meant UpdateTrainedModelAllocationStateAction

Copy link
Member Author

@benwtrent benwtrent Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UpdateTrainedModelAllocationStateAction

The focus of that request is only routes. Possibly I should rename it, but right now that action is only requested by the node service to update a route in the cluster service.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the deployment stop action to update the allocation state to stopping

@@ -249,6 +253,10 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
if (existingAllocation == null) {
throw new ResourceNotFoundException("allocation for model with id [" + modelId + "] not found");
}
// If we are stopping, don't update anything
if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) {
return currentState;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A debug log statement might be helpful here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

…com:benwtrent/elasticsearch into feature/ml-trained-model-allocation-service
@@ -152,7 +152,7 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS
});
}

public void stopModelAllocation(String modelId, ActionListener<AcknowledgedResponse> listener) {
public void setModelAllocationToStopping(String modelId, ActionListener<AcknowledgedResponse> listener) {
clusterService.submitStateUpdateTask("stop model allocation", new ClusterStateUpdateTask() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also update the name of the source

Copy link
Contributor

@dimitris-athanasiou dimitris-athanasiou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

*/
static boolean isNodeShuttingDown(final ClusterState state, final String nodeId) {
// Right now we make no distinction between the type of shutdown, but maybe in the future we might?
return NodesShutdownMetadata.getShutdowns(state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sometimes this method is called in a loop it would be worth locally caching the result of getAllNodeMetadataMap()

TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
if (allocation == null) {
throw new ResourceNotFoundException(
"unable to add node [{}] to model [{}] routing table as allocation does not exist",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"unable to add node [{}] to model [{}] routing table as allocation does not exist",
"unable to add failed node [{}] to model [{}] routing table as allocation does not exist",

private TaskAwareRequest taskAwareRequest(StartTrainedModelDeploymentAction.TaskParams params) {
final TrainedModelAllocationNodeService trainedModelAllocationNodeService = this;
return new TaskAwareRequest() {
final TaskId parentTaskId = new TaskId(nodeId, taskIdGenerator.incrementAndGet());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the parent task? When not use TaskId#EMPTY_TASK_ID as suggested in the comment in TaskAwareRequest.java

observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null));
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(state, modelId).orElse(null));

Shouldn't the clusterstate in the method parameter be used here?

@benwtrent benwtrent requested a review from davidkyle August 3, 2021 13:57
Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM2 :shipit:

@benwtrent benwtrent merged commit b11c15b into elastic:master Aug 3, 2021
@benwtrent benwtrent deleted the feature/ml-trained-model-allocation-service branch August 3, 2021 17:06
benwtrent added a commit that referenced this pull request Aug 4, 2021
Trained model deployment memory usage is no longer determinable via persistent tasks.

The new way is to look into the trained model allocation metadata.

This PR updates this and removes some unused code.

relates: #75778
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:ml Machine learning >non-issue Team:ML Meta label for the ML team v8.0.0-alpha1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants