Skip to content

[FX] Changes done internally at Facebook #1178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
tensorrt_converter,
)
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
from .input_tensor_spec import InputTensorSpec # noqa
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa
76 changes: 71 additions & 5 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,66 @@
from typing import Iterable, List, NamedTuple, Sequence, Tuple
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple

import torch

from .types import Shape, ShapeRange
from .utils import get_dynamic_dims


def generate_input_specs(
inputs, lower_setting, additional_inputs=None, fixed_shape=False
):
# AIT lower setting doesn't have explicit_batch_dimension field and
# we just return None.
if not hasattr(lower_setting, "explicit_batch_dimension"):
return None

if not lower_setting.explicit_batch_dimension or fixed_shape:
return InputTensorSpec.from_tensors(inputs)

# If we don't have additional inputs, we assume the first dimension
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
if additional_inputs is None:
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
lower_setting.max_batch_size,
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
)
else:
batch_dims = []

for i, j in zip(inputs, additional_inputs):
found_batch_dim = False

for idx, values in enumerate(zip(i.shape, j.shape)):
if values[0] != values[1]:
assert (
found_batch_dim is False
), f"We've already found a batch dim, {i.shape}, {j.shape}."
batch_dims.append(idx)
found_batch_dim = True

if not found_batch_dim:
raise RuntimeError(
f"Failed to find batch dimension because shapes are the same, {i.shape}"
)

return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
0,
lower_setting.max_batch_size,
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
batch_dims,
)


