Skip to content

Commit dc41596

Browse files
authored
migrate utils from jarvis to cadence
Differential Revision: D65458848 Pull Request resolved: #6720
1 parent 4947e27 commit dc41596

File tree

3 files changed

+177
-4
lines changed

3 files changed

+177
-4
lines changed

backends/cadence/aot/TARGETS

+13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ load(
1111
"CXX",
1212
)
1313
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
14+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
1415

1516
oncall("odai_jarvis")
1617

@@ -103,3 +104,15 @@ executorch_generated_lib(
103104
"//executorch/kernels/portable:operators",
104105
],
105106
)
107+
108+
python_unittest(
109+
name = "test_pass_filter",
110+
srcs = [
111+
"tests/test_pass_filter.py",
112+
],
113+
typing = True,
114+
deps = [
115+
":pass_utils",
116+
"//executorch/exir:pass_base",
117+
],
118+
)

backends/cadence/aot/pass_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@ class CadencePassAttribute:
2828

2929

3030
# A dictionary that maps an ExportPass to its attributes.
31-
_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
31+
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
3232

3333

3434
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
35-
return _ALL_CADENCE_PASSES[p]
35+
return ALL_CADENCE_PASSES[p]
3636

3737

3838
# A decorator that registers a pass.
3939
def register_cadence_pass(
4040
pass_attribute: CadencePassAttribute,
4141
) -> Callable[[ExportPass], ExportPass]:
4242
def wrapper(cls: ExportPass) -> ExportPass:
43-
_ALL_CADENCE_PASSES[cls] = pass_attribute
43+
ALL_CADENCE_PASSES[cls] = pass_attribute
4444
return cls
4545

4646
return wrapper
4747

4848

4949
def get_all_available_cadence_passes() -> Set[ExportPass]:
50-
return set(_ALL_CADENCE_PASSES.keys())
50+
return set(ALL_CADENCE_PASSES.keys())
5151

5252

