Skip to content

Commit 03de3a5

Browse files
committed
move BasePipeline to a new file
1 parent 61bb610 commit 03de3a5

File tree

4 files changed

+386
-361
lines changed

4 files changed

+386
-361
lines changed

src/deepsparse/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .engine import *
3434
from .tasks import *
3535
from .pipeline import *
36+
from .base_pipeline import *
3637
from .loggers import *
3738
from .version import __version__, is_release
3839
from .analytics import deepsparse_analytics as _analytics

src/deepsparse/base_pipeline.py

+382
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from abc import ABC, abstractmethod
16+
from pathlib import Path
17+
from typing import Any, List, Optional, Type, Union
18+
19+
from pydantic import BaseModel
20+
21+
from deepsparse import Context
22+
from deepsparse.loggers.base_logger import BaseLogger
23+
from deepsparse.loggers.build_logger import logger_from_config
24+
from deepsparse.loggers.constants import validate_identifier
25+
from deepsparse.tasks import SupportedTasks, dynamic_import_task
26+
27+
28+
__all__ = [
29+
"BasePipeline",
30+
]
31+
32+
_REGISTERED_PIPELINES = {}
33+
34+
35+
class BasePipeline(ABC):
36+
"""
37+
Generic BasePipeline abstract class meant to wrap inference objects to include
38+
model-specific Pipelines objects. Any pipeline inherited from Pipeline objects
39+
should handle all model-specific input/output pre/post processing while BasePipeline
40+
is meant to serve as a generic wrapper. Inputs and outputs of BasePipelines should
41+
be serialized as pydantic Models.
42+
43+
BasePipelines should not be instantiated by their constructors, but rather the
44+
`BasePipeline.create()` method. The task name given to `create` will be used to
45+
load the appropriate pipeline. The pipeline should inherit from `BasePipeline` and
46+
implement the `__call__`, `input_schema`, and `output_schema` abstract methods.
47+
48+
Finally, the class definition should be decorated by the `BasePipeline.register`
49+
function. This defines the task name and task aliases for the pipeline and
50+
ensures that it will be accessible by `BasePipeline.create`. The implemented
51+
`BasePipeline` subclass must be imported at runtime to be accessible.
52+
53+
Example:
54+
@BasePipeline.register(task="base_example")
55+
class BasePipelineExample(BasePipeline):
56+
def __init__(self, base_specific, **kwargs):
57+
self._base_specific = base_specific
58+
self.model_pipeline = Pipeline.create(task="..")
59+
super().__init__(**kwargs)
60+
# implementation of abstract methods
61+
62+
:param alias: optional name to give this pipeline instance, useful when
63+
inferencing with multiple models. Default is None
64+
:param logger: An optional item that can be either a DeepSparse Logger object,
65+
or an object that can be transformed into one. Those object can be either
66+
a path to the logging config, or yaml string representation the logging
67+
config. If logger provided (in any form), the pipeline will log inference
68+
metrics to the logger. Default is None
69+
70+
"""
71+
72+
def __init__(
73+
self,
74+
alias: Optional[str] = None,
75+
logger: Optional[Union[BaseLogger, str]] = None,
76+
):
77+
78+
self._alias = alias
79+
self.logger = (
80+
logger
81+
if isinstance(logger, BaseLogger)
82+
else (
83+
logger_from_config(
84+
config=logger, pipeline_identifier=self._identifier()
85+
)
86+
if isinstance(logger, str)
87+
else None
88+
)
89+
)
90+
91+
@abstractmethod
92+
def __call__(self, *args, **kwargs) -> BaseModel:
93+
"""
94+
Runner function needed to stitch together any parsing, preprocessing, engine,
95+
and post-processing steps.
96+
97+
:returns: pydantic model class that outputs of this pipeline must comply to
98+
"""
99+
raise NotImplementedError()
100+
101+
@property
102+
@abstractmethod
103+
def input_schema(self) -> Type[BaseModel]:
104+
"""
105+
:return: pydantic model class that inputs to this pipeline must comply to
106+
"""
107+
raise NotImplementedError()
108+
109+
@property
110+
@abstractmethod
111+
def output_schema(self) -> Type[BaseModel]:
112+
"""
113+
:return: pydantic model class that outputs of this pipeline must comply to
114+
"""
115+
raise NotImplementedError()
116+
117+
@staticmethod
118+
def _get_task_constructor(task: str) -> Type["BasePipeline"]:
119+
"""
120+
This function retrieves the class previously registered via
121+
`BasePipeline.register` or `Pipeline.register` for `task`.
122+
123+
If `task` starts with "import:", it is treated as a module to be imported,
124+
and retrieves the task via the `TASK` attribute of the imported module.
125+
126+
If `task` starts with "custom", then it is mapped to the "custom" task.
127+
128+
:param task: The task name to get the constructor for
129+
:return: The class registered to `task`
130+
:raises ValueError: if `task` was not registered via `Pipeline.register`.
131+
"""
132+
if task.startswith("import:"):
133+
# dynamically import the task from a file
134+
task = dynamic_import_task(module_or_path=task.replace("import:", ""))
135+
elif task.startswith("custom"):
136+
# support any task that has "custom" at the beginning via the "custom" task
137+
task = "custom"
138+
else:
139+
task = task.lower().replace("-", "_")
140+
141+
# extra step to register pipelines for a given task domain
142+
# for cases where imports should only happen once a user specifies
143+
# that domain is to be used. (ie deepsparse.transformers will auto
144+
# install extra packages so should only import and register once a
145+
# transformers task is specified)
146+
SupportedTasks.check_register_task(task, _REGISTERED_PIPELINES.keys())
147+
148+
if task not in _REGISTERED_PIPELINES:
149+
raise ValueError(
150+
f"Unknown Pipeline task {task}. Pipeline tasks should be "
151+
"must be declared with the Pipeline.register decorator. Currently "
152+
f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}"
153+
)
154+
155+
return _REGISTERED_PIPELINES[task]
156+
157+
@staticmethod
158+
def create(
159+
task: str,
160+
**kwargs,
161+
) -> "BasePipeline":
162+
"""
163+
:param task: name of task to create a pipeline for. Use "custom" for
164+
custom tasks (see `CustomTaskPipeline`).
165+
:param kwargs: extra task specific kwargs to be passed to task Pipeline
166+
implementation
167+
:return: pipeline object initialized for the given task
168+
"""
169+
from deepsparse.pipeline import Bucketable, BucketingPipeline, Pipeline
170+
171+
pipeline_constructor = BasePipeline._get_task_constructor(task)
172+
model_path = kwargs.get("model_path", None)
173+
174+
if issubclass(pipeline_constructor, Pipeline):
175+
if (
176+
(model_path is None or model_path == "default")
177+
and hasattr(pipeline_constructor, "default_model_path")
178+
and pipeline_constructor.default_model_path
179+
):
180+
model_path = pipeline_constructor.default_model_path
181+
182+
if model_path is None:
183+
raise ValueError(
184+
f"No model_path provided for pipeline {pipeline_constructor}. Must "
185+
"provide a model path for pipelines that do not have a default "
186+
"defined"
187+
)
188+
189+
kwargs["model_path"] = model_path
190+
191+
if issubclass(
192+
pipeline_constructor, Bucketable
193+
) and pipeline_constructor.should_bucket(**kwargs):
194+
if kwargs.get("input_shape", None):
195+
raise ValueError(
196+
"Overriding input shapes not supported with Bucketing enabled"
197+
)
198+
if not kwargs.get("context", None):
199+
context = Context(num_cores=kwargs["num_cores"])
200+
kwargs["context"] = context
201+
buckets = pipeline_constructor.create_pipeline_buckets(
202+
task=task,
203+
**kwargs,
204+
)
205+
return BucketingPipeline(pipelines=buckets)
206+
207+
return pipeline_constructor(**kwargs)
208+
209+
@classmethod
210+
def register(
211+
cls,
212+
task: str,
213+
task_aliases: Optional[List[str]] = None,
214+
default_model_path: Optional[str] = None,
215+
):
216+
"""
217+
Pipeline implementer class decorator that registers the pipeline
218+
task name and its aliases as valid tasks that can be used to load
219+
the pipeline through `BasePipeline.create()` or `Pipeline.create()`
220+
221+
Multiple pipelines may not have the same task name. An error will
222+
be raised if two different pipelines attempt to register the same task name
223+
224+
:param task: main task name of this pipeline
225+
:param task_aliases: list of extra task names that may be used to reference
226+
this pipeline. Default is None
227+
:param default_model_path: path (ie zoo stub) to use as default for this
228+
task if None is provided
229+
"""
230+
task_names = [task]
231+
if task_aliases:
232+
task_names.extend(task_aliases)
233+
234+
task_names = [task_name.lower().replace("-", "_") for task_name in task_names]
235+
236+
def _register_task(task_name, pipeline_class):
237+
if task_name in _REGISTERED_PIPELINES and (
238+
pipeline_class is not _REGISTERED_PIPELINES[task_name]
239+
):
240+
raise RuntimeError(
241+
f"task {task_name} already registered by BasePipeline.register. "
242+
f"attempting to register pipeline: {pipeline_class}, but"
243+
f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered"
244+
)
245+
_REGISTERED_PIPELINES[task_name] = pipeline_class
246+
247+
def _register_pipeline_tasks_decorator(pipeline_class: BasePipeline):
248+
if not issubclass(pipeline_class, cls):
249+
raise RuntimeError(
250+
f"Attempting to register pipeline {pipeline_class}. "
251+
f"Registered pipelines must inherit from {cls}"
252+
)
253+
for task_name in task_names:
254+
_register_task(task_name, pipeline_class)
255+
256+
# set task and task_aliases as class level property
257+
pipeline_class.task = task
258+
pipeline_class.task_aliases = task_aliases
259+
pipeline_class.default_model_path = default_model_path
260+
261+
return pipeline_class
262+
263+
return _register_pipeline_tasks_decorator
264+
265+
@classmethod
266+
def from_config(
267+
cls,
268+
config: Union["PipelineConfig", str, Path], # noqa: F821
269+
logger: Optional[BaseLogger] = None,
270+
) -> "BasePipeline":
271+
"""
272+
:param config: PipelineConfig object, filepath to a json serialized
273+
PipelineConfig, or raw string of a json serialized PipelineConfig
274+
:param logger: An optional DeepSparse Logger object for inference
275+
logging. Default is None
276+
:return: loaded Pipeline object from the config
277+
"""
278+
from deepsparse.pipeline import PipelineConfig
279+
280+
if isinstance(config, Path) or (
281+
isinstance(config, str) and os.path.exists(config)
282+
):
283+
if isinstance(config, str):
284+
config = Path(config)
285+
config = PipelineConfig.parse_file(config)
286+
if isinstance(config, str):
287+
config = PipelineConfig.parse_raw(config)
288+
289+
return cls.create(
290+
task=config.task,
291+
alias=config.alias,
292+
logger=logger,
293+
**config.kwargs,
294+
)
295+
296+
@property
297+
def alias(self) -> str:
298+
"""
299+
:return: optional name to give this pipeline instance, useful when
300+
inferencing with multiple models
301+
"""
302+
return self._alias
303+
304+
def to_config(self) -> "PipelineConfig": # noqa: F821
305+
"""
306+
:return: PipelineConfig that can be used to reload this object
307+
"""
308+
from deepsparse.pipeline import PipelineConfig
309+
310+
if not hasattr(self, "task"):
311+
raise RuntimeError(
312+
f"{self.__class__} instance has no attribute task. Pipeline objects "
313+
"must have a task to be serialized to a config. Pipeline objects "
314+
"must be declared with the Pipeline.register object to be assigned a "
315+
"task"
316+
)
317+
318+
# parse any additional properties as kwargs
319+
kwargs = {}
320+
for attr_name, attr in self.__class__.__dict__.items():
321+
if isinstance(attr, property) and attr_name not in dir(PipelineConfig):
322+
kwargs[attr_name] = getattr(self, attr_name)
323+
324+
return PipelineConfig(
325+
task=self.task,
326+
alias=self.alias,
327+
kwargs=kwargs,
328+
)
329+
330+
def log(
331+
self,
332+
identifier: str,
333+
value: Any,
334+
category: str,
335+
):
336+
"""
337+
Pass the logged data to the DeepSparse logger object (if present).
338+
339+
:param identifier: The string name assigned to the logged value
340+
:param value: The logged data structure
341+
:param category: The metric category that the log belongs to
342+
"""
343+
if not self.logger:
344+
return
345+
346+
identifier = f"{self._identifier()}/{identifier}"
347+
validate_identifier(identifier)
348+
self.logger.log(
349+
identifier=identifier,
350+
value=value,
351+
category=category,
352+
pipeline_name=self._identifier(),
353+
)
354+
return
355+
356+
def parse_inputs(self, *args, **kwargs) -> BaseModel:
357+
"""
358+
:param args: ordered arguments to pipeline, only an input_schema object
359+
is supported as an arg for this function
360+
:param kwargs: keyword arguments to pipeline
361+
:return: pipeline arguments parsed into the given `input_schema`
362+
schema if necessary. If an instance of the `input_schema` is provided
363+
it will be returned
364+
"""
365+
# passed input_schema schema directly
366+
if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs:
367+
return args[0]
368+
369+
if args:
370+
raise ValueError(
371+
f"pipeline {self.__class__} only supports either only a "
372+
f"{self.input_schema} object. or keyword arguments to be construct "
373+
f"one. Found {len(args)} args and {len(kwargs)} kwargs"
374+
)
375+
376+
return self.input_schema(**kwargs)
377+
378+
def _identifier(self):
379+
# get pipeline identifier; used in the context of logging
380+
if not hasattr(self, "task"):
381+
self.task = None
382+
return f"{self.alias or self.task or 'unknown_pipeline'}"

0 commit comments

Comments
 (0)