Skip to content

Commit 7d637b8

Browse files
[ML] Add queue_capacity setting to start deployment API (#79433)
Adds a setting to the start trained model deployment API that allows configuring the capacity of the queueing mechanism that handles inference requests.
1 parent 79260bc commit 7d637b8

14 files changed

+106
-41
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
6060
public static final ParseField WAIT_FOR = new ParseField("wait_for");
6161
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
6262
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
63+
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
6364

6465
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
6566

@@ -69,6 +70,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
6970
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
7071
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
7172
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
73+
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
7274
}
7375

7476
public static Request parseRequest(String modelId, XContentParser parser) {
@@ -87,6 +89,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
8789
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
8890
private int modelThreads = 1;
8991
private int inferenceThreads = 1;
92+
private int queueCapacity = 1024;
9093

9194
private Request() {}
9295

@@ -101,6 +104,7 @@ public Request(StreamInput in) throws IOException {
101104
waitForState = in.readEnum(AllocationStatus.State.class);
102105
modelThreads = in.readVInt();
103106
inferenceThreads = in.readVInt();
107+
queueCapacity = in.readVInt();
104108
}
105109

106110
public final void setModelId(String modelId) {
@@ -144,6 +148,14 @@ public void setInferenceThreads(int inferenceThreads) {
144148
this.inferenceThreads = inferenceThreads;
145149
}
146150

151+
public int getQueueCapacity() {
152+
return queueCapacity;
153+
}
154+
155+
public void setQueueCapacity(int queueCapacity) {
156+
this.queueCapacity = queueCapacity;
157+
}
158+
147159
@Override
148160
public void writeTo(StreamOutput out) throws IOException {
149161
super.writeTo(out);
@@ -152,6 +164,7 @@ public void writeTo(StreamOutput out) throws IOException {
152164
out.writeEnum(waitForState);
153165
out.writeVInt(modelThreads);
154166
out.writeVInt(inferenceThreads);
167+
out.writeVInt(queueCapacity);
155168
}
156169

157170
@Override
@@ -162,6 +175,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
162175
builder.field(WAIT_FOR.getPreferredName(), waitForState);
163176
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
164177
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
178+
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
165179
builder.endObject();
166180
return builder;
167181
}
@@ -183,12 +197,15 @@ public ActionRequestValidationException validate() {
183197
if (inferenceThreads < 1) {
184198
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
185199
}
200+
if (queueCapacity < 1) {
201+
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
202+
}
186203
return validationException.validationErrors().isEmpty() ? null : validationException;
187204
}
188205

189206
@Override
190207
public int hashCode() {
191-
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
208+
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
192209
}
193210

194211
@Override
@@ -204,7 +221,8 @@ public boolean equals(Object obj) {
204221
&& Objects.equals(timeout, other.timeout)
205222
&& Objects.equals(waitForState, other.waitForState)
206223
&& modelThreads == other.modelThreads
207-
&& inferenceThreads == other.inferenceThreads;
224+
&& inferenceThreads == other.inferenceThreads
225+
&& queueCapacity == other.queueCapacity;
208226
}
209227

210228
@Override
@@ -226,16 +244,20 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
226244
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
227245
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
228246
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
247+
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
248+
229249
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
230250
"trained_model_deployment_params",
231251
true,
232-
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
252+
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
233253
);
254+
234255
static {
235256
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
236257
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
237258
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
238259
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
260+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
239261
}
240262

241263
public static TaskParams fromXContent(XContentParser parser) {
@@ -253,28 +275,22 @@ public static TaskParams fromXContent(XContentParser parser) {
253275
private final long modelBytes;
254276
private final int inferenceThreads;
255277
private final int modelThreads;
278+
private final int queueCapacity;
256279

257-
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
280+
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
258281
this.modelId = Objects.requireNonNull(modelId);
259282
this.modelBytes = modelBytes;
260-
if (modelBytes < 0) {
261-
throw new IllegalArgumentException("modelBytes must be non-negative");
262-
}
263283
this.inferenceThreads = inferenceThreads;
264-
if (inferenceThreads < 1) {
265-
throw new IllegalArgumentException(INFERENCE_THREADS + " must be positive");
266-
}
267284
this.modelThreads = modelThreads;
268-
if (modelThreads < 1) {
269-
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
270-
}
285+
this.queueCapacity = queueCapacity;
271286
}
272287

273288
public TaskParams(StreamInput in) throws IOException {
274289
this.modelId = in.readString();
275-
this.modelBytes = in.readVLong();
290+
this.modelBytes = in.readLong();
276291
this.inferenceThreads = in.readVInt();
277292
this.modelThreads = in.readVInt();
293+
this.queueCapacity = in.readVInt();
278294
}
279295

280296
public String getModelId() {
@@ -293,9 +309,10 @@ public Version getMinimalSupportedVersion() {
293309
@Override
294310
public void writeTo(StreamOutput out) throws IOException {
295311
out.writeString(modelId);
296-
out.writeVLong(modelBytes);
312+
out.writeLong(modelBytes);
297313
out.writeVInt(inferenceThreads);
298314
out.writeVInt(modelThreads);
315+
out.writeVInt(queueCapacity);
299316
}
300317

301318
@Override
@@ -305,13 +322,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
305322
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
306323
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
307324
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
325+
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
308326
builder.endObject();
309327
return builder;
310328
}
311329

312330
@Override
313331
public int hashCode() {
314-
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
332+
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
315333
}
316334

317335
@Override
@@ -323,7 +341,8 @@ public boolean equals(Object o) {
323341
return Objects.equals(modelId, other.modelId)
324342
&& modelBytes == other.modelBytes
325343
&& inferenceThreads == other.inferenceThreads
326-
&& modelThreads == other.modelThreads;
344+
&& modelThreads == other.modelThreads
345+
&& queueCapacity == other.queueCapacity;
327346
}
328347

329348
@Override
@@ -342,6 +361,15 @@ public int getInferenceThreads() {
342361
public int getModelThreads() {
343362
return modelThreads;
344363
}
364+
365+
public int getQueueCapacity() {
366+
return queueCapacity;
367+
}
368+
369+
@Override
370+
public String toString() {
371+
return Strings.toString(this);
372+
}
345373
}
346374

347375
public interface TaskMatcher {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,7 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
return new Request(
18-
new StartTrainedModelDeploymentAction.TaskParams(
19-
randomAlphaOfLength(10),
20-
randomNonNegativeLong(),
21-
randomIntBetween(1, 8),
22-
randomIntBetween(1, 8)
23-
)
24-
);
17+
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
2518
}
2619

2720
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919

2020
import static org.hamcrest.Matchers.containsString;
21+
import static org.hamcrest.Matchers.equalTo;
2122
import static org.hamcrest.Matchers.is;
2223
import static org.hamcrest.Matchers.not;
2324
import static org.hamcrest.Matchers.nullValue;
@@ -53,6 +54,9 @@ public static Request createRandom() {
5354
if (randomBoolean()) {
5455
request.setModelThreads(randomIntBetween(1, 8));
5556
}
57+
if (randomBoolean()) {
58+
request.setQueueCapacity(randomIntBetween(1, 10000));
59+
}
5660
return request;
5761
}
5862

@@ -95,4 +99,33 @@ public void testValidate_GivenModelThreadsIsNegative() {
9599
assertThat(e, is(not(nullValue())));
96100
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
97101
}
102+
103+
public void testValidate_GivenQueueCapacityIsZero() {
104+
Request request = createRandom();
105+
request.setQueueCapacity(0);
106+
107+
ActionRequestValidationException e = request.validate();
108+
109+
assertThat(e, is(not(nullValue())));
110+
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
111+
}
112+
113+
public void testValidate_GivenQueueCapacityIsNegative() {
114+
Request request = createRandom();
115+
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));
116+
117+
ActionRequestValidationException e = request.validate();
118+
119+
assertThat(e, is(not(nullValue())));
120+
assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
121+
}
122+
123+
public void testDefaults() {
124+
Request request = new Request(randomAlphaOfLength(10));
125+
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
126+
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
127+
assertThat(request.getInferenceThreads(), equalTo(1));
128+
assertThat(request.getModelThreads(), equalTo(1));
129+
assertThat(request.getQueueCapacity(), equalTo(1024));
130+
}
98131
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
3636
randomAlphaOfLength(10),
3737
randomNonNegativeLong(),
3838
randomIntBetween(1, 8),
39-
randomIntBetween(1, 8)
39+
randomIntBetween(1, 8),
40+
randomIntBetween(1, 10000)
4041
);
4142
}
4243
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import org.elasticsearch.cluster.node.DiscoveryNode;
1414
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
1515
import org.elasticsearch.common.io.stream.Writeable;
16-
import org.elasticsearch.xcontent.XContentParser;
1716
import org.elasticsearch.test.AbstractSerializingTestCase;
17+
import org.elasticsearch.xcontent.XContentParser;
1818
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
19+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;
1920

