Skip to content

Commit 7ce47fc

Browse files
authored
add backend is available
Differential Revision: D69810445 Pull Request resolved: #8738
1 parent a5750fb commit 7ce47fc

File tree

5 files changed

+40
-0
lines changed

5 files changed

+40
-0
lines changed

extension/pybindings/portable_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_dump_profile_results, # noqa: F401
4040
_get_operator_names, # noqa: F401
4141
_get_registered_backend_names, # noqa: F401
42+
_is_available, # noqa: F401
4243
_load_bundled_program_from_buffer, # noqa: F401
4344
_load_for_executorch, # noqa: F401
4445
_load_for_executorch_from_buffer, # noqa: F401

extension/pybindings/pybindings.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,12 @@ using ::executorch::extension::BufferDataLoader;
8888
using ::executorch::extension::MallocMemoryAllocator;
8989
using ::executorch::extension::MmapDataLoader;
9090
using ::executorch::runtime::ArrayRef;
91+
using ::executorch::runtime::BackendInterface;
9192
using ::executorch::runtime::DataLoader;
9293
using ::executorch::runtime::Error;
9394
using ::executorch::runtime::EValue;
9495
using ::executorch::runtime::EventTracerDebugLogLevel;
96+
using ::executorch::runtime::get_backend_class;
9597
using ::executorch::runtime::get_backend_name;
9698
using ::executorch::runtime::get_num_registered_backends;
9799
using ::executorch::runtime::get_registered_kernels;
@@ -990,6 +992,14 @@ py::list get_registered_backend_names() {
990992
return res;
991993
}
992994

995+
py::bool_ is_available(const std::string& backend_name) {
996+
BackendInterface* backend = get_backend_class(backend_name.c_str());
997+
if (backend == nullptr) {
998+
return false;
999+
}
1000+
return backend->is_available();
1001+
}
1002+
9931003
} // namespace
9941004

9951005
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1048,6 +1058,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10481058
&get_registered_backend_names,
10491059
call_guard);
10501060
m.def("_get_operator_names", &get_operator_names);
1061+
m.def("_is_available", &is_available, py::arg("backend_name"), call_guard);
10511062
m.def("_create_profile_block", &create_profile_block, call_guard);
10521063
m.def(
10531064
"_reset_profile_results",

extension/pybindings/pybindings.pyi

+9
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,15 @@ def _load_bundled_program_from_buffer(
211211
"""
212212
...
213213

214+
@experimental("This API is experimental and subject to change without notice.")
215+
def _is_available(backend_name: str) -> bool:
216+
"""
217+
.. warning::
218+
219+
This API is experimental and subject to change without notice.
220+
"""
221+
...
222+
214223
@experimental("This API is experimental and subject to change without notice.")
215224
def _get_operator_names() -> List[str]:
216225
"""

extension/pybindings/test/test_backend_pybinding.py

+13
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,16 @@ def test_backend_name_list(
1212
registered_backend_names = runtime.backend_registry.registered_backend_names
1313
self.assertGreaterEqual(len(registered_backend_names), 1)
1414
self.assertIn("XnnpackBackend", registered_backend_names)
15+
16+
def test_backend_is_available(
17+
self,
18+
) -> None:
19+
# XnnpackBackend is available
20+
runtime = Runtime.get()
21+
self.assertTrue(
22+
runtime.backend_registry.is_available(backend_name="XnnpackBackend")
23+
)
24+
# NonExistBackend doesn't exist and not available
25+
self.assertFalse(
26+
runtime.backend_registry.is_available(backend_name="NonExistBackend")
27+
)

runtime/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def registered_backend_names(self) -> List[str]:
139139
"""
140140
return self._legacy_module._get_registered_backend_names()
141141

142+
def is_available(self, backend_name: str) -> bool:
143+
"""
144+
Returns the names of all registered backends as a list of strings.
145+
"""
146+
return self._legacy_module._is_available(backend_name)
147+
142148

143149
class OperatorRegistry:
144150
"""The registry of operators that are available to the runtime."""

0 commit comments

Comments
 (0)