Skip to content

Commit bb6135d

Browse files
committed
initial registry functionality
1 parent c858b1f commit bb6135d

File tree

6 files changed

+331
-0
lines changed

6 files changed

+331
-0
lines changed

src/deepsparse/v2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from .pipeline import *
1919
from .routers import *
2020
from .schedulers import *
21+
from .task import *
2122
from .utils import *

src/deepsparse/v2/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# limitations under the License.
1717
from .operator import *
1818
from .engine_operator import *
19+
from .registry import *

src/deepsparse/v2/operators/operator.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pydantic import BaseModel
1919

2020
from deepsparse.v2.utils import InferenceState
21+
from deepsparse.v2.operators.registry import OperatorRegistry
2122

2223

2324
__all__ = ["Operator"]
@@ -100,6 +101,20 @@ def __call__(
100101
return self.output_schema(**run_output)
101102
return run_output
102103

104+
@staticmethod
105+
def create(
106+
task: str,
107+
**kwargs,
108+
) -> "Operator":
109+
"""
110+
:param task: Operator task
111+
:param kwargs: extra task specific kwargs to be passed to task Operator
112+
implementation
113+
:return: operator object initialized for the given task
114+
"""
115+
operator_constructor = OperatorRegistry.get_task_constructor(task)
116+
return operator_constructor(**kwargs)
117+
103118
@abstractmethod
104119
def run(self, *args, **kwargs) -> Any:
105120
"""
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Type
16+
17+
from deepsparse.v2.task import SupportedTasks, dynamic_import_task
18+
19+
20+
_REGISTERED_OPERATORS = {}
21+
22+
__all__ = ["OperatorRegistry"]
23+
24+
25+
class OperatorRegistry:
26+
def register(task: str, task_aliases: Optional[List[str]] = None):
27+
from deepsparse.v2.operators import Operator
28+
29+
"""
30+
Decorator to register an operator with its task name and its aliases. The
31+
registered names can be used to load the operator through `Operator.create()`
32+
33+
Multiple operators may not have the same task name. An error will
34+
be raised if two different operators attempt to register the same task name
35+
36+
:param task: main task name of this operator
37+
:param task_aliases: list of extra task names that may be used to reference
38+
this operator. Default is None
39+
"""
40+
task_names = [task]
41+
if task_aliases:
42+
task_names.extend(task_aliases)
43+
44+
task_names = [task_name.lower().replace("-", "_") for task_name in task_names]
45+
46+
def _register_task(task_name, operator):
47+
if task_name in _REGISTERED_OPERATORS and (
48+
operator is not _REGISTERED_OPERATORS[task_name]
49+
):
50+
raise RuntimeError(
51+
f"task {task_name} already registered by OperatorRegistry.register "
52+
f"attempting to register operator: {operator}, but"
53+
f"operator: {_REGISTERED_OPERATORS[task_name]}, already registered"
54+
)
55+
_REGISTERED_OPERATORS[task_name] = operator
56+
57+
def _register_operator(operator: Operator):
58+
if not issubclass(operator, Operator):
59+
raise RuntimeError(
60+
f"Attempting to register operator {operator}. "
61+
f"Registered operators must inherit from {Operator}"
62+
)
63+
for task_name in task_names:
64+
_register_task(task_name, operator)
65+
66+
# set task and task_aliases as class level property
67+
operator.task = task
68+
operator.task_aliases = task_aliases
69+
70+
return operator
71+
72+
return _register_operator
73+
74+
@staticmethod
75+
def get_task_constructor(task: str) -> Type["Operator"]:
76+
"""
77+
This function retrieves the class previously registered via
78+
`OperatorRegistry.register` for `task`.
79+
80+
If `task` starts with "import:", it is treated as a module to be imported,
81+
and retrieves the task via the `TASK` attribute of the imported module.
82+
83+
If `task` starts with "custom", then it is mapped to the "custom" task.
84+
85+
:param task: The task name to get the constructor for
86+
:return: The class registered to `task`
87+
:raises ValueError: if `task` was not registered via `OperatorRegistry.register`
88+
"""
89+
if task.startswith("import:"):
90+
# dynamically import the task from a file
91+
task = dynamic_import_task(module_or_path=task.replace("import:", ""))
92+
elif task.startswith("custom"):
93+
# support any task that has "custom" at the beginning via the "custom" task
94+
task = "custom"
95+
else:
96+
task = task.lower().replace("-", "_")
97+
98+
# step needed to import relevant files required to load the operator
99+
SupportedTasks.check_register_task(task, _REGISTERED_OPERATORS.keys())
100+
101+
if task not in _REGISTERED_OPERATORS:
102+
raise ValueError(
103+
f"Unknown Operator task {task}. Operators tasks should be "
104+
"must be declared with the OperatorRegistry.register decorator. "
105+
f"Currently registered operators: {list(_REGISTERED_OPERATORS.keys())}"
106+
)
107+
108+
return _REGISTERED_OPERATORS[task]

src/deepsparse/v2/task.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Classes and implementations for supported tasks in the DeepSparse pipeline and system
17+
"""
18+
19+
import importlib
20+
import logging
21+
import os
22+
import sys
23+
from collections import namedtuple
24+
from typing import Iterable, List, Optional, Tuple
25+
26+
27+
_LOGGER = logging.getLogger(__name__)
28+
29+
__all__ = ["SupportedTasks", "AliasedTask"]
30+
31+
32+
class AliasedTask:
33+
"""
34+
A task that can have multiple aliases to match to.
35+
For example, question_answering which can alias to qa as well
36+
37+
:param name: the name of the task such as question_answering or text_classification
38+
:param aliases: the aliases the task can go by in addition to the name such as
39+
qa, glue, sentiment_analysis, etc
40+
"""
41+
42+
def __init__(self, name: str, aliases: List[str]):
43+
self._name = name
44+
self._aliases = aliases
45+
46+
@property
47+
def name(self) -> str:
48+
"""
49+
:return: the name of the task such as question_answering
50+
"""
51+
return self._name
52+
53+
@property
54+
def aliases(self) -> List[str]:
55+
"""
56+
:return: the aliases the task can go by such as qa, glue, sentiment_analysis
57+
"""
58+
return self._aliases
59+
60+
def matches(self, task: str) -> bool:
61+
"""
62+
:param task: the name of the task to check whether the given instance matches.
63+
Checks the current name as well as any aliases.
64+
Everything is compared at lower case and "-" and whitespace
65+
are replaced with "_".
66+
:return: True if task does match the current instance, False otherwise
67+
"""
68+
task = task.lower().replace("-", "_")
69+
70+
# replace whitespace with "_"
71+
task = "_".join(task.split())
72+
73+
return task == self.name or task in self.aliases
74+
75+
76+
class SupportedTasks:
77+
"""
78+
The supported tasks in the DeepSparse pipeline and system
79+
"""
80+
81+
text_generation = namedtuple(
82+
"text_generation", ["text_generation", "opt", "bloom"]
83+
)(
84+
text_generation=AliasedTask("text_generation", []),
85+
opt=AliasedTask("opt", []),
86+
bloom=AliasedTask("bloom", []),
87+
)
88+
89+
all_task_categories = [text_generation]
90+
91+
@classmethod
92+
def check_register_task(
93+
cls, task: str, extra_tasks: Optional[Iterable[str]] = None
94+
):
95+
"""
96+
:param task: task name to validate and import dependencies for
97+
:param extra_tasks: valid task names that are not included in supported tasks.
98+
i.e. tasks registered to Pipeline at runtime
99+
"""
100+
if cls.is_text_generation(task):
101+
import deepsparse.v2.text_generation.pipeline # noqa: F401
102+
103+
all_tasks = set(cls.task_names() + (list(extra_tasks or [])))
104+
if task not in all_tasks:
105+
raise ValueError(
106+
f"Unknown Pipeline task {task}. Currently supported tasks are "
107+
f"{list(all_tasks)}"
108+
)
109+
110+
@classmethod
111+
def is_text_generation(cls, task: str) -> bool:
112+
"""
113+
:param task: the name of the task to check whether it is a text generation task
114+
such as codegen
115+
:return: True if it is a text generation task, False otherwise
116+
"""
117+
return any(
118+
text_generation_task.matches(task)
119+
for text_generation_task in cls.text_generation
120+
)
121+
122+
@classmethod
123+
def task_names(cls):
124+
task_names = ["custom"]
125+
for task_category in cls.all_task_categories:
126+
for task in task_category:
127+
unique_aliases = (
128+
alias for alias in task._aliases if alias != task._name
129+
)
130+
task_names += (task._name, *unique_aliases)
131+
return task_names
132+
133+
134+
def dynamic_import_task(module_or_path: str) -> str:
135+
"""
136+
Dynamically imports `module` with importlib, and returns the `TASK`
137+
attribute on the module (something like `importlib.import_module(module).TASK`).
138+
139+
Example contents of `module`:
140+
```python
141+
from deepsparse.pipeline import Pipeline
142+
from deepsparse.transformers.pipelines.question_answering import (
143+
QuestionAnsweringPipeline,
144+
)
145+
146+
TASK = "my_qa_task"
147+
Pipeline.register(TASK)(QuestionAnsweringPipeline)
148+
```
149+
150+
NOTE: this modifies `sys.path`.
151+
152+
:raises FileNotFoundError: if path does not exist
153+
:raises RuntimeError: if the imported module does not contain `TASK`
154+
:raises RuntimeError: if the module doesn't register the task
155+
:return: The task from the imported module.
156+
"""
157+
parent_dir, module_name = _split_dir_and_name(module_or_path)
158+
if not os.path.exists(os.path.join(parent_dir, module_name + ".py")):
159+
raise FileNotFoundError(
160+
f"Unable to find file for {module_or_path}. "
161+
f"Looked for {module_name}.py under {parent_dir if parent_dir else '.'}"
162+
)
163+
164+
# add parent_dir to sys.path so we can import the file as a module
165+
sys.path.append(os.curdir)
166+
if parent_dir:
167+
_LOGGER.info(f"Adding {parent_dir} to sys.path")
168+
sys.path.append(parent_dir)
169+
170+
# do the import
171+
_LOGGER.info(f"Importing '{module_name}'")
172+
module_or_path = importlib.import_module(module_name)
173+
174+
if not hasattr(module_or_path, "TASK"):
175+
raise RuntimeError(
176+
"When using --task import:<module>, "
177+
"module must set the `TASK` attribute."
178+
)
179+
180+
task = getattr(module_or_path, "TASK")
181+
_LOGGER.info(f"Using task={repr(task)}")
182+
183+
return task
184+
185+
186+
def _split_dir_and_name(module_or_path: str) -> Tuple[str, str]:
187+
"""
188+
Examples:
189+
- `a` -> `("", "a")`
190+
- `a.b` -> `("a", "b")`
191+
- `a.b.c` -> `("a/b", "c")`
192+
193+
:return: module split into directory & name
194+
"""
195+
if module_or_path.endswith(".py"):
196+
# assume path
197+
split_char = os.sep
198+
module_or_path = module_or_path.replace(".py", "")
199+
else:
200+
# assume module
201+
split_char = "."
202+
*dirs, module_name = module_or_path.split(split_char)
203+
parent_dir = os.sep if dirs == [""] else os.sep.join(dirs)
204+
return parent_dir, module_name

src/deepsparse/v2/text_generation/pipeline.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from deepsparse.transformers.utils.helpers import process_generation_config
2020
from deepsparse.utils import split_engine_inputs
2121
from deepsparse.v2.operators import EngineOperator
22+
from deepsparse.v2.operators.registry import OperatorRegistry
2223
from deepsparse.v2.pipeline import Pipeline
2324
from deepsparse.v2.routers import GraphRouter
2425
from deepsparse.v2.schedulers import ContinuousBatchingScheduler, OperatorScheduler
@@ -44,6 +45,7 @@
4445
_LOGGER = logging.getLogger(__name__)
4546

4647

48+
@OperatorRegistry.register(task="text_generation")
4749
class TextGenerationPipeline(Pipeline):
4850
def __init__(
4951
self,

0 commit comments

Comments
 (0)