Skip to content

Commit bb3ff41

Browse files
authored
[Pipeline Refactor] Operator Registry (#1420)
* initial registry functionality * use sparsezoo mixin
1 parent c858b1f commit bb3ff41

File tree

6 files changed

+299
-0
lines changed

6 files changed

+299
-0
lines changed

src/deepsparse/v2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from .pipeline import *
1919
from .routers import *
2020
from .schedulers import *
21+
from .task import *
2122
from .utils import *

src/deepsparse/v2/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# limitations under the License.
1717
from .operator import *
1818
from .engine_operator import *
19+
from .registry import *

src/deepsparse/v2/operators/operator.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pydantic import BaseModel
1919

2020
from deepsparse.v2.utils import InferenceState
21+
from deepsparse.v2.operators.registry import OperatorRegistry
2122

2223

2324
__all__ = ["Operator"]
@@ -100,6 +101,20 @@ def __call__(
100101
return self.output_schema(**run_output)
101102
return run_output
102103

104+
@staticmethod
105+
def create(
106+
task: str,
107+
**kwargs,
108+
) -> "Operator":
109+
"""
110+
:param task: Operator task
111+
:param kwargs: extra task specific kwargs to be passed to task Operator
112+
implementation
113+
:return: operator object initialized for the given task
114+
"""
115+
operator_constructor = OperatorRegistry.get_task_constructor(task)
116+
return operator_constructor(**kwargs)
117+
103118
@abstractmethod
104119
def run(self, *args, **kwargs) -> Any:
105120
"""
+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Type
16+
17+
from deepsparse.v2.task import SupportedTasks, dynamic_import_task
18+
from sparsezoo.utils.registry import (
19+
RegistryMixin,
20+
get_from_registry,
21+
register,
22+
registered_names,
23+
)
24+
25+
26+
__all__ = ["OperatorRegistry"]
27+
28+
29+
class OperatorRegistry(RegistryMixin):
30+
"""
31+
Register operators with given task name(s). Leverages the RegistryMixin
32+
functionality.
33+
"""
34+
35+
@classmethod
36+
def register_value(cls, operator, name):
37+
from deepsparse.v2.operators import Operator
38+
39+
if not isinstance(name, list):
40+
name = [name]
41+
42+
for task_name in name:
43+
register(Operator, operator, task_name, require_subclass=True)
44+
45+
return operator
46+
47+
@classmethod
48+
def get_task_constructor(cls, task: str) -> Type["Operator"]: # noqa: F821
49+
"""
50+
This function retrieves the class previously registered via
51+
`OperatorRegistry.register` for `task`.
52+
53+
If `task` starts with "import:", it is treated as a module to be imported,
54+
and retrieves the task via the `TASK` attribute of the imported module.
55+
56+
If `task` starts with "custom", then it is mapped to the "custom" task.
57+
58+
:param task: The task name to get the constructor for
59+
:return: The class registered to `task`
60+
:raises ValueError: if `task` was not registered via `OperatorRegistry.register`
61+
"""
62+
from deepsparse.v2.operators import Operator
63+
64+
if task.startswith("import:"):
65+
# dynamically import the task from a file
66+
task = dynamic_import_task(module_or_path=task.replace("import:", ""))
67+
elif task.startswith("custom"):
68+
# support any task that has "custom" at the beginning via the "custom" task
69+
task = "custom"
70+
else:
71+
task = task.lower().replace("-", "_")
72+
73+
tasks = registered_names(Operator)
74+
# step needed to import relevant files required to load the operator
75+
SupportedTasks.check_register_task(task, tasks)
76+
return get_from_registry(Operator, task)

src/deepsparse/v2/task.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Classes and implementations for supported tasks in the DeepSparse pipeline and system
17+
"""
18+
19+
import importlib
20+
import logging
21+
import os
22+
import sys
23+
from collections import namedtuple
24+
from typing import Iterable, List, Optional, Tuple
25+
26+
27+
_LOGGER = logging.getLogger(__name__)
28+
29+
__all__ = ["SupportedTasks", "AliasedTask"]
30+
31+
32+
class AliasedTask:
33+
"""
34+
A task that can have multiple aliases to match to.
35+
For example, question_answering which can alias to qa as well
36+
37+
:param name: the name of the task such as question_answering or text_classification
38+
:param aliases: the aliases the task can go by in addition to the name such as
39+
qa, glue, sentiment_analysis, etc
40+
"""
41+
42+
def __init__(self, name: str, aliases: List[str]):
43+
self._name = name
44+
self._aliases = aliases
45+
46+
@property
47+
def name(self) -> str:
48+
"""
49+
:return: the name of the task such as question_answering
50+
"""
51+
return self._name
52+
53+
@property
54+
def aliases(self) -> List[str]:
55+
"""
56+
:return: the aliases the task can go by such as qa, glue, sentiment_analysis
57+
"""
58+
return self._aliases
59+
60+
def matches(self, task: str) -> bool:
61+
"""
62+
:param task: the name of the task to check whether the given instance matches.
63+
Checks the current name as well as any aliases.
64+
Everything is compared at lower case and "-" and whitespace
65+
are replaced with "_".
66+
:return: True if task does match the current instance, False otherwise
67+
"""
68+
task = task.lower().replace("-", "_")
69+
70+
# replace whitespace with "_"
71+
task = "_".join(task.split())
72+
73+
return task == self.name or task in self.aliases
74+
75+
76+
class SupportedTasks:
77+
"""
78+
The supported tasks in the DeepSparse pipeline and system
79+
"""
80+
81+
text_generation = namedtuple(
82+
"text_generation", ["text_generation", "opt", "bloom"]
83+
)(
84+
text_generation=AliasedTask("text_generation", []),
85+
opt=AliasedTask("opt", []),
86+
bloom=AliasedTask("bloom", []),
87+
)
88+
89+
all_task_categories = [text_generation]
90+
91+
@classmethod
92+
def check_register_task(
93+
cls, task: str, extra_tasks: Optional[Iterable[str]] = None
94+
):
95+
"""
96+
:param task: task name to validate and import dependencies for
97+
:param extra_tasks: valid task names that are not included in supported tasks.
98+
i.e. tasks registered to Pipeline at runtime
99+
"""
100+
if cls.is_text_generation(task):
101+
import deepsparse.v2.text_generation.pipeline # noqa: F401
102+
103+
all_tasks = set(cls.task_names() + (list(extra_tasks or [])))
104+
if task not in all_tasks:
105+
raise ValueError(
106+
f"Unknown Pipeline task {task}. Currently supported tasks are "
107+
f"{list(all_tasks)}"
108+
)
109+
110+
@classmethod
111+
def is_text_generation(cls, task: str) -> bool:
112+
"""
113+
:param task: the name of the task to check whether it is a text generation task
114+
such as codegen
115+
:return: True if it is a text generation task, False otherwise
116+
"""
117+
return any(
118+
text_generation_task.matches(task)
119+
for text_generation_task in cls.text_generation
120+
)
121+
122+
@classmethod
123+
def task_names(cls):
124+
task_names = ["custom"]
125+
for task_category in cls.all_task_categories:
126+
for task in task_category:
127+
unique_aliases = (
128+
alias for alias in task._aliases if alias != task._name
129+
)
130+
task_names += (task._name, *unique_aliases)
131+
return task_names
132+
133+
134+
def dynamic_import_task(module_or_path: str) -> str:
135+
"""
136+
Dynamically imports `module` with importlib, and returns the `TASK`
137+
attribute on the module (something like `importlib.import_module(module).TASK`).
138+
139+
Example contents of `module`:
140+
```python
141+
from deepsparse.pipeline import Pipeline
142+
from deepsparse.transformers.pipelines.question_answering import (
143+
QuestionAnsweringPipeline,
144+
)
145+
146+
TASK = "my_qa_task"
147+
Pipeline.register(TASK)(QuestionAnsweringPipeline)
148+
```
149+
150+
NOTE: this modifies `sys.path`.
151+
152+
:raises FileNotFoundError: if path does not exist
153+
:raises RuntimeError: if the imported module does not contain `TASK`
154+
:raises RuntimeError: if the module doesn't register the task
155+
:return: The task from the imported module.
156+
"""
157+
parent_dir, module_name = _split_dir_and_name(module_or_path)
158+
if not os.path.exists(os.path.join(parent_dir, module_name + ".py")):
159+
raise FileNotFoundError(
160+
f"Unable to find file for {module_or_path}. "
161+
f"Looked for {module_name}.py under {parent_dir if parent_dir else '.'}"
162+
)
163+
164+
# add parent_dir to sys.path so we can import the file as a module
165+
sys.path.append(os.curdir)
166+
if parent_dir:
167+
_LOGGER.info(f"Adding {parent_dir} to sys.path")
168+
sys.path.append(parent_dir)
169+
170+
# do the import
171+
_LOGGER.info(f"Importing '{module_name}'")
172+
module_or_path = importlib.import_module(module_name)
173+
174+
if not hasattr(module_or_path, "TASK"):
175+
raise RuntimeError(
176+
"When using --task import:<module>, "
177+
"module must set the `TASK` attribute."
178+
)
179+
180+
task = getattr(module_or_path, "TASK")
181+
_LOGGER.info(f"Using task={repr(task)}")
182+
183+
return task
184+
185+
186+
def _split_dir_and_name(module_or_path: str) -> Tuple[str, str]:
187+
"""
188+
Examples:
189+
- `a` -> `("", "a")`
190+
- `a.b` -> `("a", "b")`
191+
- `a.b.c` -> `("a/b", "c")`
192+
193+
:return: module split into directory & name
194+
"""
195+
if module_or_path.endswith(".py"):
196+
# assume path
197+
split_char = os.sep
198+
module_or_path = module_or_path.replace(".py", "")
199+
else:
200+
# assume module
201+
split_char = "."
202+
*dirs, module_name = module_or_path.split(split_char)
203+
parent_dir = os.sep if dirs == [""] else os.sep.join(dirs)
204+
return parent_dir, module_name

src/deepsparse/v2/text_generation/pipeline.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from deepsparse.transformers.utils.helpers import process_generation_config
2020
from deepsparse.utils import split_engine_inputs
2121
from deepsparse.v2.operators import EngineOperator
22+
from deepsparse.v2.operators.registry import OperatorRegistry
2223
from deepsparse.v2.pipeline import Pipeline
2324
from deepsparse.v2.routers import GraphRouter
2425
from deepsparse.v2.schedulers import ContinuousBatchingScheduler, OperatorScheduler
@@ -44,6 +45,7 @@
4445
_LOGGER = logging.getLogger(__name__)
4546

4647

48+
@OperatorRegistry.register(name="text_generation")
4749
class TextGenerationPipeline(Pipeline):
4850
def __init__(
4951
self,

0 commit comments

Comments
 (0)