-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML] Adding new trained model allocation service #75778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding new trained model allocation service #75778
Conversation
Pinging @elastic/ml-core (Team:ML) |
"\"transient\" : {\n" + | ||
" \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" + | ||
" \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" + | ||
" }" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice and I think should be kept for now as it has uncovered some weird conditions while running
this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; | ||
// Here we create our singleton for the node service | ||
clusterService.addListener( | ||
new TrainedModelAllocationNodeService( | ||
trainedModelAllocationService, | ||
clusterService, | ||
deploymentManager, | ||
transportService.getTaskManager(), | ||
threadPool | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty much what persistent tasks do. Its ok as transport classes are singletons and are created on every node where the plugin is loaded.
// TODO Do better routing for inference calls | ||
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length); | ||
request.setNodes(randomRunningNode[nodeIndex]); | ||
super.doExecute(task, request, listener); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we do tasks, and still use the object TrainedModelDeploymentTask
all we need to do is set what node we care about and the internal routing does the rest (filtering to the correct node, finding the right task, allowing us to infer against it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ I like how simple this has turned out
.../main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
Outdated
Show resolved
Hide resolved
if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) { | ||
nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If any fail, they all fail. This can be easily changed later. We may want to support "partial allocations"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that it would be beneficial to our users to allow flexibility here and support partial allocations. But I'm happy to deal with this after this PR is merged. Let's raise an issue though so we don't forget.
.../java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
Outdated
Show resolved
Hide resolved
// TODO: Do we want to remove from the modelIdToTask map? This would cause it to be reloaded by state updates on INITIALIZING | ||
modelIdToTask.remove(task.getModelId()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the node state in the cluster state somehow got set back to initializing (of which there is no way now other than recreating the allocation or deleting the route), we would start the model again, which may be exactly what we want.
client.execute(UpdateTrainedModelAllocationStateAction.INSTANCE, request, ActionListener.wrap(listener::onResponse, failure -> { | ||
if (isMasterChannelException(failure)) { | ||
logger.info( | ||
"[{}] master channel exception will retry on new master node for allocation state update [{}]", | ||
request.getModelId(), | ||
request.getRoutingState().getState() | ||
); | ||
waitForNewMasterAndRetry(observer, UpdateTrainedModelAllocationStateAction.INSTANCE, request, listener, changePredicate); | ||
return; | ||
} | ||
listener.onFailure(failure); | ||
})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for table updates, this is important. We don't want a spurious failure to cause the routing table to be stale and cause inference issues. So, if the failure has to do with some intermittent master node communication issue, we should retry.
...ugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
Outdated
Show resolved
Hide resolved
@@ -176,7 +168,7 @@ public void stopDeployment(TrainedModelDeploymentTask task) { | |||
public void infer(TrainedModelDeploymentTask task, | |||
String input, TimeValue timeout, | |||
ActionListener<InferenceResults> listener) { | |||
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); | |||
ProcessContext processContext = processContextByAllocation.get(task.getId()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task IDs are monotonically increasing for the life time of the TaskManager, so this is safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Posting those comments so they don't get lost. I'm half way so I'll come back with more. Also, some of the comments I've made might make no sense once I understand this fully so bear with me.
); | ||
} | ||
|
||
private final StartTrainedModelDeploymentAction.TaskParams taskParams; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it makes sense to keep these as TaskParams
. I would consider getting rid of that object and have the model id, index and model bytes as flat variables in this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, those could be renamed to ModelParams
or something like that, if we intend to capture information about the model into a single object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, these params are still used to create a Task object. Though, I do think maybe bringing them out of StartTrainedModelDeploymentAction
makes sense.
...c/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
Outdated
Show resolved
Hide resolved
...c/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
Outdated
Show resolved
Hide resolved
...c/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
Outdated
Show resolved
Hide resolved
...c/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
Outdated
Show resolved
Hide resolved
// TODO Do better routing for inference calls | ||
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length); | ||
request.setNodes(randomRunningNode[nodeIndex]); | ||
super.doExecute(task, request, listener); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
.../main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
Outdated
Show resolved
Hide resolved
.../main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
Outdated
Show resolved
Hide resolved
if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) { | ||
nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that it would be beneficial to our users to allow flexibility here and support partial allocations. But I'm happy to deal with this after this PR is merged. Let's raise an issue though so we don't forget.
...c/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java
Show resolved
Hide resolved
...va/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
Outdated
Show resolved
Hide resolved
…d-model-allocation-service
|
||
@Override | ||
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { | ||
logger.trace("updated model allocations based on node changes in the cluster"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also print the new routing table here? I think it would be useful in debug.
// TODO this has a weird side-effect for allocating to nodes | ||
// If the event indicates there were nodes added/removed, this method only looks at the current state and has | ||
// no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node | ||
// it may get re-allocated to that node when another node is added/removed... | ||
return addRemoveAllocationNodes(currentState); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should worry about this for now. If a user manually cancels a task on a node, we can only expect they've done this during troubleshooting with our support and guidance. The idea of not then allowing the model to be allocated back to that node is a new feature that would cover the requirement to "manually designate suitable nodes for a model to be allocated on". We do not have that requirement currently.
...va/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
Outdated
Show resolved
Hide resolved
...va/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
Outdated
Show resolved
Hide resolved
...ain/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java
Outdated
Show resolved
Hide resolved
.../src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
Outdated
Show resolved
Hide resolved
...va/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design is good and I can see how advanced features would be built on this.
In many ways it is a simplification as a lot of code is removed
...in/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
Outdated
Show resolved
Hide resolved
this.taskParams = taskParams; | ||
} | ||
|
||
public Builder addNewRoutingEntry(String nodeId) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how you can get to FAILED
without going through INTIALIZING
first.
...-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
Show resolved
Hide resolved
.../main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java
Outdated
Show resolved
Hide resolved
// TODO Do better routing for inference calls | ||
int nodeIndex = Randomness.get().nextInt(randomRunningNode.length); | ||
request.setNodes(randomRunningNode[nodeIndex]); | ||
super.doExecute(task, request, listener); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ I like how simple this has turned out
PersistentTasksCustomMetadata.Assignment assignment = persistentTask.getAssignment(); | ||
|
||
String reason = "__unknown__"; | ||
final Set<Map.Entry<String, RoutingStateAndReason>> nodesAndState = trainedModelAllocation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the logic below to get failed nodes could be a member method on TrainedModelAllocation
. Easier to test and reuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't figure a nice way to get them both in a nice loop (with routing reasons intact).
// TODO this has a weird side-effect for allocating to nodes | ||
// If the event indicates there were nodes added/removed, this method only looks at the current state and has | ||
// no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node | ||
// it may get re-allocated to that node when another node is added/removed... | ||
return addRemoveAllocationNodes(currentState); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would argue 'works as intended' while allocations are all nodes or none at least.
...ain/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java
Outdated
Show resolved
Hide resolved
...ain/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java
Outdated
Show resolved
Hide resolved
...main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java
Show resolved
Hide resolved
…d-model-allocation-service
…d-model-allocation-service
@elasticmachine update branch |
@@ -101,7 +102,17 @@ protected void doExecute(Task task, StopTrainedModelDeploymentAction.Request req | |||
listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); | |||
return; | |||
} | |||
normalUndeploy(task, models.get(0).getModelId(), maybeAllocation.get(), request, listener); | |||
final String modelId = models.get(0).getModelId(); | |||
trainedModelAllocationService.stopModelAllocation(modelId, ActionListener.wrap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a whole new action for this, would it be an option to call the update allocation action and update its state to stopping through that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be an option to call the update allocation action and update its state to stopping through that?
Maybe, but there is no "update allocation" action. The only updates that occur throughout the life time of the allocation are route updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the one I meant UpdateTrainedModelAllocationStateAction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UpdateTrainedModelAllocationStateAction
The focus of that request is only routes. Possibly I should rename it, but right now that action is only requested by the node service to update a route in the cluster service.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the deployment stop action to update the allocation state to stopping
@@ -249,6 +253,10 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra | |||
if (existingAllocation == null) { | |||
throw new ResourceNotFoundException("allocation for model with id [" + modelId + "] not found"); | |||
} | |||
// If we are stopping, don't update anything | |||
if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) { | |||
return currentState; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A debug log statement might be helpful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
…com:benwtrent/elasticsearch into feature/ml-trained-model-allocation-service
@@ -152,7 +152,7 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS | |||
}); | |||
} | |||
|
|||
public void stopModelAllocation(String modelId, ActionListener<AcknowledgedResponse> listener) { | |||
public void setModelAllocationToStopping(String modelId, ActionListener<AcknowledgedResponse> listener) { | |||
clusterService.submitStateUpdateTask("stop model allocation", new ClusterStateUpdateTask() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also update the name of the source
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
*/ | ||
static boolean isNodeShuttingDown(final ClusterState state, final String nodeId) { | ||
// Right now we make no distinction between the type of shutdown, but maybe in the future we might? | ||
return NodesShutdownMetadata.getShutdowns(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sometimes this method is called in a loop it would be worth locally caching the result of getAllNodeMetadataMap()
TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId); | ||
if (allocation == null) { | ||
throw new ResourceNotFoundException( | ||
"unable to add node [{}] to model [{}] routing table as allocation does not exist", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"unable to add node [{}] to model [{}] routing table as allocation does not exist", | |
"unable to add failed node [{}] to model [{}] routing table as allocation does not exist", |
private TaskAwareRequest taskAwareRequest(StartTrainedModelDeploymentAction.TaskParams params) { | ||
final TrainedModelAllocationNodeService trainedModelAllocationNodeService = this; | ||
return new TaskAwareRequest() { | ||
final TaskId parentTaskId = new TaskId(nodeId, taskIdGenerator.incrementAndGet()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the parent task? When not use TaskId#EMPTY_TASK_ID
as suggested in the comment in TaskAwareRequest.java
observer.waitForNextChange(new ClusterStateObserver.Listener() { | ||
@Override | ||
public void onNewClusterState(ClusterState state) { | ||
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null)); | |
listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(state, modelId).orElse(null)); |
Shouldn't the clusterstate in the method parameter be used here?
…d-model-allocation-service
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM2
…d-model-allocation-service
Trained model deployment memory usage is no longer determinable via persistent tasks. The new way is to look into the trained model allocation metadata. This PR updates this and removes some unused code. relates: #75778
Adds a new service for trained model allocation to nodes.
Initially, this only supports PyTorch models and simply allocates
to nodes with the ML roles.
Design is fairly simple:
This type of service sort of splits the difference between the logic of shard allocation and persistent tasks. Neither really fully addressed the need here.