Skip to content

Commit e4b8365

Browse files
authored
Merge pull request #2166 from pytorch/opset_coverage
feat(torch_tensorrt.dynamo.tools): Tool to calculate coverage of PyTorch
2 parents 0527edd + e8966d7 commit e4b8365

File tree

4 files changed

+220
-5
lines changed

4 files changed

+220
-5
lines changed

py/torch_tensorrt/dynamo/conversion/converter_registry.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import dataclass, field
3-
from typing import Any, Callable, Dict, Optional, Sequence, Union
3+
from typing import Any, Callable, Dict, Optional, Sequence, Union, List
44
from enum import Enum, auto
55

66
from torch.fx.node import Target, Node, _get_qualified_name
@@ -305,23 +305,32 @@ def unique_targets(self):
305305
"""Returns the set of unique converter targets stored across all registries"""
306306
return set.union(*[set(registry.keys()) for registry in self.registries])
307307

308+
# TODO: Make this a static method since it does not need state
308309
def qualified_name_or_str(self, target: Target) -> str:
309310
"""Returns string representation of an FX Node target"""
310311
if isinstance(target, str):
311312
return target
312313
else:
313314
return _get_qualified_name(target)
314315

315-
def display_all_available_converters(self) -> str:
316-
"""Returns a string with all converters and their source, separated by newlines"""
317-
available_converters = "Available converters in ATen registries with counts:\n"
318-
316+
def get_converter_support_info(self) -> Dict[str, Dict[str, int]]:
317+
"""Returns a dictionary of targets backed by at least one converter"""
318+
available_converters = {}
319319
for target in sorted(
320320
self.unique_targets(), key=lambda target: self.qualified_name_or_str(target)
321321
):
322322
_, registry_data = self.get_all_converters_with_target(
323323
target, return_registry_info=True
324324
)
325+
available_converters[self.qualified_name_or_str(target)] = registry_data
326+
return available_converters
327+
328+
def display_all_available_converters(self) -> str:
329+
"""Returns a string with all converters and their source, separated by newlines"""
330+
available_converters = "Available converters in ATen registries with counts:\n"
331+
332+
support_info = self.get_converter_support_info()
333+
for target, registry_data in support_info.keys():
325334
available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n"
326335

327336
return available_converters

