diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 910bb8228e..2fd49b9e5d 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -7,12 +7,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): - # AIT lower setting doesn't have explicit_batch_dimension field and - # we just return None. - if not hasattr(lower_setting, "explicit_batch_dimension"): - return None - - # dynamic_batch is TRT only flag. It does not exist in AIT lower setting + # dynamic_batch is TRT only flag. if ( not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index c99004585a..6d052fc34e 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -232,7 +232,7 @@ def __call__( x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs ) - pm = self.lower_pass_manager_builder.build_lower_pipeline( + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( inputs, additional_inputs ) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index ee09da1ce5..98c6314f18 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -121,7 +121,7 @@ def _split_pass(self) -> PassManager: ) return PassManager.build_from_passlist(passes) - def _lower_pass(self) -> PassManager: + def _trt_lower_pass(self) -> PassManager: def lower_func(split_result: SplitResult) -> nn.Module: if ( hasattr(self.lower_setting, "explicit_batch_dimension") @@ -169,7 +169,51 @@ def lower_func(split_result: SplitResult) -> nn.Module: return PassManager.build_from_passlist([lower_func]) - def build_lower_pipeline( + def _default_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + print("Now lowering submodule", submod_name) + lowering_start_time = datetime.datetime.now() + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + print( + f"Lowering submodule {submod_name} elapsed time", + datetime.datetime.now() - lowering_start_time, + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def build_trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + self._input = input + self._additional_input = additional_input + passes = [] + + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_default_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: self._input = input @@ -179,7 +223,7 @@ def build_lower_pipeline( passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) - passes.append(self._lower_pass()) + passes.append(self._default_lower_pass()) pm = PassManager.build_from_passlist(passes) return pm diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8814bf4075..be6b6700a1 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -528,7 +528,6 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return cat_node -@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) @register_acc_op @@ -1743,7 +1742,7 @@ def quantized_conv2d( dilation, groups, padding_mode, - acc_out_ty, + acc_out_ty=None, ): qparams = acc_out_ty.qparams return torch.nn.quantized.functional.conv2d( @@ -2041,7 +2040,7 @@ def quantized_batch_norm2d( weight, bias, eps, - acc_out_ty, + acc_out_ty=None, ): qparams = acc_out_ty.qparams return torch.ops.quantized.batch_norm2d(