Skip to content

Commit 1359e74

Browse files
authored
chore: address flaky test failures related to global partitioning (#3369)
1 parent 43831dc commit 1359e74

File tree

27 files changed

+155
-113
lines changed

27 files changed

+155
-113
lines changed

docs/_downloads/0e30a6276601af7e5fc4d5166e2e3d37/torch_compile_advanced_usage.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Torch Compile Advanced Usage
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API."""
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/2a9ac10f2667047a7f398d1593b7ca33/torch_export_gpt2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling GPT2 using the dynamo backend
55
==========================================================
66
7-
This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model."""
7+
This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/3d4d74f6636d986f33167154f6553961/torch_export_cudagraphs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Torch Export with Cudagraphs
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well."""
7+
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling ResNet using the Torch-TensorRT Dyanmo Frontend
55
==========================================================
66
7-
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model."""
7+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/7b7004dc2ea6f839be532665e16e0426/torch_export_llama2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling Llama2 using the dynamo backend
55
==========================================================
66
7-
This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model."""
7+
This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/d6e1bb6ec5f884994554d9d12e37a0f6/torch_compile_resnet_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling ResNet with dynamic shapes using the `torch.compile` backend
55
==========================================================
66
7-
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model."""
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/dfa60e8f9850fd7761f3e7da81304d32/torch_compile_transformers_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling BERT using the `torch.compile` backend
55
==============================================================
66
7-
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model."""
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Dynamo Compile Advanced Usage
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API."""
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling a Transformer using torch.compile and TensorRT
55
==============================================================
66
7-
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model."""
7+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/v1.4.0/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling ResNet using the Torch-TensorRT Dyanmo Frontend
55
==========================================================
66
7-
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model."""
7+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Dynamo Compile Advanced Usage
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API."""
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

docs/v1.4.0/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling a Transformer using torch.compile and TensorRT
55
==============================================================
66
7-
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model."""
7+
This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_compile_advanced_usage.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Torch Compile Advanced Usage
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API."""
7+
This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_compile_resnet_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling ResNet with dynamic shapes using the `torch.compile` backend
55
==========================================================
66
7-
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model."""
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_compile_transformers_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling BERT using the `torch.compile` backend
55
==============================================================
66
7-
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model."""
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_export_cudagraphs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Torch Export with Cudagraphs
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well."""
7+
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_export_gpt2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling GPT2 using the dynamo backend
55
==========================================================
66
7-
This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model."""
7+
This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

examples/dynamo/torch_export_llama2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Compiling Llama2 using the dynamo backend
55
==========================================================
66
7-
This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model."""
7+
This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.
8+
"""
89

910
# %%
1011
# Imports and Model Definition

py/torch_tensorrt/_Input.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
261261

262262
@staticmethod
263263
def _parse_tensor_domain(
264-
domain: Optional[Tuple[float, float]]
264+
domain: Optional[Tuple[float, float]],
265265
) -> Tuple[float, float]:
266266
"""
267267
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)