py/torch_tensorrt/dynamo/tools/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import dataclasses
2+
import json
3+
import os
4+
from collections import OrderedDict
5+
from dataclasses import dataclass
6+
from enum import Enum, auto
7+
from pathlib import Path
8+
from typing import Any, Callable, Dict, List, Optional, Tuple
9+
10+
import torch
11+
import torch._prims as prims
12+
import torchgen
13+
from torch._ops import OpOverload
14+
from torch._dynamo.variables import BuiltinVariable
15+
from torch_tensorrt.dynamo.conversion.converter_registry import (
16+
DYNAMO_CONVERTERS,
17+
ConverterRegistry,
18+
)
19+
from torch_tensorrt.dynamo.lowering import get_decompositions
20+
from torchgen.gen import parse_native_yaml
21+
22+
23+
class SupportStatus(Enum):
24+
CONVERTED = auto()
25+
LEGACY_CONVERTED = auto()
26+
LOWERED = auto()
27+
FALLBACK = auto()
28+
29+
def __str__(self) -> str:
30+
return self.name
31+
32+
33+
@dataclass
34+
class OpsetCoverage:
35+
support_status: Dict[str, Dict[str, str]]
36+
dynamo_coverage: float
37+
legacy_coverage: float
38+
decomposition_coverage: float
39+
fallback_coverage: float
40+
41+
42+
NATIVE_FUNCTION_YAML_PATH = (
43+
Path(os.path.dirname(torchgen.__file__))
44+
/ "packaged/ATen/native/native_functions.yaml"
45+
)
46+
TAGS_YAML_PATH = (
47+
Path(os.path.dirname(torchgen.__file__)) / "packaged/ATen/native/tags.yaml"
48+
)
49+
50+
51+
def get_aten_ops() -> List[Tuple[str, str]]:
52+
parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
53+
native_functions = parsed_yaml.native_functions
54+
55+
aten_ops = OrderedDict()
56+
for function in native_functions:
57+
if "core" in function.tags:
58+
op_name = str(function.func.name)
59+
aten_ops[op_name] = function
60+
61+
op_schema_pairs = []
62+
for key, op in sorted(aten_ops.items()):
63+
op_name = f"aten.{key}"
64+
schema = str(op.func).replace("*", r"\*")
65+
66+
op_schema_pairs.append((op_name, schema))
67+
68+
return op_schema_pairs
69+
70+
71+
ATEN_OPS = get_aten_ops()
72+
73+
74+
def get_prims_ops() -> List[Tuple[str, str]]:
75+
op_schema_pairs = []
76+
for op_name in prims.__all__:
77+
op_overload = getattr(prims, op_name, None)
78+
79+
if not isinstance(op_overload, torch._ops.OpOverload):
80+
continue
81+
82+
op_overloadpacket = op_overload.overloadpacket
83+
84+
op_name = str(op_overload).replace(".default", "")
85+
schema = op_overloadpacket.schema.replace("*", r"\*")
86+
87+
op_schema_pairs.append((op_name, schema))
88+
89+
return op_schema_pairs
90+
91+
92+
PRIM_OPS = get_prims_ops()
93+
94+
95+
def get_overloaded_py_ops() -> List[Tuple[str, str]]:
96+
python_ops = BuiltinVariable._fx_graph_functions()
97+
op_schema_pairs = []
98+
for op in python_ops:
99+
name = op.__name__
100+
op_schema_pairs.append((f"_operator.{name}", ""))
101+
102+
return op_schema_pairs
103+
104+
105+
OVERLOADED_PY_OPS = get_overloaded_py_ops()
106+
107+
108+
def opset_coverage(
109+
opset: List[Tuple[str, str]],
110+
converter_registry: Optional[ConverterRegistry] = None,
111+
decomposition_registry: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
112+
) -> OpsetCoverage:
113+
114+
opset_schemas = dict(opset)
115+
opset_targets = set(opset_schemas.keys())
116+
117+
support_status = {}
118+
119+
# TODO: Could be way less complicated if there is a way to convert from
120+
# strings to OpOverload
121+
c_registry = (
122+
converter_registry if converter_registry is not None else DYNAMO_CONVERTERS
123+
)
124+
converter_registry_targets = {
125+
c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
126+
for target in c_registry.keys()
127+
}
128+
supported_converted_targets = opset_targets.intersection(converter_registry_targets)
129+
support_count = 0
130+
legacy_count = 0
131+
for target in c_registry.keys():
132+
target_str = c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
133+
if target_str in opset_targets:
134+
_, registry_data = c_registry.get_all_converters_with_target(
135+
target, return_registry_info=True
136+
)
137+
if registry_data["Dynamo ATen Converters Registry"] >= 1:
138+
status = SupportStatus.CONVERTED
139+
support_count += 1
140+
elif registry_data["FX ATen Converters Registry"] >= 1:
141+
status = SupportStatus.LEGACY_CONVERTED
142+
legacy_count += 1
143+
144+
support_status[target_str] = {
145+
"schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}",
146+
"status": str(status),
147+
}
148+
149+
l_registry = (
150+
decomposition_registry
151+
if decomposition_registry is not None
152+
else get_decompositions()
153+
)
154+
decomp_registry_targets = {
155+
c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
156+
for target in l_registry.keys()
157+
}
158+
supported_decomp_targets = opset_targets.intersection(decomp_registry_targets)
159+
decomposition_count = len(supported_decomp_targets)
160+
for target in supported_decomp_targets:
161+
support_status[target] = {
162+
"schema": f"{target.split('.')[0]}.{opset_schemas[target]}",
163+
"status": str(SupportStatus.LOWERED),
164+
}
165+
166+
unsupported_targets = opset_targets.difference(
167+
supported_converted_targets.union(supported_decomp_targets)
168+
)
169+
unsupported_count = len(unsupported_targets)
170+
for target in unsupported_targets:
171+
support_status[target] = {
172+
"schema": f"{target.split('.')[0]}.{opset_schemas[target]}",
173+
"status": str(SupportStatus.FALLBACK),
174+
}
175+
176+
return OpsetCoverage(
177+
support_status,
178+
dynamo_coverage=support_count / len(opset),
179+
legacy_coverage=legacy_count / len(opset),
180+
decomposition_coverage=decomposition_count / len(opset),
181+
fallback_coverage=unsupported_count / len(opset),
182+
)
183+
184+
185+
if __name__ == "__main__":
186+
187+
def find_coverage_status(opset: List[Tuple[str, str]], name: str) -> None:
188+
coverage = opset_coverage(opset)
189+
print(f"{name}:")
190+
print(f" - Dynamo converters: {coverage.dynamo_coverage:.2%}")
191+
print(f" - Decomposed: {coverage.decomposition_coverage:.2%}")
192+
print(f" - Legacy FX converters: {coverage.legacy_coverage:.2%}")
193+
print(f" - Ops to fallback to Torch: {coverage.fallback_coverage:.2%}")
194+
print(
195+
f"Per op coverage status saved to /tmp/{name.lower()}_coverage_status.json"
196+
)
197+
198+
with open(f"/tmp/{name.lower()}_coverage_status.json", "w") as f:
199+
json.dump(dataclasses.asdict(coverage), f)
200+
201+
print("-------- OPERATOR SET COVERAGE --------")
202+
find_coverage_status(ATEN_OPS, "ATen")
203+
find_coverage_status(PRIM_OPS, "prim")
204+
find_coverage_status(OVERLOADED_PY_OPS, "py_overload")

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def run(self):
350350
"torch_tensorrt.dynamo.lowering",
351351
"torch_tensorrt.dynamo.lowering.substitutions",
352352
"torch_tensorrt.dynamo.runtime",
353+
"torch_tensorrt.dynamo.tools",
353354
"torch_tensorrt.fx",
354355
"torch_tensorrt.fx.converters",
355356
"torch_tensorrt.fx.converters.impl",
@@ -374,6 +375,7 @@ def run(self):
374375
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
375376
"torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
376377
"torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",
378+
"torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools",
377379
"torch_tensorrt.fx": "py/torch_tensorrt/fx",
378380
"torch_tensorrt.fx.converters": "py/torch_tensorrt/fx/converters",
379381
"torch_tensorrt.fx.converters.impl": "py/torch_tensorrt/fx/converters/impl",

0 commit comments

Comments
 (0)