@@ -37,12 +37,16 @@ public class PyTorchStateStreamer {
37
37
38
38
private static final Logger logger = LogManager .getLogger (PyTorchStateStreamer .class );
39
39
40
+ /** The size of the data written before the model definition */
41
+ private static final int NUM_BYTES_IN_PRELUDE = 4 ;
42
+
40
43
private final OriginSettingClient client ;
41
44
private final ExecutorService executorService ;
42
45
private final NamedXContentRegistry xContentRegistry ;
43
46
private volatile boolean isCancelled ;
44
47
private volatile int modelSize = -1 ;
45
- private final AtomicInteger bytesWritten = new AtomicInteger ();
48
+ // model bytes only, does not include the prelude
49
+ private final AtomicInteger modelBytesWritten = new AtomicInteger ();
46
50
47
51
public PyTorchStateStreamer (Client client , ExecutorService executorService , NamedXContentRegistry xContentRegistry ) {
48
52
this .client = new OriginSettingClient (Objects .requireNonNull (client ), ML_ORIGIN );
@@ -59,7 +63,7 @@ public void cancel() {
59
63
60
64
/**
61
65
* First writes the size of the model so the native process can
62
- * allocated memory then writes the chunks of binary state.
66
+ * allocate memory then writes the chunks of binary state.
63
67
*
64
68
* @param modelId The model to write
65
69
* @param index The index to search for the model
@@ -72,11 +76,11 @@ public void writeStateToStream(String modelId, String index, OutputStream restor
72
76
restorer .setSearchSize (1 );
73
77
restorer .restoreModelDefinition (doc -> writeChunk (doc , restoreStream ), success -> {
74
78
logger .debug ("model [{}] state restored in [{}] documents from index [{}]" , modelId , restorer .getNumDocsWritten (), index );
75
- if (bytesWritten .get () != modelSize ) {
79
+ if (modelBytesWritten .get () != modelSize ) {
76
80
logger .error (
77
81
"model [{}] restored state size [{}] does not equal the expected model size [{}]" ,
78
82
modelId ,
79
- bytesWritten ,
83
+ modelBytesWritten ,
80
84
modelSize
81
85
);
82
86
}
@@ -96,7 +100,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr
96
100
// The array backing the BytesReference may be bigger than what is
97
101
// referred to so write only what is after the offset
98
102
outputStream .write (doc .getBinaryData ().array (), doc .getBinaryData ().arrayOffset (), doc .getBinaryData ().length ());
99
- bytesWritten .addAndGet (doc .getBinaryData ().length ());
103
+ modelBytesWritten .addAndGet (doc .getBinaryData ().length ());
100
104
return true ;
101
105
}
102
106
@@ -139,12 +143,10 @@ private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream out
139
143
throw new IllegalStateException (message );
140
144
}
141
145
142
- final int NUM_BYTES = 4 ;
143
- ByteBuffer lengthBuffer = ByteBuffer .allocate (NUM_BYTES );
146
+ ByteBuffer lengthBuffer = ByteBuffer .allocate (NUM_BYTES_IN_PRELUDE );
144
147
lengthBuffer .putInt (modelSizeBytes .intValue ());
145
148
outputStream .write (lengthBuffer .array ());
146
149
147
- bytesWritten .addAndGet (NUM_BYTES );
148
150
return modelSizeBytes .intValue ();
149
151
}
150
152
}
0 commit comments