py/torch_tensorrt/_enums.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ def _from(
12001200

12011201
@classmethod
12021202
def try_from(
1203-
c: Union[trt.EngineCapability, EngineCapability]
1203+
c: Union[trt.EngineCapability, EngineCapability],
12041204
) -> Optional[EngineCapability]:
12051205
"""Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.
12061206

py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def _redraw(self, *, blank_lines: int = 0) -> None:
5353
if self._render:
5454

5555
def clear_line() -> None:
56-
print("\x1B[2K", end="")
56+
print("\x1b[2K", end="")
5757

5858
def move_to_start_of_line() -> None:
59-
print("\x1B[0G", end="")
59+
print("\x1b[0G", end="")
6060

6161
def move_cursor_up(lines: int) -> None:
62-
print("\x1B[{}A".format(lines), end="")
62+
print("\x1b[{}A".format(lines), end="")
6363

6464
def progress_bar(steps: int, num_steps: int) -> str:
6565
INNER_WIDTH = 10

py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def hard_sigmoid(
247247
operation_type = trt.ActivationType.HARD_SIGMOID
248248

249249
def hard_sigmoid_dyn_range_fn(
250-
dyn_range: Tuple[float, float]
250+
dyn_range: Tuple[float, float],
251251
) -> Tuple[float, float]:
252252
def hard_sigmoid_fn(x: float) -> float:
253253
return max(0, min(1, alpha * x + beta))
@@ -310,7 +310,7 @@ def thresholded_relu(
310310
operation_type = trt.ActivationType.THRESHOLDED_RELU
311311

312312
def thresholded_relu_dyn_range_fn(
313-
dyn_range: Tuple[float, float]
313+
dyn_range: Tuple[float, float],
314314
) -> Tuple[float, float]:
315315
def thresholded_relu_fn(x: float) -> float:
316316
return x if x > alpha else 0

py/torch_tensorrt/dynamo/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch
465465

466466

467467
def to_torch_tensorrt_device(
468-
device: Optional[Union[Device, torch.device, str]]
468+
device: Optional[Union[Device, torch.device, str]],
469469
) -> Device:
470470
"""Cast a device-type to torch_tensorrt.Device
471471

py/torch_tensorrt/fx/test/converters/acc_op/test_where.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(self, x_shape, y_shape):
101101
def forward(self, condition):
102102
return torch.where(condition, self.x, self.y)
103103

104-
inputs = [(torch.randn(condition_shape) > 0)]
104+
inputs = [torch.randn(condition_shape) > 0]
105105
self.run_test(
106106
Where(x_shape, y_shape),
107107
inputs,

py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import (
1111
Any,
1212
Callable,
13-
cast,
1413
Dict,
1514
Iterable,
1615
Optional,
@@ -19,6 +18,7 @@
1918
Tuple,
2019
Type,
2120
Union,
21+
cast,
2222
)
2323

2424
import torch
@@ -32,7 +32,6 @@
3232

3333
from . import acc_normalizer, acc_ops, acc_shape_prop, acc_utils # noqa: F401
3434

35-
3635
_LOGGER = logging.getLogger(__name__)
3736

3837

@@ -517,7 +516,7 @@ def _replace_transpose_last_dims_impl(
517516
changed = False
518517

519518
def _calculate_dim(
520-
transpose_dim: Union[torch.fx.Node, int]
519+
transpose_dim: Union[torch.fx.Node, int],
521520
) -> Union[torch.fx.Node, int]:
522521
nonlocal transpose_input_node
523522
nonlocal changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from copy import deepcopy
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
import torch.nn.functional as F
7+
import torch_tensorrt
8+
from parameterized import parameterized
9+
from torch.testing._internal.common_utils import TestCase, run_tests
10+
from torch_tensorrt.dynamo import partitioning
11+
12+
from ..testing_utilities import lower_graph_testing
13+
14+
# Note: the following tests were a part of test_global_partitioning.py and were flaky when
15+
# we ran all the tests. So, the following test cases were separated out in this test_flaky_global_partitioning.py
16+
# The partitioned graphs were different when you ran the graph as a part of test_global_partitioning.py vs when you
17+
# run these tests independently. pytest by default doesn't use parallel execution, so we are not sure why this behavior occurs
18+
# currently. When you run these tests independently, the partitioned graph is structurally correct and is similar to fast partitioning.
19+
20+
21+
class TestGlobalPartitioning(TestCase):
22+
def test_partition_partially_supported_multi_op(self):
23+
class PartiallySupportedMultiOp(torch.nn.Module):
24+
def __init__(self, *args, **kwargs) -> None:
25+
super().__init__(*args, **kwargs)
26+
27+
def forward(self, x, y):
28+
sum_1 = torch.ops.aten.add.Tensor(x, y)
29+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
30+
sum_ = np.sum(sum_1) + np.sum(sum_2)
31+
relu_ = torch.ops.aten.relu.default(sum_)
32+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
33+
return pow_
34+
35+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
36+
partitioned_graph, _ = partitioning.global_partition(
37+
deepcopy(fx_graph), min_block_size=2
38+
)
39+
# breakpoint()
40+
self.assertEqual(
41+
len(list(partitioned_graph.named_children())),
42+
2,
43+
"Unsupported operators interleave supported ones, expected 2 segments",
44+
)
45+
46+
def test_partition_partially_supported_with_torch_executed_ops(self):
47+
class PartiallySupportedMultiOp(torch.nn.Module):
48+
def __init__(self, *args, **kwargs) -> None:
49+
super().__init__(*args, **kwargs)
50+
51+
def forward(self, x, y):
52+
sum_1 = torch.ops.aten.add.Tensor(x, y)
53+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
54+
sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2)
55+
relu_ = torch.ops.aten.relu.default(sum_)
56+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
57+
return pow_
58+
59+
unexpected_ops = {torch.ops.aten.add.Tensor}
60+
61+
inputs = [
62+
torch.randint(
63+
1,
64+
10,
65+
(5,),
66+
),
67+
torch.randint(
68+
1,
69+
10,
70+
(5,),
71+
),
72+
]
73+
74+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
75+
(
76+
unexpected_ops_seen,
77+
_,
78+
partitioned_graphs,
79+
) = lower_graph_testing(
80+
fx_graph,
81+
inputs,
82+
unexpected_ops=unexpected_ops,
83+
min_block_size=2,
84+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
85+
testing_partitioning=True,
86+
use_fast_partitioner=False,
87+
)
88+
89+
self.assertEqual(
90+
len(unexpected_ops_seen),
91+
0,
92+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
93+
)
94+
95+
self.assertEqual(
96+
len(partitioned_graphs),
97+
1,
98+
"Without control flow breaks, there should only be a single graph",
99+
)
100+
self.assertEqual(
101+
len(list(partitioned_graphs[0].named_children())),
102+
1,
103+
"Certain operators are set to run in Torch, expected 1 segment",
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
run_tests()

0 commit comments

Comments
 (0)