@@ -97,9 +97,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
97
97
def _overwrite_precision (self , node : torch .fx .Node ):
98
98
precision = self ._detect_precision (node )
99
99
if precision not in self .enabled_precision_types :
100
- # detected precision is not enabled, lets try to partition it as fp32
100
+ # detected precision is not enabled, try to partition it as fp32
101
101
if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
102
- # if only fp32 is enabled, then we can still partition fp32 gemms
102
+ # when only fp32 is enabled, then we can still partition fp32 gemms
103
103
# even with in a quantized graph
104
104
if precision in [
105
105
ConfigPrecisionType .STATIC_QUANT ,
@@ -108,6 +108,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
108
108
precision = ConfigPrecisionType .FP32
109
109
logging .info (f"Overwriting precision, partitioning { node } as FP32" )
110
110
return True , precision
111
+
111
112
return False , precision
112
113
113
114
def get_deps (
@@ -210,8 +211,11 @@ def _get_bias_deps(
210
211
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
211
212
) -> Tuple [bool , List [torch .fx .Node ]]:
212
213
gemm_deps = []
213
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
214
- # if force force_fp32_dynamic_linear is enabled, then we
214
+ if (
215
+ precision == ConfigPrecisionType .FP32
216
+ and self .force_non_static_weights_for_f32_linear
217
+ ):
218
+ # if force_non_static_weights_for_f32_linear is enabled, then we
215
219
# do not partition the weight node
216
220
return (True , gemm_deps )
217
221
@@ -287,8 +291,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
287
291
def _get_weight_deps (
288
292
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
289
293
) -> Tuple [bool , List [torch .fx .Node ]]:
290
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
291
- # if force fp32_dynamic_linear is enabled, then we
294
+ if (
295
+ precision == ConfigPrecisionType .FP32
296
+ and self .force_non_static_weights_for_f32_linear
297
+ ):
298
+ # if force_non_static_weights_for_f32_linear is enabled, then we
292
299
# do not partition the weight node
293
300
return (True , [])
294
301
@@ -394,9 +401,11 @@ def __init__(self, **kwargs):
394
401
def _get_weight_deps (
395
402
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
396
403
) -> Tuple [bool , List [torch .fx .Node ]]:
397
- # TODO(maxren, T210537195):
398
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
399
- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
404
+ if (
405
+ precision == ConfigPrecisionType .FP32
406
+ and self .force_non_static_weights_for_f32_linear
407
+ ):
408
+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
400
409
# do not partition the weight node
401
410
return (True , [])
402
411
@@ -482,11 +491,11 @@ def find_partition_args(input_node):
482
491
node .args = old_args
483
492
node .users = old_users
484
493
485
- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
494
+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
486
495
# Else we want to be greedy.
487
496
ret_deps = (
488
497
list (set (deps ) & set (src_partition .nodes ))
489
- if self .force_fp32_dynamic_linear
498
+ if self .force_non_static_weights_for_f32_linear
490
499
else list (set (deps ) | set (src_partition .nodes ))
491
500
)
492
501
@@ -512,8 +521,11 @@ def __init__(self, **kwargs):
512
521
def _get_weight_deps (
513
522
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
514
523
) -> Tuple [bool , List [torch .fx .Node ]]:
515
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
516
- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
524
+ if (
525
+ precision == ConfigPrecisionType .FP32
526
+ and self .force_non_static_weights_for_f32_linear
527
+ ):
528
+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
517
529
# do not partition the weight node
518
530
return (True , [])
519
531
0 commit comments