Skip to content

Commit 644b7dd

Browse files
authored
Refactor serialize.py (#9579)
Pull Request resolved: #9124 Refactor `_extract_tensors` out of serialize.py, simplify before serializing named_data ghstack-source-id: 273798696 Differential Revision: [D70752429](https://our.internmc.facebook.com/intern/diff/D70752429/) ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent 3b2a8ca commit 644b7dd

File tree

1 file changed

+101
-57
lines changed

1 file changed

+101
-57
lines changed

extension/flat_tensor/serialize/serialize.py

+101-57
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,33 @@
1010
import os
1111
import tempfile
1212
from dataclasses import dataclass
13-
from typing import ClassVar, Dict, List, Literal, Optional
13+
from typing import ClassVar, Dict, List, Literal, Optional, Sequence
1414

1515
import pkg_resources
1616
from executorch.exir._serialize._cord import Cord
1717
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
1818

1919
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2020
from executorch.exir._serialize._program import _insert_flatbuffer_header
21-
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
21+
from executorch.exir._serialize.data_serializer import (
22+
DataPayload,
23+
DataSerializer,
24+
TensorEntry,
25+
)
2226

2327
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
2428

25-
# Byte order of numbers written to flat tensor headers. Always little-endian
26-
# regardless of the host system, since all commonly-used modern CPUs are little
27-
# endian.
28-
_HEADER_BYTEORDER: Literal["little"] = "little"
29-
3029
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
3130
DataSegment,
3231
FlatTensor,
3332
TensorMetadata,
3433
)
3534

35+
# Byte order of numbers written to flat tensor headers. Always little-endian
36+
# regardless of the host system, since all commonly-used modern CPUs are little
37+
# endian.
38+
_HEADER_BYTEORDER: Literal["little"] = "little"
39+
3640

3741
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
3842
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
@@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
209213
return None
210214

211215

