Skip to content

Commit 18735de

Browse files
committed
[ML] Check total_definition_length is consistent on start deployment (elastic#80553)
Before a model deployment is started check the the individual model part documents have a consistent value for total_definition_length. Any inconsistency is an unrecoverable error.
1 parent 9ae2496 commit 18735de

File tree

4 files changed

+101
-10
lines changed

4 files changed

+101
-10
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ public final class Messages {
120120
"Unable to delete model [{0}] as it is required by machine learning";
121121
public static final String MODEL_DEFINITION_TRUNCATED =
122122
"Model definition truncated. Unable to deserialize trained model definition [{0}]";
123+
public static final String UNABLE_TO_DEPLOY_MODEL_BAD_PARTS = "Unable to deploy model, please delete and recreate the model definition";
123124
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
124125
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
125126
"Getting model definition is not supported when getting more than one model";

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,60 @@ private static TrainedModel buildRegression() {
320320
.build();
321321
}
322322

323+
public void testStartDeploymentWithInconsistentTotalLengths() throws IOException {
324+
String modelId = "inconsistent-size-model";
325+
putPyTorchModel(modelId);
326+
327+
putModelDefinitionPart(modelId, 500, 3, 0);
328+
putModelDefinitionPart(modelId, 500, 3, 1);
329+
putModelDefinitionPart(modelId, 600, 3, 2);
330+
331+
ResponseException responseException = expectThrows(ResponseException.class, () -> startDeployment(modelId));
332+
assertThat(
333+
responseException.getMessage(),
334+
containsString(
335+
"[total_definition_length] must be the same in all model definition parts. "
336+
+ "The value [600] in model definition part [2] does not match the value [500] in part [0]"
337+
)
338+
);
339+
340+
}
341+
342+
private void putPyTorchModel(String modelId) throws IOException {
343+
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
344+
request.setJsonEntity(
345+
"{ "
346+
+ " \"description\": \"simple model for testing\",\n"
347+
+ " \"model_type\": \"pytorch\",\n"
348+
+ " \"inference_config\": {\n"
349+
+ " \"pass_through\": {\n"
350+
+ " }\n"
351+
+ " }\n"
352+
+ "}"
353+
);
354+
client().performRequest(request);
355+
}
356+
357+
private void putModelDefinitionPart(String modelId, int totalSize, int numParts, int partNumber) throws IOException {
358+
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/" + partNumber);
359+
request.setJsonEntity(
360+
"{ "
361+
+ "\"total_definition_length\": "
362+
+ totalSize
363+
+ ","
364+
+ "\"definition\": \"UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZW==\","
365+
+ "\"total_parts\": "
366+
+ numParts
367+
+ "}"
368+
);
369+
client().performRequest(request);
370+
}
371+
372+
private void startDeployment(String modelId) throws IOException {
373+
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_start?timeout=40s");
374+
client().performRequest(request);
375+
}
376+
323377
@After
324378
public void clearMlState() throws Exception {
325379
new MlRestTestStateCleaner(logger, adminClient()).resetFeatures();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,20 @@ private void validateModelDefinition(TrainedModelConfig config, ActionListener<V
316316
listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
317317
return;
318318
}
319+
long firstTotalLength = ((Number) hits[0].getSourceAsMap()
320+
.get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
321+
319322
long summedLengths = 0;
320323
for (SearchHit hit : hits) {
321324
Map<String, Object> fields = hit.getSourceAsMap();
322325
if (fields == null) {
323326
listener.onFailure(
324327
ExceptionsHelper.badRequestException(
325-
"[{}] model definition [{}] is missing required fields {}, unable to be deployed",
328+
"[{}] model definition [{}] is missing required fields {}. {}",
326329
modelId,
327330
TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
328-
List.of(requiredSourceFields)
331+
List.of(requiredSourceFields),
332+
Messages.UNABLE_TO_DEPLOY_MODEL_BAD_PARTS
329333
)
330334
);
331335
return;
@@ -334,21 +338,40 @@ private void validateModelDefinition(TrainedModelConfig config, ActionListener<V
334338
if (diff.isEmpty() == false) {
335339
listener.onFailure(
336340
ExceptionsHelper.badRequestException(
337-
"[{}] model definition [{}] is missing required fields {}, unable to be deployed",
341+
"[{}] model definition [{}] is missing required fields {}. {}",
338342
modelId,
339343
TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
340-
diff
344+
diff,
345+
Messages.UNABLE_TO_DEPLOY_MODEL_BAD_PARTS
341346
)
342347
);
343348
return;
344349
}
345350
summedLengths += ((Number) fields.get(TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName())).longValue();
351+
long totalLength = ((Number) fields.get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()))
352+
.longValue();
353+
if (totalLength != firstTotalLength) {
354+
listener.onFailure(
355+
ExceptionsHelper.badRequestException(
356+
"[{}] [total_definition_length] must be the same in all model definition parts. "
357+
+ "The value [{}] in model definition part [{}] does not match the value [{}] in part [{}]. "
358+
+ Messages.UNABLE_TO_DEPLOY_MODEL_BAD_PARTS,
359+
modelId,
360+
totalLength,
361+
TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
362+
firstTotalLength,
363+
TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hits[0].getId()))
364+
)
365+
);
366+
return;
367+
}
368+
346369
}
347-
long totalLength = ((Number) hits[hits.length - 1].getSourceAsMap()
348-
.get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
349370
Boolean eos = (Boolean) hits[hits.length - 1].getSourceAsMap().get(TrainedModelDefinitionDoc.EOS.getPreferredName());
350-
if (summedLengths != totalLength || eos == null || eos == false) {
351-
listener.onFailure(ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
371+
if (summedLengths != firstTotalLength || eos == null || eos == false) {
372+
listener.onFailure(
373+
ExceptionsHelper.badRequestException(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))
374+
);
352375
return;
353376
}
354377
listener.onResponse(null);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Locale;
2323
import java.util.Objects;
2424
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.atomic.AtomicInteger;
2526

2627
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
2728

@@ -40,7 +41,8 @@ public class PyTorchStateStreamer {
4041
private final ExecutorService executorService;
4142
private final NamedXContentRegistry xContentRegistry;
4243
private volatile boolean isCancelled;
43-
private int modelSize = -1;
44+
private volatile int modelSize = -1;
45+
private final AtomicInteger bytesWritten = new AtomicInteger();
4446

4547
public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
4648
this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
@@ -70,6 +72,14 @@ public void writeStateToStream(String modelId, String index, OutputStream restor
7072
restorer.setSearchSize(1);
7173
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
7274
logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
75+
if (bytesWritten.get() != modelSize) {
76+
logger.error(
77+
"model [{}] restored state size [{}] does not equal the expected model size [{}]",
78+
modelId,
79+
bytesWritten,
80+
modelSize
81+
);
82+
}
7383
listener.onResponse(success);
7484
}, listener::onFailure);
7585
}
@@ -86,6 +96,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr
8696
// The array backing the BytesReference may be bigger than what is
8797
// referred to so write only what is after the offset
8898
outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
99+
bytesWritten.addAndGet(doc.getBinaryData().length());
89100
return true;
90101
}
91102

@@ -128,10 +139,12 @@ private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream out
128139
throw new IllegalStateException(message);
129140
}
130141

131-
ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
142+
final int NUM_BYTES = 4;
143+
ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES);
132144
lengthBuffer.putInt(modelSizeBytes.intValue());
133145
outputStream.write(lengthBuffer.array());
134146

147+
bytesWritten.addAndGet(NUM_BYTES);
135148
return modelSizeBytes.intValue();
136149
}
137150
}

0 commit comments

Comments
 (0)