Skip to content

Commit 3e8db3d

Browse files
committed
use sparsezoo mixin
1 parent bb6135d commit 3e8db3d

File tree

2 files changed

+29
-61
lines changed

2 files changed

+29
-61
lines changed

src/deepsparse/v2/operators/registry.py

+28-60
Original file line numberDiff line numberDiff line change
@@ -12,67 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Type
15+
from typing import Type
1616

1717
from deepsparse.v2.task import SupportedTasks, dynamic_import_task
18+
from sparsezoo.utils.registry import (
19+
RegistryMixin,
20+
get_from_registry,
21+
register,
22+
registered_names,
23+
)
1824

1925

20-
_REGISTERED_OPERATORS = {}
21-
2226
__all__ = ["OperatorRegistry"]
2327

2428

25-
class OperatorRegistry:
26-
def register(task: str, task_aliases: Optional[List[str]] = None):
29+
class OperatorRegistry(RegistryMixin):
30+
"""
31+
Register operators with given task name(s). Leverages the RegistryMixin
32+
functionality.
33+
"""
34+
35+
@classmethod
36+
def register_value(cls, operator, name):
2737
from deepsparse.v2.operators import Operator
2838

29-
"""
30-
Decorator to register an operator with its task name and its aliases. The
31-
registered names can be used to load the operator through `Operator.create()`
39+
if not isinstance(name, list):
40+
name = [name]
3241

33-
Multiple operators may not have the same task name. An error will
34-
be raised if two different operators attempt to register the same task name
42+
for task_name in name:
43+
register(Operator, operator, task_name, require_subclass=True)
3544

36-
:param task: main task name of this operator
37-
:param task_aliases: list of extra task names that may be used to reference
38-
this operator. Default is None
39-
"""
40-
task_names = [task]
41-
if task_aliases:
42-
task_names.extend(task_aliases)
43-
44-
task_names = [task_name.lower().replace("-", "_") for task_name in task_names]
45-
46-
def _register_task(task_name, operator):
47-
if task_name in _REGISTERED_OPERATORS and (
48-
operator is not _REGISTERED_OPERATORS[task_name]
49-
):
50-
raise RuntimeError(
51-
f"task {task_name} already registered by OperatorRegistry.register "
52-
f"attempting to register operator: {operator}, but"
53-
f"operator: {_REGISTERED_OPERATORS[task_name]}, already registered"
54-
)
55-
_REGISTERED_OPERATORS[task_name] = operator
56-
57-
def _register_operator(operator: Operator):
58-
if not issubclass(operator, Operator):
59-
raise RuntimeError(
60-
f"Attempting to register operator {operator}. "
61-
f"Registered operators must inherit from {Operator}"
62-
)
63-
for task_name in task_names:
64-
_register_task(task_name, operator)
65-
66-
# set task and task_aliases as class level property
67-
operator.task = task
68-
operator.task_aliases = task_aliases
69-
70-
return operator
71-
72-
return _register_operator
73-
74-
@staticmethod
75-
def get_task_constructor(task: str) -> Type["Operator"]:
45+
return operator
46+
47+
@classmethod
48+
def get_task_constructor(cls, task: str) -> Type["Operator"]: # noqa: F821
7649
"""
7750
This function retrieves the class previously registered via
7851
`OperatorRegistry.register` for `task`.
@@ -86,6 +59,8 @@ def get_task_constructor(task: str) -> Type["Operator"]:
8659
:return: The class registered to `task`
8760
:raises ValueError: if `task` was not registered via `OperatorRegistry.register`
8861
"""
62+
from deepsparse.v2.operators import Operator
63+
8964
if task.startswith("import:"):
9065
# dynamically import the task from a file
9166
task = dynamic_import_task(module_or_path=task.replace("import:", ""))
@@ -95,14 +70,7 @@ def get_task_constructor(task: str) -> Type["Operator"]:
9570
else:
9671
task = task.lower().replace("-", "_")
9772

73+
tasks = registered_names(Operator)
9874
# step needed to import relevant files required to load the operator
99-
SupportedTasks.check_register_task(task, _REGISTERED_OPERATORS.keys())
100-
101-
if task not in _REGISTERED_OPERATORS:
102-
raise ValueError(
103-
f"Unknown Operator task {task}. Operators tasks should be "
104-
"must be declared with the OperatorRegistry.register decorator. "
105-
f"Currently registered operators: {list(_REGISTERED_OPERATORS.keys())}"
106-
)
107-
108-
return _REGISTERED_OPERATORS[task]
75+
SupportedTasks.check_register_task(task, tasks)
76+
return get_from_registry(Operator, task)

src/deepsparse/v2/text_generation/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
_LOGGER = logging.getLogger(__name__)
4646

4747

48-
@OperatorRegistry.register(task="text_generation")
48+
@OperatorRegistry.register(name="text_generation")
4949
class TextGenerationPipeline(Pipeline):
5050
def __init__(
5151
self,

0 commit comments

Comments
 (0)