216+
def _extract_tensors(
217+
fqn_to_tensor: Dict[str, TensorEntry],
218+
buffers: Sequence[bytes],
219+
segments: List[Cord],
220+
tensor_alignment: int,
221+
) -> List[TensorMetadata]:
222+
"""Places tensors into a single segment, aligned to tensor_alignment within
223+
the segment.
224+
225+
Args:
226+
fqn_to_tensor: A map from fully qualified names to tensor entries.
227+
buffers: A sequence of tensor buffers.
228+
segments: A list of segments to append the tensor data to. Modified in-place.
229+
tensor_alignment: The alignment of the tensor data.
230+
231+
Returns:
232+
A list of TensorMetadata, which describes the tensors in the segment.
233+
"""
234+
tensor_data: Cord = Cord()
235+
tensors: List[TensorMetadata] = []
236+
# {idx, offset}
237+
saved_offsets: Dict[int, int] = {}
238+
for fqn, tensor_entry in fqn_to_tensor.items():
239+
assert tensor_entry.layout is not None
240+
# Check index into the tensor buffers is valid.
241+
assert tensor_entry.buffer_index < len(
242+
buffers
243+
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(buffers)}."
244+
245+
# Check if the tensor has already been appended to the flat_tensor_data.
246+
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
247+
if offset == -1:
248+
if len(tensor_data) > 0:
249+
# Add padding to round off the previous tensor offset.
250+
pad_length = padding_required(len(tensor_data), tensor_alignment)
251+
tensor_data.append(b"\x00" * pad_length)
252+
# Add to saved offsets.
253+
offset = len(tensor_data)
254+
saved_offsets[tensor_entry.buffer_index] = offset
255+
# Append to flat_tensor_data at the offset.
256+
tensor_data.append(buffers[tensor_entry.buffer_index])
257+
258+
tensors.append(
259+
TensorMetadata(
260+
fully_qualified_name=fqn,
261+
scalar_type=tensor_entry.layout.scalar_type,
262+
sizes=tensor_entry.layout.sizes,
263+
dim_order=tensor_entry.layout.dim_order,
264+
segment_index=len(segments),
265+
offset=offset,
266+
)
267+
)
268+
segments.append(tensor_data)
269+
return tensors
270+
271+
212272
class FlatTensorSerializer(DataSerializer):
213273
"""A concrete implementation of the DataSerializer interface that
214274
serializes and deserializes data to/from the FlatTensor format.
@@ -227,61 +287,45 @@ def serialize(
227287
self,
228288
data: DataPayload,
229289
) -> Cord:
230-
"""Serializes a list of tensor metadata and tensors into a blob."""
231-
232-
flat_tensor_metadata: List[TensorMetadata] = []
233-
flat_tensor_data: Cord = Cord()
234-
235-
# {idx, offset}
236-
saved_offsets: Dict[int, int] = {}
237-
238-
for fqn, tensor_entry in data.fqn_to_tensor.items():
239-
assert tensor_entry.layout is not None
240-
# Check index into the tensor buffers is valid.
241-
assert tensor_entry.buffer_index < len(
242-
data.buffers
243-
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}."
244-
245-
# Check if the tensor has already been appended to the flat_tensor_data.
246-
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
247-
if offset == -1:
248-
if len(flat_tensor_data) > 0:
249-
# Add padding to round off the previous tensor offset.
250-
pad_length = padding_required(
251-
len(flat_tensor_data), self.config.tensor_alignment
252-
)
253-
flat_tensor_data.append(b"\x00" * pad_length)
254-
# Add to saved offsets.
255-
offset = len(flat_tensor_data)
256-
saved_offsets[tensor_entry.buffer_index] = offset
257-
# Append to flat_tensor_data at the offset.
258-
flat_tensor_data.append(data.buffers[tensor_entry.buffer_index])
259-
260-
flat_tensor_metadata.append(
261-
TensorMetadata(
262-
fully_qualified_name=fqn,
263-
scalar_type=tensor_entry.layout.scalar_type,
264-
sizes=tensor_entry.layout.sizes,
265-
dim_order=tensor_entry.layout.dim_order,
266-
segment_index=0,
267-
offset=offset,
290+
"""Serializes a list of tensors and named data into a blob."""
291+
292+
segments: List[Cord] = []
293+
tensors = _extract_tensors(
294+
data.fqn_to_tensor,
295+
data.buffers,
296+
segments,
297+
self.config.tensor_alignment,
298+
)
299+
300+
data_segments: List[DataSegment] = []
301+
segment_data = Cord()
302+
for segment in segments:
303+
prev_end = (
304+
(data_segments[-1].offset + data_segments[-1].size)
305+
if data_segments
306+
else 0
307+
)
308+
data_segments.append(
309+
DataSegment(
310+
offset=aligned_size(prev_end, self.config.segment_alignment),
311+
size=len(segment),
268312
)
269313
)
270-
271-
# Pad flat_tensor_data to segment alignment.
272-
segment_pad_length = padding_required(
273-
len(flat_tensor_data), self.config.segment_alignment
274-
)
275-
if segment_pad_length > 0:
276-
flat_tensor_data.append(b"\x00" * segment_pad_length)
314+
# Pad segment_data to segment alignment.
315+
segment_pad_length = padding_required(
316+
len(segment_data), self.config.segment_alignment
317+
)
318+
if segment_pad_length > 0:
319+
segment_data.append(b"\x00" * segment_pad_length)
320+
segment_data.append(segment)
277321

278322
# Create FlatTensor, which describes of the contents of the file and
279323
# points to all the data segments. It will be serialized to flatbuffer.
280324
flat_tensor = FlatTensor(
281325
version=0, # Keep in sync with c++ version number in serialize.h
282326
tensor_alignment=self.config.tensor_alignment,
283-
tensors=flat_tensor_metadata,
284-
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
327+
tensors=tensors,
328+
segments=data_segments,
285329
named_data=[],
286330
)
287331

@@ -307,7 +351,7 @@ def serialize(
307351
flatbuffer_offset=padded_header_length,
308352
flatbuffer_size=len(flatbuffer_payload),
309353
segment_base_offset=segment_base_offset,
310-
segment_data_size=len(flat_tensor_data),
354+
segment_data_size=len(segment_data),
311355
).to_bytes()
312356

313357
# Pad header and payload to segment alignment.
@@ -327,15 +371,15 @@ def serialize(
327371
assert eh.flatbuffer_size == original_flatbuffer_payload_size
328372
assert eh.segment_base_offset == segment_base_offset
329373
assert eh.flatbuffer_offset == padded_header_length
330-
assert eh.segment_data_size == len(flat_tensor_data)
374+
assert eh.segment_data_size == len(segment_data)
331375

332376
del header_data
333377
del flatbuffer_payload
334378

335379
# Place everything into one segment.
336380
payload = Cord()
337381
payload.append(injected_flatbuffer_data)
338-
payload.append(flat_tensor_data)
382+
payload.append(segment_data)
339383

340384
return payload
341385

0 commit comments

Comments
 (0)