Skip to content

Commit 117e74b

Browse files
authored
[ML] updating node memory load for new allocation service (#76046)
Trained model deployment memory usage is no longer determinable via persistent tasks. The new way is to look into the trained model allocation metadata. This PR updates this and removes some unused code. relates: #75778
1 parent 18a39be commit 117e74b

File tree

14 files changed

+110
-312
lines changed

14 files changed

+110
-312
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
1616
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
1717
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
18-
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
19-
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
2018
import org.elasticsearch.xpack.core.ml.job.config.JobState;
2119
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
2220
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
@@ -30,11 +28,13 @@
3028

3129
public final class MlTasks {
3230

31+
public static final String TRAINED_MODEL_ALLOCATION_TASK_TYPE = "trained_model_allocation";
32+
public static final String TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX = "xpack/ml/allocation-";
33+
3334
public static final String JOB_TASK_NAME = "xpack/ml/job";
3435
public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed";
3536
public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics";
3637
public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade";
37-
public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment";
3838

3939
public static final String JOB_TASK_ID_PREFIX = "job-";
4040
public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-";
@@ -225,31 +225,6 @@ public static DataFrameAnalyticsState getDataFrameAnalyticsState(@Nullable Persi
225225
return state;
226226
}
227227

228-
public static TrainedModelDeploymentState getTrainedModelDeploymentState(PersistentTasksCustomMetadata.PersistentTask<?> task) {
229-
if (task == null) {
230-
return TrainedModelDeploymentState.STOPPED;
231-
}
232-
TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) task.getState();
233-
if (taskState == null) {
234-
return TrainedModelDeploymentState.STARTING;
235-
}
236-
237-
TrainedModelDeploymentState state = taskState.getState();
238-
if (taskState.isStatusStale(task)) {
239-
if (state == TrainedModelDeploymentState.STOPPING) {
240-
// previous executor node failed while the job was stopping - it won't
241-
// be restarted on another node, so consider it STOPPED for reassignment purposes
242-
return TrainedModelDeploymentState.STOPPED;
243-
}
244-
if (state != TrainedModelDeploymentState.FAILED) {
245-
// we are relocating at the moment
246-
// TODO Revisit this in the new allocation framework as there won't necessarily be a concept of relocation.
247-
return TrainedModelDeploymentState.STARTING;
248-
}
249-
}
250-
return state;
251-
}
252-
253228
/**
254229
* The job Ids of anomaly detector job tasks.
255230
* All anomaly detector jobs are returned regardless of the status of the
@@ -435,8 +410,6 @@ public static MemoryTrackedTaskState getMemoryTrackedTaskState(PersistentTasksCu
435410
return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
436411
case DATA_FRAME_ANALYTICS_TASK_NAME:
437412
return getDataFrameAnalyticsState(task);
438-
case TRAINED_MODEL_DEPLOYMENT_TASK_NAME:
439-
return getTrainedModelDeploymentState(task);
440413
default:
441414
throw new IllegalStateException("unexpected task type [" + task.getTaskName() + "]");
442415
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.action.support.master.MasterNodeRequest;
1414
import org.elasticsearch.cluster.node.DiscoveryNode;
1515
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
16+
import org.elasticsearch.common.io.stream.Writeable;
1617
import org.elasticsearch.common.unit.ByteSizeValue;
1718
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1819
import org.elasticsearch.common.xcontent.ParseField;
@@ -23,7 +24,6 @@
2324
import org.elasticsearch.core.TimeValue;
2425
import org.elasticsearch.common.xcontent.ToXContentObject;
2526
import org.elasticsearch.common.xcontent.XContentBuilder;
26-
import org.elasticsearch.persistent.PersistentTaskParams;
2727
import org.elasticsearch.tasks.Task;
2828
import org.elasticsearch.xpack.core.ml.MlTasks;
2929
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@@ -122,13 +122,13 @@ public String toString() {
122122
}
123123
}
124124

125-
public static class TaskParams implements PersistentTaskParams, MlTaskParams {
125+
public static class TaskParams implements MlTaskParams, Writeable, ToXContentObject {
126126

127127
// TODO add support for other roles? If so, it may have to be an instance method...
128128
// NOTE, whatever determines allocation should not be dynamically set on the node
129129
// Otherwise allocation logic might fail
130130
public static boolean mayAllocateToNode(DiscoveryNode node) {
131-
return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
131+
return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE) && node.getVersion().onOrAfter(VERSION_INTRODUCED);
132132
}
133133

134134
public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
@@ -187,12 +187,6 @@ public long estimateMemoryUsageBytes() {
187187
return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
188188
}
189189

190-
@Override
191-
public String getWriteableName() {
192-
return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME;
193-
}
194-
195-
@Override
196190
public Version getMinimalSupportedVersion() {
197191
return VERSION_INTRODUCED;
198192
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
package org.elasticsearch.xpack.core.ml.inference.allocation;
99

10+
import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
11+
12+
import java.util.Arrays;
1013
import java.util.Locale;
1114

12-
public enum RoutingState {
15+
public enum RoutingState implements MemoryTrackedTaskState {
1316
STARTING,
1417
STARTED,
1518
STOPPING,
@@ -20,8 +23,25 @@ public static RoutingState fromString(String value) {
2023
return valueOf(value.toUpperCase(Locale.ROOT));
2124
}
2225

26+
/**
27+
* @return {@code true} if state matches none of the given {@code candidates}
28+
*/
29+
public boolean isNoneOf(RoutingState... candidates) {
30+
return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
31+
}
32+
2333
@Override
2434
public String toString() {
2535
return name().toLowerCase(Locale.ROOT);
2636
}
37+
38+
@Override
39+
public boolean consumesMemory() {
40+
return isNoneOf(FAILED, STOPPED);
41+
}
42+
43+
@Override
44+
public boolean isAllocating() {
45+
return this == STARTING;
46+
}
2747
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java

Lines changed: 0 additions & 52 deletions
This file was deleted.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java

Lines changed: 0 additions & 117 deletions
This file was deleted.

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@
1616
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
1717
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
1818
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
19-
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
2019
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
2120
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
2221
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
23-
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
24-
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
2522
import org.elasticsearch.xpack.core.ml.job.config.JobState;
2623
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
2724

2825
import java.net.InetAddress;
29-
import java.util.Arrays;
3026

3127
import static org.hamcrest.Matchers.contains;
3228
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -308,41 +304,6 @@ public void testGetDataFrameAnalyticsState_GivenStaleTaskWithFailedState() {
308304
assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
309305
}
310306

311-
public void testGetTrainedModelDeploymentState_GivenNull() {
312-
assertThat(MlTasks.getTrainedModelDeploymentState(null), equalTo(TrainedModelDeploymentState.STOPPED));
313-
}
314-
315-
public void testGetTrainedModelDeploymentState_GivenTaskStateIsNull() {
316-
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(null, false);
317-
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
318-
}
319-
320-
public void testGetTrainedModelDeploymentState_GivenTaskStateIsNotNullAndNotStale() {
321-
TrainedModelDeploymentState state = randomFrom(TrainedModelDeploymentState.values());
322-
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, false);
323-
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(state));
324-
}
325-
326-
public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndStopping() {
327-
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.STOPPING, true);
328-
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STOPPED));
329-
}
330-
331-
public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndFailed() {
332-
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.FAILED, true);
333-
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.FAILED));
334-
}
335-
336-
public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndNotFailedNorStopping() {
337-
TrainedModelDeploymentState state = randomFrom(
338-
Arrays.stream(TrainedModelDeploymentState.values())
339-
.filter(s -> s != TrainedModelDeploymentState.FAILED && s != TrainedModelDeploymentState.STOPPING)
340-
.toArray(TrainedModelDeploymentState[]::new)
341-
);
342-
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, true);
343-
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
344-
}
345-
346307
private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAnalyticsTask(String jobId, String nodeId,
347308
DataFrameAnalyticsState state,
348309
boolean isStale) {
@@ -358,18 +319,4 @@ private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAn
358319
return tasks.getTask(MlTasks.dataFrameAnalyticsTaskId(jobId));
359320
}
360321

361-
private static PersistentTasksCustomMetadata.PersistentTask<?> createTrainedModelTask(TrainedModelDeploymentState state,
362-
boolean isStale) {
363-
String id = randomAlphaOfLength(10);
364-
PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
365-
builder.addTask(MlTasks.trainedModelDeploymentTaskId(id), MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
366-
new StartTrainedModelDeploymentAction.TaskParams(id, randomAlphaOfLength(10), randomNonNegativeLong()),
367-
new PersistentTasksCustomMetadata.Assignment(randomAlphaOfLength(10), "test assignment"));
368-
if (state != null) {
369-
builder.updateTaskState(MlTasks.trainedModelDeploymentTaskId(id),
370-
new TrainedModelDeploymentTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0), null));
371-
}
372-
PersistentTasksCustomMetadata tasks = builder.build();
373-
return tasks.getTask(MlTasks.trainedModelDeploymentTaskId(id));
374-
}
375322
}

0 commit comments

Comments
 (0)