From 7341c6db21ff6003bbe058a89b60cce8c7ca60a6 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Sun, 24 Jul 2022 23:05:41 -0700 Subject: [PATCH] Changes done internally at Facebook 6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei [fx2trt] bridge the dynamic batch and fixed shape f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu group swish LN plugin ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu Create backend specific lower pass 38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi [fx] run acc_linter.lint in acc_tracer.trace d5e749f9bef8157f33fa36ce59b7e1693fdff942 Wei Wei "(uncommitted/untracked changes)" 292bba27ebe69c1d3e05f6a3130c810035508118 Wei Wei [self] kefei test 9a26bab1bb87a3895613e6de4175537ac1ec1447 Wei Wei [self] test kefei 2 6656b13fccf5ae24a167144896b015e6b8c9137d wwei6 [self] modify mts benchmarck 731e93868617ca9521f85a5cc37cdb47fb4ca0bc wwei6 verify on benchmark --- py/torch_tensorrt/fx/input_tensor_spec.py | 7 +-- py/torch_tensorrt/fx/lower.py | 2 +- .../fx/passes/lower_pass_manager_builder.py | 50 +++++++++++++++++-- .../fx/tracer/acc_tracer/acc_ops.py | 5 +- 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 910bb8228e..2fd49b9e5d 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -7,12 +7,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): - # AIT lower setting doesn't have explicit_batch_dimension field and - # we just return None. - if not hasattr(lower_setting, "explicit_batch_dimension"): - return None - - # dynamic_batch is TRT only flag. It does not exist in AIT lower setting + # dynamic_batch is TRT only flag. if ( not lower_setting.explicit_batch_dimension or lower_setting.dynamic_batch is False diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index c99004585a..6d052fc34e 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -232,7 +232,7 @@ def __call__( x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs ) - pm = self.lower_pass_manager_builder.build_lower_pipeline( + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( inputs, additional_inputs ) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index ee09da1ce5..98c6314f18 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -121,7 +121,7 @@ def _split_pass(self) -> PassManager: ) return PassManager.build_from_passlist(passes) - def _lower_pass(self) -> PassManager: + def _trt_lower_pass(self) -> PassManager: def lower_func(split_result: SplitResult) -> nn.Module: if ( hasattr(self.lower_setting, "explicit_batch_dimension") @@ -169,7 +169,51 @@ def lower_func(split_result: SplitResult) -> nn.Module: return PassManager.build_from_passlist([lower_func]) - def build_lower_pipeline( + def _default_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + print("Now lowering submodule", submod_name) + lowering_start_time = datetime.datetime.now() + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + print( + f"Lowering submodule {submod_name} elapsed time", + datetime.datetime.now() - lowering_start_time, + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def build_trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + self._input = input + self._additional_input = additional_input + passes = [] + + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_default_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: self._input = input @@ -179,7 +223,7 @@ def build_lower_pipeline( passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) - passes.append(self._lower_pass()) + passes.append(self._default_lower_pass()) pm = PassManager.build_from_passlist(passes) return pm diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8814bf4075..be6b6700a1 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -528,7 +528,6 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return cat_node -@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) @register_acc_op @@ -1743,7 +1742,7 @@ def quantized_conv2d( dilation, groups, padding_mode, - acc_out_ty, + acc_out_ty=None, ): qparams = acc_out_ty.qparams return torch.nn.quantized.functional.conv2d( @@ -2041,7 +2040,7 @@ def quantized_batch_norm2d( weight, bias, eps, - acc_out_ty, + acc_out_ty=None, ): qparams = acc_out_ty.qparams return torch.ops.quantized.batch_norm2d(