5353
# Create a new filter to filter out relevant passes from all Jarvis passes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-unsafe
4+
5+
6+
import unittest
7+
8+
from copy import deepcopy
9+
10+
from executorch.backends.cadence.aot import pass_utils
11+
from executorch.backends.cadence.aot.pass_utils import (
12+
ALL_CADENCE_PASSES,
13+
CadencePassAttribute,
14+
create_cadence_pass_filter,
15+
register_cadence_pass,
16+
)
17+
18+
from executorch.exir.pass_base import ExportPass
19+
20+
21+
class TestBase(unittest.TestCase):
22+
def setUp(self):
23+
# Before running each test, create a copy of _all_passes to later restore it after test.
24+
# This avoids messing up the original _all_passes when running tests.
25+
self._all_passes_original = deepcopy(ALL_CADENCE_PASSES)
26+
# Clear _all_passes to do a clean test. It'll be restored after each test in tearDown().
27+
pass_utils.ALL_CADENCE_PASSES.clear()
28+
29+
def tearDown(self):
30+
# Restore _all_passes to original state before test.
31+
pass_utils.ALL_CADENCE_PASSES = self._all_passes_original
32+
33+
def get_filtered_passes(self, filter_):
34+
return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)}
35+
36+
37+
# Test pass registration
38+
class TestPassRegistration(TestBase):
39+
def test_register_cadence_pass(self):
40+
pass_attr_O0 = CadencePassAttribute(opt_level=0)
41+
pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True)
42+
pass_attr_O1_all_backends = CadencePassAttribute(
43+
opt_level=1,
44+
)
45+
46+
# Register 1st pass with opt_level=0
47+
@register_cadence_pass(pass_attr_O0)
48+
class DummyPass_O0(ExportPass):
49+
pass
50+
51+
# Register 2nd pass with opt_level=1, all backends.
52+
@register_cadence_pass(pass_attr_O1_all_backends)
53+
class DummyPass_O1_All_Backends(ExportPass):
54+
pass
55+
56+
# Register 3rd pass with opt_level=None, debug=True
57+
@register_cadence_pass(pass_attr_debug)
58+
class DummyPass_Debug(ExportPass):
59+
pass
60+
61+
# Check if the three passes are indeed added into _all_passes
62+
expected_all_passes = {
63+
DummyPass_O0: pass_attr_O0,
64+
DummyPass_Debug: pass_attr_debug,
65+
DummyPass_O1_All_Backends: pass_attr_O1_all_backends,
66+
}
67+
self.assertEqual(pass_utils.ALL_CADENCE_PASSES, expected_all_passes)
68+
69+
70+
# Test pass filtering
71+
class TestPassFiltering(TestBase):
72+
def test_filter_none(self):
73+
pass_attr_O0 = CadencePassAttribute(opt_level=0)
74+
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
75+
pass_attr_O1_all_backends = CadencePassAttribute(
76+
opt_level=1,
77+
)
78+
79+
@register_cadence_pass(pass_attr_O0)
80+
class DummyPass_O0(ExportPass):
81+
pass
82+
83+
@register_cadence_pass(pass_attr_O1_debug)
84+
class DummyPass_O1_Debug(ExportPass):
85+
pass
86+
87+
@register_cadence_pass(pass_attr_O1_all_backends)
88+
class DummyPass_O1_All_Backends(ExportPass):
89+
pass
90+
91+
O1_filter = create_cadence_pass_filter(opt_level=1, debug=True)
92+
O1_filter_passes = self.get_filtered_passes(O1_filter)
93+
94+
# Assert that no passes are filtered out.
95+
expected_passes = {
96+
DummyPass_O0: pass_attr_O0,
97+
DummyPass_O1_Debug: pass_attr_O1_debug,
98+
DummyPass_O1_All_Backends: pass_attr_O1_all_backends,
99+
}
100+
self.assertEqual(O1_filter_passes, expected_passes)
101+
102+
def test_filter_debug(self):
103+
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
104+
pass_attr_O2 = CadencePassAttribute(opt_level=2)
105+
106+
@register_cadence_pass(pass_attr_O1_debug)
107+
class DummyPass_O1_Debug(ExportPass):
108+
pass
109+
110+
@register_cadence_pass(pass_attr_O2)
111+
class DummyPass_O2(ExportPass):
112+
pass
113+
114+
debug_filter = create_cadence_pass_filter(opt_level=2, debug=False)
115+
debug_filter_passes = self.get_filtered_passes(debug_filter)
116+
117+
# Assert that debug passees are filtered out, since the filter explicitly
118+
# chooses debug=False.
119+
self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2})
120+
121+
def test_filter_all(self):
122+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
123+
class DummyPass_O1(ExportPass):
124+
pass
125+
126+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
127+
class DummyPass_O2(ExportPass):
128+
pass
129+
130+
debug_filter = create_cadence_pass_filter(opt_level=0)
131+
debug_filter_passes = self.get_filtered_passes(debug_filter)
132+
133+
# Assert that all the passes are filtered out, since the filter only selects
134+
# passes with opt_level <= 0
135+
self.assertEqual(debug_filter_passes, {})
136+
137+
def test_filter_opt_level_None(self):
138+
pass_attr_O1 = CadencePassAttribute(opt_level=1)
139+
pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True)
140+
141+
@register_cadence_pass(CadencePassAttribute(opt_level=None))
142+
class DummyPass_None(ExportPass):
143+
pass
144+
145+
@register_cadence_pass(pass_attr_O1)
146+
class DummyPass_O1(ExportPass):
147+
pass
148+
149+
@register_cadence_pass(pass_attr_O2_debug)
150+
class DummyPass_O2_Debug(ExportPass):
151+
pass
152+
153+
O2_filter = create_cadence_pass_filter(opt_level=2, debug=True)
154+
filtered_passes = self.get_filtered_passes(O2_filter)
155+
# Passes with opt_level=None should never be retained.
156+
expected_passes = {
157+
DummyPass_O1: pass_attr_O1,
158+
DummyPass_O2_Debug: pass_attr_O2_debug,
159+
}
160+
self.assertEqual(filtered_passes, expected_passes)

0 commit comments

Comments
 (0)