Skip to content

Commit 3101118

Browse files
authored
[ML] Fix incorrect logging of unexpected model size error (#81089)
1 parent 92b6b6f commit 3101118

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ public class PyTorchStateStreamer {
3737

3838
private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);
3939

40+
/** The size of the data written before the model definition */
41+
private static final int NUM_BYTES_IN_PRELUDE = 4;
42+
4043
private final OriginSettingClient client;
4144
private final ExecutorService executorService;
4245
private final NamedXContentRegistry xContentRegistry;
4346
private volatile boolean isCancelled;
4447
private volatile int modelSize = -1;
45-
private final AtomicInteger bytesWritten = new AtomicInteger();
48+
// model bytes only, does not include the prelude
49+
private final AtomicInteger modelBytesWritten = new AtomicInteger();
4650

4751
public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
4852
this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
@@ -59,7 +63,7 @@ public void cancel() {
5963

6064
/**
6165
* First writes the size of the model so the native process can
62-
* allocated memory then writes the chunks of binary state.
66+
* allocate memory then writes the chunks of binary state.
6367
*
6468
* @param modelId The model to write
6569
* @param index The index to search for the model
@@ -72,11 +76,11 @@ public void writeStateToStream(String modelId, String index, OutputStream restor
7276
restorer.setSearchSize(1);
7377
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
7478
logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
75-
if (bytesWritten.get() != modelSize) {
79+
if (modelBytesWritten.get() != modelSize) {
7680
logger.error(
7781
"model [{}] restored state size [{}] does not equal the expected model size [{}]",
7882
modelId,
79-
bytesWritten,
83+
modelBytesWritten,
8084
modelSize
8185
);
8286
}
@@ -96,7 +100,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr
96100
// The array backing the BytesReference may be bigger than what is
97101
// referred to so write only what is after the offset
98102
outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
99-
bytesWritten.addAndGet(doc.getBinaryData().length());
103+
modelBytesWritten.addAndGet(doc.getBinaryData().length());
100104
return true;
101105
}
102106

@@ -139,12 +143,10 @@ private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream out
139143
throw new IllegalStateException(message);
140144
}
141145

142-
final int NUM_BYTES = 4;
143-
ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES);
146+
ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES_IN_PRELUDE);
144147
lengthBuffer.putInt(modelSizeBytes.intValue());
145148
outputStream.write(lengthBuffer.array());
146149

147-
bytesWritten.addAndGet(NUM_BYTES);
148150
return modelSizeBytes.intValue();
149151
}
150152
}

0 commit comments

Comments
 (0)