@@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
80
80
81
81
def enc_hook (self , obj : Any ) -> Any :
82
82
if isinstance (obj , torch .Tensor ):
83
- return self ._encode_ndarray (obj . numpy () )
83
+ return self ._encode_tensor (obj )
84
84
85
85
# Fall back to pickle for object or void kind ndarrays.
86
86
if isinstance (obj , np .ndarray ) and obj .dtype .kind not in ('O' , 'V' ):
@@ -133,9 +133,27 @@ def _encode_ndarray(
133
133
# backing buffers that we've stashed in `aux_buffers`.
134
134
return obj .dtype .str , obj .shape , data
135
135
136
+ def _encode_tensor (
137
+ self , obj : torch .Tensor
138
+ ) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
139
+ assert self .aux_buffers is not None
140
+ # this creates a copy of the tensor if it's not already contiguous
141
+ obj = obj .contiguous ()
142
+ # view the tensor as a 1D array of bytes
143
+ arr = obj .view ((obj .numel (), )).view (torch .uint8 ).numpy ()
144
+ if obj .nbytes < self .size_threshold :
145
+ # Smaller tensors are encoded inline, just like ndarrays.
146
+ data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
147
+ else :
148
+ # Otherwise encode index of backing buffer to avoid copy.
149
+ data = len (self .aux_buffers )
150
+ self .aux_buffers .append (arr .data )
151
+ dtype = str (obj .dtype )[6 :] # remove 'torch.' prefix
152
+ return dtype , obj .shape , data
153
+
136
154
def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
137
155
if isinstance (nt , torch .Tensor ):
138
- return self ._encode_ndarray (nt . numpy () )
156
+ return self ._encode_tensor (nt )
139
157
if isinstance (nt , (int , float )):
140
158
# Although it violates NestedTensors type, MultiModalKwargs
141
159
# values are sometimes floats.
@@ -186,7 +204,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
186
204
if issubclass (t , np .ndarray ):
187
205
return self ._decode_ndarray (obj )
188
206
if issubclass (t , torch .Tensor ):
189
- return torch . from_numpy ( self ._decode_ndarray (obj ) )
207
+ return self ._decode_tensor (obj )
190
208
if issubclass (t , MultiModalKwargs ):
191
209
if isinstance (obj , list ):
192
210
return MultiModalKwargs .from_items (
@@ -199,11 +217,24 @@ def dec_hook(self, t: type, obj: Any) -> Any:
199
217
200
218
def _decode_ndarray (self , arr : Any ) -> np .ndarray :
201
219
dtype , shape , data = arr
202
- # Copy from inline representation, otherwise Torch is unhappy since
203
- # the returned memory is non-writeable.
220
+ # zero-copy decode. We assume the ndarray will not be kept around,
221
+ # as it now locks the whole received message buffer in memory.
222
+ buffer = self .aux_buffers [data ] if isinstance (data , int ) else data
223
+ return np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
224
+
225
+ def _decode_tensor (self , arr : Any ) -> torch .Tensor :
226
+ dtype , shape , data = arr
227
+ # Copy from inline representation, to decouple the memory storage
228
+ # of the message from the original buffer. And also make Torch
229
+ # not complain about a readonly memoryview.
204
230
buffer = self .aux_buffers [data ] if isinstance (data , int ) \
205
231
else bytearray (data )
206
- return np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
232
+ # Create numpy wrapper around the bytes
233
+ arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = (len (buffer ), ))
234
+ torch_dtype = getattr (torch , dtype )
235
+ assert isinstance (torch_dtype , torch .dtype )
236
+ # Convert back to proper shape & type
237
+ return torch .from_numpy (arr ).view (torch_dtype ).view (shape )
207
238
208
239
def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
209
240
decoded_items = []
@@ -228,7 +259,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
228
259
if not isinstance (obj , list ):
229
260
raise TypeError (f"Unexpected NestedTensors contents: { type (obj )} " )
230
261
if obj and isinstance (obj [0 ], str ):
231
- return torch . from_numpy ( self ._decode_ndarray (obj ) )
262
+ return self ._decode_tensor (obj )
232
263
return [self ._decode_nested_tensors (x ) for x in obj ]
233
264
234
265
def ext_hook (self , code : int , data : memoryview ) -> Any :
0 commit comments