class InputTensorSpec(NamedTuple):
"""
This class contains the information of a input tensor.
Expand Down Expand Up @@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size(
tensors: Sequence[torch.Tensor],
batch_size_range: Tuple[int, int, int],
opt_profile_replica: int = 1,
batch_dims: Optional[List[int]] = None,
) -> List["InputTensorSpec"]:
"""
Produce a list of InputTenosrSpec named tuples which would contain
Expand All @@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size(
the smallest batch size allowed. The second integer indiceates
the batch size that we'll optimize for. The third integer indicates
the largest batch size allowed.
opt_profile_replica (int): If dynamic shape is enabled, each execution
context requires a different optimization profile. This arg determines
how many optimization profile replicas we want to produce.
batch_dims (Optional[List[int]]): The batch dim might not be the leading dim
and allow user to specify the batch dims using this arg. Default we treat
dim 0 as the batch dim.

Returns:
A list of InputTensorSpec named tuples with dynamic ranges.
"""
if batch_dims is None:
batch_dims = [0] * len(tensors)

input_specs = []
batch_size = tensors[0].size(0)
batch_size = tensors[0].size(batch_dims[0])

for i, tensor in enumerate(tensors):
batch_dim = batch_dims[i]
assert batch_size == tensor.size(
0
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[0] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)
Expand Down
38 changes: 9 additions & 29 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses as dc
import logging
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
Expand All @@ -10,15 +10,9 @@
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .input_tensor_spec import InputTensorSpec
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import (
chain_passes,
decorate_method,
PassFunc,
validate_inference,
)
from .passes.pass_utils import decorate_method, PassFunc, validate_inference
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting

Expand Down Expand Up @@ -91,25 +85,8 @@ def create(cls, lower_setting):
return LowerTrtInterpreter(lower_setting, timing_cache_manager)

def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
input_specs_val = (
self.lower_setting.input_specs
if self.lower_setting.input_specs
else (
InputTensorSpec.from_tensors_with_dynamic_batch_size(
input,
(
0,
self.lower_setting.max_batch_size,
self.lower_setting.max_batch_size,
),
self.lower_setting.opt_profile_replica,
)
if self.lower_setting.explicit_batch_dimension
and self.lower_setting.dynamic_batch
else InputTensorSpec.from_tensors(input)
)
)
logger.info(f"{split_name=} {input_specs_val=}")
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
logger.info(f"{split_name=} {self.lower_setting.input_specs=}")

# Prepare algorithm selector and timing_cache for TRTInterpreter
algo_selector = None
Expand All @@ -125,7 +102,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:

interpreter = TRTInterpreter(
mod,
input_specs=input_specs_val,
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
logger_level=trt.Logger.VERBOSE
Expand Down Expand Up @@ -242,6 +219,7 @@ def __call__(
self,
module: nn.Module,
inputs: Input,
additional_inputs: Optional[Input] = None,
) -> nn.Module:
module.eval()

Expand All @@ -254,7 +232,9 @@ def __call__(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
)
pm = self.lower_pass_manager_builder.build_lower_pipeline(inputs)
pm = self.lower_pass_manager_builder.build_lower_pipeline(
inputs, additional_inputs
)

lower_result = pm(module)

Expand Down
31 changes: 28 additions & 3 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from functools import partial, wraps
from typing import Any, Callable, Sequence
from typing import Any, Callable, Optional, Sequence

import torch
from torch import nn
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import SplitResult
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult

from ..input_tensor_spec import generate_input_specs

from ..lower_setting import LowerSetting
from ..observer import Observer
Expand Down Expand Up @@ -120,13 +122,33 @@ def _split_pass(self) -> PassManager:

def _lower_pass(self) -> PassManager:
def lower_func(split_result: SplitResult) -> nn.Module:
if (
hasattr(self.lower_setting, "explicit_batch_dimension")
and self.lower_setting.explicit_batch_dimension
and self._additional_input
):
additional_submodule_inputs = generate_inputs_for_submodules(
split_result.split_module,
self._additional_input,
list(split_result.submodule_inputs.keys()),
)
else:
additional_submodule_inputs = None

for submod_name, submod_inputs in split_result.submodule_inputs.items():
submod = getattr(split_result.split_module, submod_name)

LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
)
Expand All @@ -139,8 +161,11 @@ def lower_func(split_result: SplitResult) -> nn.Module:

return PassManager.build_from_passlist([lower_func])

def build_lower_pipeline(self, input: Input) -> PassManager:
def build_lower_pipeline(
self, input: Input, additional_input: Optional[Input] = None
) -> PassManager:
self._input = input
self._additional_input = additional_input
passes = []

passes.append(self._const_fold_pass())
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ def _validate_inference(pass_: PassFunc) -> PassFunc:

@wraps(pass_)
def pass_with_validation(
module: fx.GraphModule, input: Input
module: fx.GraphModule,
input: Input,
*args,
**kwargs,
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
Expand Down
22 changes: 21 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec

# NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples

Expand Down Expand Up @@ -93,6 +93,26 @@ def forward(self, x):
test_implicit_batch_dim=False,
)

def test_prod_all_dims_with_dynamic_shape(
self,
op=torch.prod,
):
class Prod(torch.nn.Module):
def forward(self, x):
return op(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]

self.run_test_with_dynamic_shape(
Prod(), input_specs, expected_ops={acc_ops.prod}
)


if __name__ == "__main__":
run_tests()
43 changes: 42 additions & 1 deletion py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import generate_input_specs, InputTensorSpec, LowerSetting


class TestTRTModule(TestCase):
Expand Down Expand Up @@ -47,6 +47,47 @@ def test_from_tensors_with_dynamic_batch_size(self):
self.assertEqual(batch_size, shape[0])
self.assertSequenceEqual(tensor.shape[1:], shape[1:])

def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
batch_size_range = [2, 3, 4]
specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
tensors, batch_size_range, batch_dims=[0, 1]
)
for i, spec_and_tensor in enumerate(zip(specs, tensors)):
spec, tensor = spec_and_tensor
self._validate_spec(spec, tensor, dynamic_dims=[i])

for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
self.assertEqual(batch_size, shape[i])
tensor_shape = list(tensor.shape)
tensor_shape[i] = batch_size
self.assertSequenceEqual(tensor_shape, shape)

def test_generate_input_specs(self):
lower_setting = LowerSetting(
explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
)

# Implicit batch dim.
inputs = [torch.randn(1, 2, 3)]
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor)

# Explicit batch dim without additional inputs.
lower_setting.explicit_batch_dimension = True
specs = generate_input_specs(inputs, lower_setting)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor, dynamic_dims=[0])
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)

# Explicit batch dim with additional inputs.
additional_inputs = [torch.randn(1, 1, 3)]
specs = generate_input_specs(inputs, lower_setting, additional_inputs)
for spec, tensor in zip(specs, inputs):
self._validate_spec(spec, tensor, dynamic_dims=[1])
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)


if __name__ == "__main__":
run_tests()