Skip to content

Commit ad6fa22

Browse files
committed
fix: Error with aten.view across Tensor memory (#2464)
1 parent e79ca21 commit ad6fa22

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1111
from .repair_input_as_output import repair_input_as_output
1212
from .replace_max_pool_with_indices import replace_max_pool_with_indices
13+
from .view_to_reshape import view_to_reshape
1314

1415
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1516
[
@@ -19,6 +20,7 @@
1920
lower_efficient_attention,
2021
fuse_prims_broadcast,
2122
replace_max_pool_with_indices,
23+
view_to_reshape,
2224
]
2325
)
2426

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Callable, List, Sequence, Tuple
3+
4+
import torch
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def view_to_reshape(
13+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14+
) -> torch.fx.GraphModule:
15+
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16+
orig, replacement = view_replacement()
17+
18+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19+
gm = clean_up_graph_after_modifications(gm)
20+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21+
22+
return gm
23+
24+
25+
def view_replacement() -> (
26+
Tuple[
27+
torch.fx.GraphModule,
28+
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
29+
]
30+
):
31+
"""Constructs the original and replacement functions for view"""
32+
33+
# Original graph
34+
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
35+
return torch.ops.aten.view.default(input, shape)
36+
37+
# Replacement graph
38+
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
39+
return torch.ops.aten.reshape.default(input, shape)
40+
41+
return orig, replacement

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+65
Original file line numberDiff line numberDiff line change
@@ -267,5 +267,70 @@ def forward(self, q, k, v):
267267
torch._dynamo.reset()
268268

269269

270+
class TestLowerViewToReshape(TestCase):
271+
def test_view_to_reshape(self):
272+
class ViewToReshape(torch.nn.Module):
273+
def forward(self, input):
274+
out = torch.ops.aten.view.default(input, (1, 1, -1))
275+
return out
276+
277+
inputs = [
278+
torch.rand((3, 4, 5, 32)).cuda(),
279+
]
280+
281+
fx_graph = torch.fx.symbolic_trace(ViewToReshape())
282+
expected_ops = {torch.ops.aten.reshape.default}
283+
unexpected_ops = {
284+
torch.ops.aten.view.default,
285+
}
286+
287+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
288+
fx_graph,
289+
inputs,
290+
expected_ops=expected_ops,
291+
unexpected_ops=unexpected_ops,
292+
min_block_size=1,
293+
)
294+
295+
self.assertEquals(
296+
len(unexpected_ops_seen),
297+
0,
298+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
299+
)
300+
301+
self.assertEquals(
302+
len(expected_ops_unseen),
303+
0,
304+
f"The following expected ops were not encountered: {expected_ops_unseen}",
305+
)
306+
torch._dynamo.reset()
307+
308+
# Validate that the results between Torch and Torch-TRT are similar
309+
optimized_model = torch_tensorrt.compile(
310+
fx_graph,
311+
"torch_compile",
312+
inputs,
313+
min_block_size=1,
314+
pass_through_build_failures=True,
315+
)
316+
optimized_model_results = torch.cat(
317+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
318+
)
319+
torch_model_results = torch.cat(
320+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
321+
)
322+
323+
max_diff = float(
324+
torch.max(torch.abs(optimized_model_results - torch_model_results))
325+
)
326+
self.assertAlmostEqual(
327+
max_diff,
328+
0,
329+
DECIMALS_OF_AGREEMENT,
330+
msg=f"ViewToReshape TRT outputs don't match with the original model.",
331+
)
332+
torch._dynamo.reset()
333+
334+
270335
if __name__ == "__main__":
271336
run_tests()

0 commit comments

Comments
 (0)