12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import List , Optional , Type
15
+ from typing import Type
16
16
17
17
from deepsparse .v2 .task import SupportedTasks , dynamic_import_task
18
+ from sparsezoo .utils .registry import (
19
+ RegistryMixin ,
20
+ get_from_registry ,
21
+ register ,
22
+ registered_names ,
23
+ )
18
24
19
25
20
- _REGISTERED_OPERATORS = {}
21
-
22
26
__all__ = ["OperatorRegistry" ]
23
27
24
28
25
- class OperatorRegistry :
26
- def register (task : str , task_aliases : Optional [List [str ]] = None ):
29
+ class OperatorRegistry (RegistryMixin ):
30
+ """
31
+ Register operators with given task name(s). Leverages the RegistryMixin
32
+ functionality.
33
+ """
34
+
35
+ @classmethod
36
+ def register_value (cls , operator , name ):
27
37
from deepsparse .v2 .operators import Operator
28
38
29
- """
30
- Decorator to register an operator with its task name and its aliases. The
31
- registered names can be used to load the operator through `Operator.create()`
39
+ if not isinstance (name , list ):
40
+ name = [name ]
32
41
33
- Multiple operators may not have the same task name. An error will
34
- be raised if two different operators attempt to register the same task name
42
+ for task_name in name :
43
+ register ( Operator , operator , task_name , require_subclass = True )
35
44
36
- :param task: main task name of this operator
37
- :param task_aliases: list of extra task names that may be used to reference
38
- this operator. Default is None
39
- """
40
- task_names = [task ]
41
- if task_aliases :
42
- task_names .extend (task_aliases )
43
-
44
- task_names = [task_name .lower ().replace ("-" , "_" ) for task_name in task_names ]
45
-
46
- def _register_task (task_name , operator ):
47
- if task_name in _REGISTERED_OPERATORS and (
48
- operator is not _REGISTERED_OPERATORS [task_name ]
49
- ):
50
- raise RuntimeError (
51
- f"task { task_name } already registered by OperatorRegistry.register "
52
- f"attempting to register operator: { operator } , but"
53
- f"operator: { _REGISTERED_OPERATORS [task_name ]} , already registered"
54
- )
55
- _REGISTERED_OPERATORS [task_name ] = operator
56
-
57
- def _register_operator (operator : Operator ):
58
- if not issubclass (operator , Operator ):
59
- raise RuntimeError (
60
- f"Attempting to register operator { operator } . "
61
- f"Registered operators must inherit from { Operator } "
62
- )
63
- for task_name in task_names :
64
- _register_task (task_name , operator )
65
-
66
- # set task and task_aliases as class level property
67
- operator .task = task
68
- operator .task_aliases = task_aliases
69
-
70
- return operator
71
-
72
- return _register_operator
73
-
74
- @staticmethod
75
- def get_task_constructor (task : str ) -> Type ["Operator" ]:
45
+ return operator
46
+
47
+ @classmethod
48
+ def get_task_constructor (cls , task : str ) -> Type ["Operator" ]: # noqa: F821
76
49
"""
77
50
This function retrieves the class previously registered via
78
51
`OperatorRegistry.register` for `task`.
@@ -86,6 +59,8 @@ def get_task_constructor(task: str) -> Type["Operator"]:
86
59
:return: The class registered to `task`
87
60
:raises ValueError: if `task` was not registered via `OperatorRegistry.register`
88
61
"""
62
+ from deepsparse .v2 .operators import Operator
63
+
89
64
if task .startswith ("import:" ):
90
65
# dynamically import the task from a file
91
66
task = dynamic_import_task (module_or_path = task .replace ("import:" , "" ))
@@ -95,14 +70,7 @@ def get_task_constructor(task: str) -> Type["Operator"]:
95
70
else :
96
71
task = task .lower ().replace ("-" , "_" )
97
72
73
+ tasks = registered_names (Operator )
98
74
# step needed to import relevant files required to load the operator
99
- SupportedTasks .check_register_task (task , _REGISTERED_OPERATORS .keys ())
100
-
101
- if task not in _REGISTERED_OPERATORS :
102
- raise ValueError (
103
- f"Unknown Operator task { task } . Operators tasks should be "
104
- "must be declared with the OperatorRegistry.register decorator. "
105
- f"Currently registered operators: { list (_REGISTERED_OPERATORS .keys ())} "
106
- )
107
-
108
- return _REGISTERED_OPERATORS [task ]
75
+ SupportedTasks .check_register_task (task , tasks )
76
+ return get_from_registry (Operator , task )
0 commit comments