Skip to content

Commit b9db0a3

Browse files
pytorchbotlucylq
andauthored
[executorch][emit] Refactor _tensor_spec_to_evalue
^ adding more logic to _tensor_spec_to_evalue in the next diff; simplifying it now. Otherwise, linter error on complexity. Differential Revision: [D66847875](https://our.internmc.facebook.com/intern/diff/D66847875/) ghstack-source-id: 256981105 Pull Request resolved: #7233 Co-authored-by: lucylq <[email protected]>
1 parent 06e85a8 commit b9db0a3

File tree

1 file changed

+61
-42
lines changed

1 file changed

+61
-42
lines changed

exir/emit/_emitter.py

+61-42
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op
4949
from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node
5050
from executorch.exir.schema import (
51+
AllocationDetails,
5152
BackendDelegate,
5253
BackendDelegateDataReference,
5354
BackendDelegateInlineData,
@@ -328,6 +329,59 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
328329
ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}"
329330
)
330331

332+
def _get_allocation_info(self, spec: TensorSpec) -> AllocationDetails:
333+
"""Returns the allocation info for a given TensorSpec."""
334+
self._internal_assert_emitter(
335+
isinstance(spec.mem_id, int) and spec.mem_id >= 0,
336+
self.node,
337+
f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}",
338+
)
339+
340+
self._internal_assert_emitter(
341+
isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
342+
self.node,
343+
f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}",
344+
)
345+
try:
346+
allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset)
347+
except AddressSpaceOverflowException as e:
348+
raise InternalError(
349+
self._emit_node_specific_error(
350+
self.node,
351+
(
352+
f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, "
353+
f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an "
354+
f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) "
355+
"during torch.export()."
356+
),
357+
)
358+
)
359+
return allocation_info
360+
361+
def _save_new_const_tensor(
362+
self,
363+
spec: TensorSpec,
364+
buffer_data: bytes,
365+
hashed: str,
366+
allocation_info: Optional[AllocationDetails],
367+
) -> int:
368+
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369+
370+
self.program_state.allocated_specs.append(spec)
371+
# +1 because the first buffer location is reserved.
372+
373+
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
374+
buffer = Buffer(storage=buffer_data)
375+
if allocation_info:
376+
buffer_idx = len(self.program_state.mutable_buffer)
377+
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
378+
self.program_state.mutable_buffer.append(buffer)
379+
else:
380+
buffer_idx = len(self.program_state.constant_buffer)
381+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
382+
self.program_state.constant_buffer.append(buffer)
383+
return buffer_idx
384+
331385
def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
332386
"""Constructs an EValue from the given TensorSpec."""
333387

@@ -339,35 +393,12 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
339393
# default algos to set offsets, so need to check both.
340394
if spec.mem_id is not None and spec.mem_offset is not None:
341395
# Tensor is an activation.
342-
self._internal_assert_emitter(
343-
isinstance(spec.mem_id, int) and spec.mem_id >= 0,
344-
self.node,
345-
f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}",
346-
)
347-
348-
self._internal_assert_emitter(
349-
isinstance(spec.mem_offset, int) and spec.mem_offset >= 0,
350-
self.node,
351-
f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}",
352-
)
353-
try:
354-
allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset)
355-
except AddressSpaceOverflowException as e:
356-
raise InternalError(
357-
self._emit_node_specific_error(
358-
self.node,
359-
(
360-
f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, "
361-
f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an "
362-
f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) "
363-
"during torch.export()."
364-
),
365-
)
366-
)
396+
allocation_info = self._get_allocation_info(spec)
367397

398+
# Tensor is either a constant tensor, or a mutable tensor with an initial state.
368399
if spec.const:
369400
# Tensor with a blob we need to serialize. May not actually be constant at runtime
370-
# if it's a weight with an associated gradient
401+
# if it's a weight with an associated gradient.
371402
spec_array_type = (
372403
ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
373404
)
@@ -392,23 +423,11 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
392423
else:
393424
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
394425

395-
# Haven't seen this constant before
426+
# Haven't seen this constant before.
396427
if buffer_idx == -1:
397-
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
398-
buffer = Buffer(storage=buffer_data)
399-
self.program_state.allocated_specs.append(spec)
400-
# +1 because the first buffer location is reserved
401-
402-
if allocation_info:
403-
buffer_idx = len(self.program_state.mutable_buffer)
404-
self.program_state.cached_spec_mutable_hash_values[hashed] = (
405-
buffer_idx
406-
)
407-
self.program_state.mutable_buffer.append(buffer)
408-
else:
409-
buffer_idx = len(self.program_state.constant_buffer)
410-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
411-
self.program_state.constant_buffer.append(buffer)
428+
buffer_idx = self._save_new_const_tensor(
429+
spec, buffer_data, hashed, allocation_info
430+
)
412431

413432
if spec.const and spec.nbytes() != len(buffer_data):
414433
raise InternalError(

0 commit comments

Comments
 (0)