Skip to content

Commit 7f88494

Browse files
committed
fix: Error with aten.view across Tensor memory
- Address error where `aten.view` is called on TRT output Tensors, which can be in a different memory format than Torch expects - Specifically, TRT can modify tensor memory to optimize certain layers, but Torch's view operator depends on specific configurations which can be violated at runtime (but not at compile time, since Torch itself would run these configurations correctly) - Add a custom lowering pass to replace `view` with `reshape`, avoiding this issue. Reshape will make a copy of the underlying Tensor if necessary - Torch-TRT's `aten.view` implementation is the same as that for `aten.reshape`, and they share a schema so no changes are needed on the converter side - Add test case to validate new lowering pass
1 parent 4985c70 commit 7f88494

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1212
from .repair_input_as_output import repair_input_as_output
1313
from .replace_max_pool_with_indices import replace_max_pool_with_indices
14+
from .view_to_reshape import view_to_reshape
1415

1516
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1617
[
@@ -21,6 +22,7 @@
2122
lower_linear,
2223
fuse_prims_broadcast,
2324
replace_max_pool_with_indices,
25+
view_to_reshape,
2426
]
2527
)
2628

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Callable, List, Sequence, Tuple
3+
4+
import torch
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def view_to_reshape(
13+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14+
) -> torch.fx.GraphModule:
15+
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16+
orig, replacement = view_replacement()
17+
18+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19+
gm = clean_up_graph_after_modifications(gm)
20+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21+
22+
return gm
23+
24+
25+
def view_replacement() -> (
26+
Tuple[
27+
torch.fx.GraphModule,
28+
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
29+
]
30+
):
31+
"""Constructs the original and replacement functions for view"""
32+
33+
# Original graph
34+
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
35+
return torch.ops.aten.view.default(input, shape)
36+
37+
# Replacement graph
38+
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
39+
return torch.ops.aten.reshape.default(input, shape)
40+
41+
return orig, replacement

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
2-
import torch_tensorrt
32
from torch.testing._internal.common_utils import TestCase, run_tests
43

4+
import torch_tensorrt
5+
56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
67

78

@@ -375,5 +376,70 @@ def forward(self, input, weight, bias):
375376
torch._dynamo.reset()
376377

377378

379+
class TestLowerViewToReshape(TestCase):
380+
def test_view_to_reshape(self):
381+
class ViewToReshape(torch.nn.Module):
382+
def forward(self, input):
383+
out = torch.ops.aten.view.default(input, (1, 1, -1))
384+
return out
385+
386+
inputs = [
387+
torch.rand((3, 4, 5, 32)).cuda(),
388+
]
389+
390+
fx_graph = torch.fx.symbolic_trace(ViewToReshape())
391+
expected_ops = {torch.ops.aten.reshape.default}
392+
unexpected_ops = {
393+
torch.ops.aten.view.default,
394+
}
395+
396+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
397+
fx_graph,
398+
inputs,
399+
expected_ops=expected_ops,
400+
unexpected_ops=unexpected_ops,
401+
min_block_size=1,
402+
)
403+
404+
self.assertEquals(
405+
len(unexpected_ops_seen),
406+
0,
407+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
408+
)
409+
410+
self.assertEquals(
411+
len(expected_ops_unseen),
412+
0,
413+
f"The following expected ops were not encountered: {expected_ops_unseen}",
414+
)
415+
torch._dynamo.reset()
416+
417+
# Validate that the results between Torch and Torch-TRT are similar
418+
optimized_model = torch_tensorrt.compile(
419+
fx_graph,
420+
"torch_compile",
421+
inputs,
422+
min_block_size=1,
423+
pass_through_build_failures=True,
424+
)
425+
optimized_model_results = torch.cat(
426+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
427+
)
428+
torch_model_results = torch.cat(
429+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
430+
)
431+
432+
max_diff = float(
433+
torch.max(torch.abs(optimized_model_results - torch_model_results))
434+
)
435+
self.assertAlmostEqual(
436+
max_diff,
437+
0,
438+
DECIMALS_OF_AGREEMENT,
439+
msg=f"ViewToReshape TRT outputs don't match with the original model.",
440+
)
441+
torch._dynamo.reset()
442+
443+
378444
if __name__ == "__main__":
379445
run_tests()

0 commit comments

Comments
 (0)