Skip to content

Commit 39b6818

Browse files
authored
tool: Opset coverage notebook (#2831)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e122901 commit 39b6818

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,8 @@ bazel-tensorrt
6969
bazel-project
7070
build/
7171
wheelhouse/
72+
*_status.json
73+
tests/py/dynamo/models/*.ts
74+
tests/py/dynamo/models/*.ep
75+
*.deb
76+
*.tar.xz

py/torch_tensorrt/dynamo/tools/opset_coverage.py

+9
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ def opset_coverage(
206206
)
207207

208208

209+
def get_coverage_status(opset: List[Tuple[str, str]], name: str) -> OpsetCoverage:
210+
coverage = opset_coverage(opset)
211+
return coverage
212+
213+
214+
ATEN_COVERAGE = get_coverage_status(ATEN_OPS, "ATen")
215+
PRIMS_COVERAGE = get_coverage_status(PRIM_OPS, "prim")
216+
PY_OVERLOAD_COVERAGE = get_coverage_status(OVERLOADED_PY_OPS, "py_overload")
217+
209218
if __name__ == "__main__":
210219

211220
def find_coverage_status(opset: List[Tuple[str, str]], name: str) -> None:

tools/opset_coverage.ipynb

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch_tensorrt\n",
10+
"from torch_tensorrt.dynamo.tools.opset_coverage import ATEN_COVERAGE, PRIMS_COVERAGE, PY_OVERLOAD_COVERAGE, SupportStatus, OpsetCoverage"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"unsupported_ops = {}\n",
20+
"backwards_ops = {}\n",
21+
"\n",
22+
"for target, info in ATEN_COVERAGE.support_status.items():\n",
23+
" if info[\"status\"] == \"FALLBACK\":\n",
24+
" if \"backward\" not in target:\n",
25+
" unsupported_ops.update({target : info[\"schema\"]})\n",
26+
" else:\n",
27+
" backwards_ops.update({target : info[\"schema\"]})\n",
28+
"\n",
29+
"print(\"Unsupported Ops:\")\n",
30+
"for _, schema in unsupported_ops.items():\n",
31+
" print(schema)\n",
32+
"\n",
33+
"print(\"\\nBackwards Ops:\")\n",
34+
"for _, schema in backwards_ops.items():\n",
35+
" print(schema)\n"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"unsupported_ops = {}\n",
45+
"backwards_ops = {}\n",
46+
"\n",
47+
"for target, info in PRIMS_COVERAGE.support_status.items():\n",
48+
" if info[\"status\"] == \"FALLBACK\":\n",
49+
" if \"backward\" not in target:\n",
50+
" unsupported_ops.update({target : info[\"schema\"]})\n",
51+
" else:\n",
52+
" backwards_ops.update({target : info[\"schema\"]})\n",
53+
"\n",
54+
"print(\"Unsupported Ops:\")\n",
55+
"for _, schema in unsupported_ops.items():\n",
56+
" print(schema)\n",
57+
"\n",
58+
"print(\"\\nBackwards Ops:\")\n",
59+
"for _, schema in backwards_ops.items():\n",
60+
" print(schema)"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": []
69+
}
70+
],
71+
"metadata": {
72+
"kernelspec": {
73+
"display_name": "torch230cu121py311",
74+
"language": "python",
75+
"name": "python3"
76+
},
77+
"language_info": {
78+
"codemirror_mode": {
79+
"name": "ipython",
80+
"version": 3
81+
},
82+
"file_extension": ".py",
83+
"mimetype": "text/x-python",
84+
"name": "python",
85+
"nbconvert_exporter": "python",
86+
"pygments_lexer": "ipython3",
87+
"version": "3.11.7"
88+
}
89+
},
90+
"nbformat": 4,
91+
"nbformat_minor": 2
92+
}

0 commit comments

Comments
 (0)