2021
import java.io.IOException;
2122
import java.util.List;
@@ -31,9 +32,7 @@
3132
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {
3233

3334
public static TrainedModelAllocation randomInstance() {
34-
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
35-
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
36-
);
35+
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
3736
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
3837
for (String node : nodes) {
3938
if (randomBoolean()) {
@@ -249,7 +248,7 @@ private static DiscoveryNode buildNode() {
249248
}
250249

251250
private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
252-
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
251+
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
253252
}
254253

255254
private static void assertUnchanged(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.elasticsearch.cluster.service.ClusterService;
2727
import org.elasticsearch.common.inject.Inject;
2828
import org.elasticsearch.common.settings.Settings;
29-
import org.elasticsearch.xcontent.NamedXContentRegistry;
3029
import org.elasticsearch.core.TimeValue;
3130
import org.elasticsearch.license.LicenseUtils;
3231
import org.elasticsearch.license.XPackLicenseState;
@@ -35,6 +34,7 @@
3534
import org.elasticsearch.tasks.Task;
3635
import org.elasticsearch.threadpool.ThreadPool;
3736
import org.elasticsearch.transport.TransportService;
37+
import org.elasticsearch.xcontent.NamedXContentRegistry;
3838
import org.elasticsearch.xpack.core.XPackField;
3939
import org.elasticsearch.xpack.core.ml.MachineLearningField;
4040
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
@@ -161,7 +161,8 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
161161
trainedModelConfig.getModelId(),
162162
modelBytes,
163163
request.getInferenceThreads(),
164-
request.getModelThreads()
164+
request.getModelThreads(),
165+
request.getQueueCapacity()
165166
);
166167
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
167168
PersistentTasksCustomMetadata.TYPE);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ TrainedModelDeploymentTask getTask(String modelId) {
332332
}
333333

334334
void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
335+
logger.debug(() -> new ParameterizedMessage("[{}] preparing to load model with task params: {}",
336+
taskParams.getModelId(), taskParams));
335337
TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register(
336338
TRAINED_MODEL_ALLOCATION_TASK_TYPE,
337339
TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + taskParams.getModelId(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,11 @@ class ProcessContext {
392392
this.task = Objects.requireNonNull(task);
393393
resultProcessor = new PyTorchResultProcessor(task.getModelId());
394394
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
395-
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
395+
this.executorService = new ProcessWorkerExecutorService(
396+
threadPool.getThreadContext(),
397+
"pytorch_inference",
398+
task.getParams().getQueueCapacity()
399+
);
396400
}
397401

398402
PyTorchResultProcessor getResultProcessor() {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService {
4545
/**
4646
* @param contextHolder the thread context holder
4747
* @param processName the name of the process to be used in logging
48-
* @param queueSize the size of the queue holding operations. If an operation is added
48+
* @param queueCapacity the capacity of the queue holding operations. If an operation is added
4949
* for execution when the queue is full a 429 error is thrown.
5050
*/
5151
@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
52-
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueSize) {
52+
public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueCapacity) {
5353
this.contextHolder = Objects.requireNonNull(contextHolder);
5454
this.processName = Objects.requireNonNull(processName);
55-
this.queue = new LinkedBlockingQueue<>(queueSize);
55+
this.queue = new LinkedBlockingQueue<>(queueCapacity);
5656
}
5757

5858
@Override

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static org.elasticsearch.rest.RestRequest.Method.POST;
2424
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
2525
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
26+
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
2627
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
2728
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;
2829

@@ -59,6 +60,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
5960
));
6061
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
6162
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
63+
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
6264
}
6365

6466
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

0 commit comments

Comments
 (0)