10
10
import os
11
11
import tempfile
12
12
from dataclasses import dataclass
13
- from typing import ClassVar , Dict , List , Literal , Optional
13
+ from typing import ClassVar , Dict , List , Literal , Optional , Sequence
14
14
15
15
import pkg_resources
16
16
from executorch .exir ._serialize ._cord import Cord
17
17
from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
18
18
19
19
from executorch .exir ._serialize ._flatbuffer import _flatc_compile , _flatc_decompile
20
20
from executorch .exir ._serialize ._program import _insert_flatbuffer_header
21
- from executorch .exir ._serialize .data_serializer import DataPayload , DataSerializer
21
+ from executorch .exir ._serialize .data_serializer import (
22
+ DataPayload ,
23
+ DataSerializer ,
24
+ TensorEntry ,
25
+ )
22
26
23
27
from executorch .exir ._serialize .padding import aligned_size , pad_to , padding_required
24
28
25
- # Byte order of numbers written to flat tensor headers. Always little-endian
26
- # regardless of the host system, since all commonly-used modern CPUs are little
27
- # endian.
28
- _HEADER_BYTEORDER : Literal ["little" ] = "little"
29
-
30
29
from executorch .extension .flat_tensor .serialize .flat_tensor_schema import (
31
30
DataSegment ,
32
31
FlatTensor ,
33
32
TensorMetadata ,
34
33
)
35
34
35
+ # Byte order of numbers written to flat tensor headers. Always little-endian
36
+ # regardless of the host system, since all commonly-used modern CPUs are little
37
+ # endian.
38
+ _HEADER_BYTEORDER : Literal ["little" ] = "little"
39
+
36
40
37
41
def _serialize_to_flatbuffer (flat_tensor : FlatTensor ) -> Cord :
38
42
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
@@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
209
213
return None
210
214
211
215
216
+ def _extract_tensors (
217
+ fqn_to_tensor : Dict [str , TensorEntry ],
218
+ buffers : Sequence [bytes ],
219
+ segments : List [Cord ],
220
+ tensor_alignment : int ,
221
+ ) -> List [TensorMetadata ]:
222
+ """Places tensors into a single segment, aligned to tensor_alignment within
223
+ the segment.
224
+
225
+ Args:
226
+ fqn_to_tensor: A map from fully qualified names to tensor entries.
227
+ buffers: A sequence of tensor buffers.
228
+ segments: A list of segments to append the tensor data to. Modified in-place.
229
+ tensor_alignment: The alignment of the tensor data.
230
+
231
+ Returns:
232
+ A list of TensorMetadata, which describes the tensors in the segment.
233
+ """
234
+ tensor_data : Cord = Cord ()
235
+ tensors : List [TensorMetadata ] = []
236
+ # {idx, offset}
237
+ saved_offsets : Dict [int , int ] = {}
238
+ for fqn , tensor_entry in fqn_to_tensor .items ():
239
+ assert tensor_entry .layout is not None
240
+ # Check index into the tensor buffers is valid.
241
+ assert tensor_entry .buffer_index < len (
242
+ buffers
243
+ ), f"Invalid index { tensor_entry .buffer_index } is greater than tensor buffer size { len (buffers )} ."
244
+
245
+ # Check if the tensor has already been appended to the flat_tensor_data.
246
+ offset = saved_offsets .get (tensor_entry .buffer_index , - 1 )
247
+ if offset == - 1 :
248
+ if len (tensor_data ) > 0 :
249
+ # Add padding to round off the previous tensor offset.
250
+ pad_length = padding_required (len (tensor_data ), tensor_alignment )
251
+ tensor_data .append (b"\x00 " * pad_length )
252
+ # Add to saved offsets.
253
+ offset = len (tensor_data )
254
+ saved_offsets [tensor_entry .buffer_index ] = offset
255
+ # Append to flat_tensor_data at the offset.
256
+ tensor_data .append (buffers [tensor_entry .buffer_index ])
257
+
258
+ tensors .append (
259
+ TensorMetadata (
260
+ fully_qualified_name = fqn ,
261
+ scalar_type = tensor_entry .layout .scalar_type ,
262
+ sizes = tensor_entry .layout .sizes ,
263
+ dim_order = tensor_entry .layout .dim_order ,
264
+ segment_index = len (segments ),
265
+ offset = offset ,
266
+ )
267
+ )
268
+ segments .append (tensor_data )
269
+ return tensors
270
+
271
+
212
272
class FlatTensorSerializer (DataSerializer ):
213
273
"""A concrete implementation of the DataSerializer interface that
214
274
serializes and deserializes data to/from the FlatTensor format.
@@ -227,61 +287,45 @@ def serialize(
227
287
self ,
228
288
data : DataPayload ,
229
289
) -> Cord :
230
- """Serializes a list of tensor metadata and tensors into a blob."""
231
-
232
- flat_tensor_metadata : List [TensorMetadata ] = []
233
- flat_tensor_data : Cord = Cord ()
234
-
235
- # {idx, offset}
236
- saved_offsets : Dict [int , int ] = {}
237
-
238
- for fqn , tensor_entry in data .fqn_to_tensor .items ():
239
- assert tensor_entry .layout is not None
240
- # Check index into the tensor buffers is valid.
241
- assert tensor_entry .buffer_index < len (
242
- data .buffers
243
- ), f"Invalid index { tensor_entry .buffer_index } is greater than tensor buffer size { len (data .buffers )} ."
244
-
245
- # Check if the tensor has already been appended to the flat_tensor_data.
246
- offset = saved_offsets .get (tensor_entry .buffer_index , - 1 )
247
- if offset == - 1 :
248
- if len (flat_tensor_data ) > 0 :
249
- # Add padding to round off the previous tensor offset.
250
- pad_length = padding_required (
251
- len (flat_tensor_data ), self .config .tensor_alignment
252
- )
253
- flat_tensor_data .append (b"\x00 " * pad_length )
254
- # Add to saved offsets.
255
- offset = len (flat_tensor_data )
256
- saved_offsets [tensor_entry .buffer_index ] = offset
257
- # Append to flat_tensor_data at the offset.
258
- flat_tensor_data .append (data .buffers [tensor_entry .buffer_index ])
259
-
260
- flat_tensor_metadata .append (
261
- TensorMetadata (
262
- fully_qualified_name = fqn ,
263
- scalar_type = tensor_entry .layout .scalar_type ,
264
- sizes = tensor_entry .layout .sizes ,
265
- dim_order = tensor_entry .layout .dim_order ,
266
- segment_index = 0 ,
267
- offset = offset ,
290
+ """Serializes a list of tensors and named data into a blob."""
291
+
292
+ segments : List [Cord ] = []
293
+ tensors = _extract_tensors (
294
+ data .fqn_to_tensor ,
295
+ data .buffers ,
296
+ segments ,
297
+ self .config .tensor_alignment ,
298
+ )
299
+
300
+ data_segments : List [DataSegment ] = []
301
+ segment_data = Cord ()
302
+ for segment in segments :
303
+ prev_end = (
304
+ (data_segments [- 1 ].offset + data_segments [- 1 ].size )
305
+ if data_segments
306
+ else 0
307
+ )
308
+ data_segments .append (
309
+ DataSegment (
310
+ offset = aligned_size (prev_end , self .config .segment_alignment ),
311
+ size = len (segment ),
268
312
)
269
313
)
270
-
271
- # Pad flat_tensor_data to segment alignment.
272
- segment_pad_length = padding_required (
273
- len ( flat_tensor_data ), self . config . segment_alignment
274
- )
275
- if segment_pad_length > 0 :
276
- flat_tensor_data .append (b" \x00 " * segment_pad_length )
314
+ # Pad segment_data to segment alignment.
315
+ segment_pad_length = padding_required (
316
+ len ( segment_data ), self . config . segment_alignment
317
+ )
318
+ if segment_pad_length > 0 :
319
+ segment_data . append ( b" \x00 " * segment_pad_length )
320
+ segment_data .append (segment )
277
321
278
322
# Create FlatTensor, which describes of the contents of the file and
279
323
# points to all the data segments. It will be serialized to flatbuffer.
280
324
flat_tensor = FlatTensor (
281
325
version = 0 , # Keep in sync with c++ version number in serialize.h
282
326
tensor_alignment = self .config .tensor_alignment ,
283
- tensors = flat_tensor_metadata ,
284
- segments = [ DataSegment ( offset = 0 , size = len ( flat_tensor_data ))] ,
327
+ tensors = tensors ,
328
+ segments = data_segments ,
285
329
named_data = [],
286
330
)
287
331
@@ -307,7 +351,7 @@ def serialize(
307
351
flatbuffer_offset = padded_header_length ,
308
352
flatbuffer_size = len (flatbuffer_payload ),
309
353
segment_base_offset = segment_base_offset ,
310
- segment_data_size = len (flat_tensor_data ),
354
+ segment_data_size = len (segment_data ),
311
355
).to_bytes ()
312
356
313
357
# Pad header and payload to segment alignment.
@@ -327,15 +371,15 @@ def serialize(
327
371
assert eh .flatbuffer_size == original_flatbuffer_payload_size
328
372
assert eh .segment_base_offset == segment_base_offset
329
373
assert eh .flatbuffer_offset == padded_header_length
330
- assert eh .segment_data_size == len (flat_tensor_data )
374
+ assert eh .segment_data_size == len (segment_data )
331
375
332
376
del header_data
333
377
del flatbuffer_payload
334
378
335
379
# Place everything into one segment.
336
380
payload = Cord ()
337
381
payload .append (injected_flatbuffer_data )
338
- payload .append (flat_tensor_data )
382
+ payload .append (segment_data )
339
383
340
384
return payload
341
385
0 commit comments