Skip to content

Commit 715b280

Browse files
digantdesaikirklandsign
authored andcommitted
[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity
Pull Request resolved: #8892 Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) ghstack-source-id: 269599293
1 parent 09ad20a commit 715b280

File tree

4 files changed

+35
-19
lines changed

4 files changed

+35
-19
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9797
def _overwrite_precision(self, node: torch.fx.Node):
9898
precision = self._detect_precision(node)
9999
if precision not in self.enabled_precision_types:
100-
# detected precision is not enabled, lets try to partition it as fp32
100+
# detected precision is not enabled, try to partition it as fp32
101101
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
102-
# if only fp32 is enabled, then we can still partition fp32 gemms
102+
# when only fp32 is enabled, then we can still partition fp32 gemms
103103
# even with in a quantized graph
104104
if precision in [
105105
ConfigPrecisionType.STATIC_QUANT,
@@ -108,6 +108,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
108108
precision = ConfigPrecisionType.FP32
109109
logging.info(f"Overwriting precision, partitioning {node} as FP32")
110110
return True, precision
111+
111112
return False, precision
112113

113114
def get_deps(
@@ -210,8 +211,11 @@ def _get_bias_deps(
210211
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
211212
) -> Tuple[bool, List[torch.fx.Node]]:
212213
gemm_deps = []
213-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
214-
# if force force_fp32_dynamic_linear is enabled, then we
214+
if (
215+
precision == ConfigPrecisionType.FP32
216+
and self.force_non_static_weights_for_f32_linear
217+
):
218+
# if force_non_static_weights_for_f32_linear is enabled, then we
215219
# do not partition the weight node
216220
return (True, gemm_deps)
217221

@@ -287,8 +291,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
287291
def _get_weight_deps(
288292
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
289293
) -> Tuple[bool, List[torch.fx.Node]]:
290-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
291-
# if force fp32_dynamic_linear is enabled, then we
294+
if (
295+
precision == ConfigPrecisionType.FP32
296+
and self.force_non_static_weights_for_f32_linear
297+
):
298+
# if force_non_static_weights_for_f32_linear is enabled, then we
292299
# do not partition the weight node
293300
return (True, [])
294301

@@ -394,9 +401,11 @@ def __init__(self, **kwargs):
394401
def _get_weight_deps(
395402
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
396403
) -> Tuple[bool, List[torch.fx.Node]]:
397-
# TODO(maxren, T210537195):
398-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
399-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
404+
if (
405+
precision == ConfigPrecisionType.FP32
406+
and self.force_non_static_weights_for_f32_linear
407+
):
408+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
400409
# do not partition the weight node
401410
return (True, [])
402411

@@ -482,11 +491,11 @@ def find_partition_args(input_node):
482491
node.args = old_args
483492
node.users = old_users
484493

485-
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
494+
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
486495
# Else we want to be greedy.
487496
ret_deps = (
488497
list(set(deps) & set(src_partition.nodes))
489-
if self.force_fp32_dynamic_linear
498+
if self.force_non_static_weights_for_f32_linear
490499
else list(set(deps) | set(src_partition.nodes))
491500
)
492501

@@ -512,8 +521,11 @@ def __init__(self, **kwargs):
512521
def _get_weight_deps(
513522
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
514523
) -> Tuple[bool, List[torch.fx.Node]]:
515-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
516-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
524+
if (
525+
precision == ConfigPrecisionType.FP32
526+
and self.force_non_static_weights_for_f32_linear
527+
):
528+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
517529
# do not partition the weight node
518530
return (True, [])
519531

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get(
45+
"force_non_static_weights_for_f32_linear", False
46+
)
4547

4648
def get_partition(
4749
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def test_linear_qd8_as_fp32(self):
874874
},
875875
)
876876

877-
def test_linear_fp32_with_force_as_mm(self):
877+
def test_linear_with_force_non_static_weights_for_f32_linear(self):
878878
def check_signature(
879879
signature: ExportGraphSignature,
880880
force_flag: bool,
@@ -907,7 +907,7 @@ def check_signature(
907907
inputs = module.get_inputs()
908908
tester = Tester(module, inputs).export()
909909
partitioner = XnnpackPartitioner(
910-
force_fp32_dynamic_linear=force_flag
910+
force_non_static_weights_for_f32_linear=force_flag
911911
)
912912
if legacy_mode:
913913
tester.to_edge()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ def test_fp32_lstm(self):
4343
.run_method_and_compare_outputs()
4444
)
4545

46-
def test_fp32_lstm_force_dynamic_linear(self):
46+
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4747
(
4848
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
52+
partitioners=[
53+
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
54+
]
5355
)
5456
)
5557
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5658
# Weights are supplied as input to linears
57-
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
59+
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
5860
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
5961
.to_executorch()
6062
.serialize()

0 commit comments

Comments
 (0)