Skip to content

Commit e1ff108

Browse files
authored
[Pipeline Refactor] Update routes, text generation initial functionality (#1348)
* initial functionality and working example with image classification * remove testing image * rebase fixes * initial functionality and working example with image classification * text gen * updates func * prompt inference, initial functionality * remove image; update state docstring * Fix typo * add todo for split/join * remove context, clean-up args, remove prefill_preprocess_operaator * fix docstrings
1 parent 58b0758 commit e1ff108

19 files changed

+1203
-55
lines changed

src/deepsparse/v2/operators/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
1716
from .operator import *

src/deepsparse/v2/operators/engine_operator.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
from pydantic import BaseModel, Field
1919

20-
from deepsparse import Context, Engine, MultiModelEngine, Scheduler
20+
from deepsparse import Context as EngineContext
21+
from deepsparse import Engine, MultiModelEngine, Scheduler
2122
from deepsparse.benchmark import ORTEngine
2223
from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs
2324
from deepsparse.v2.operators import Operator
@@ -54,16 +55,15 @@ def __init__(
5455
self,
5556
model_path: str,
5657
engine_type: str = DEEPSPARSE_ENGINE,
57-
batch_size: Optional[int] = 1,
5858
num_cores: int = None,
5959
num_streams: int = None,
6060
scheduler: Scheduler = None,
6161
input_shapes: List[List[int]] = None,
62-
engine_context: Optional[Context] = None,
62+
engine_context: Optional[EngineContext] = None,
63+
engine_kwargs: Dict = None,
6364
):
64-
65-
self._batch_size = batch_size
6665
self.model_path = model_to_path(model_path)
66+
self._batch_size = 1
6767
self.engine_context = engine_context
6868

6969
if self.engine_context is not None:
@@ -87,7 +87,7 @@ def __init__(
8787
self._engine_args = engine_args
8888
self._engine_type = engine_type
8989

90-
self.engine = self.create_engine()
90+
self.engine = self.create_engine(**engine_kwargs)
9191

9292
@property
9393
def batch_size(self) -> int:
@@ -114,12 +114,12 @@ def create_engine(
114114

115115
if engine_type == DEEPSPARSE_ENGINE:
116116
if self.engine_context is not None and isinstance(
117-
self.engine_context, Context
117+
self.engine_context, EngineContext
118118
):
119119
engine_args.pop("num_cores", None)
120120
engine_args.pop("scheduler", None)
121121
engine_args.pop("num_streams", None)
122-
engine_args["context"] = self.engien_context
122+
engine_args["context"] = self.engine_context
123123
return MultiModelEngine(
124124
model=onnx_file_path,
125125
**engine_args,
@@ -135,7 +135,7 @@ def create_engine(
135135
f"{SUPPORTED_PIPELINE_ENGINES}"
136136
)
137137

138-
def run(self, inp: EngineOperatorInputs) -> Dict:
138+
def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict:
139139
if inp.engine:
140140
# run with custom engine, do not split/join since custom engine
141141
# may run at any batch size, returning here as code below has a

src/deepsparse/v2/operators/operator.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from pydantic import BaseModel
1919

20+
from deepsparse.v2.utils import InferenceState, PipelineState
21+
2022

2123
__all__ = ["Operator"]
2224

@@ -54,14 +56,18 @@ def has_output_schema(cls) -> bool:
5456
def __call__(
5557
self,
5658
*args,
59+
inference_state: InferenceState,
60+
pipeline_state: PipelineState,
5761
**kwargs,
5862
) -> Any:
5963
"""
6064
Parses inputs to this Operator and runs the run() method of this operator
6165
6266
:param args: an unnamed arg may only be provided if it is of the type of the
6367
input_schema
64-
:param context: pipeline context to pass to operator
68+
:param inference_state: inference_state for the pipeline.
69+
:param pipeline_state: pipeline_state for the pipeline. The values in the state
70+
are created during pipeline creation and are read-only during inference.
6571
:param kwargs: kwargs when not initializing from an instantiated schema
6672
:return: operator output
6773
"""
@@ -81,10 +87,18 @@ def __call__(
8187
"in the form of a dictionary or an instance of the input_schema"
8288
"object"
8389
)
84-
85-
run_output = self.run(inference_input)
90+
run_output = self.run(
91+
inference_input,
92+
inference_state=inference_state,
93+
pipeline_state=pipeline_state,
94+
)
8695
else:
87-
run_output = self.run(*args, **kwargs)
96+
run_output = self.run(
97+
*args,
98+
inference_state=inference_state,
99+
pipeline_state=pipeline_state,
100+
**kwargs,
101+
)
88102

89103
if self.has_output_schema():
90104
return self.output_schema(**run_output)
@@ -93,12 +107,16 @@ def __call__(
93107
@abstractmethod
94108
def run(self, *args, **kwargs) -> Any:
95109
"""
96-
:param inp: operator input, as the defined input schema if applicable
97-
:param context: pipeline context of already run operators
98110
:return: result of this operator as the defined output schema if applicable
99111
"""
100112
raise NotImplementedError
101113

114+
def can_operate(self, inp: Any) -> bool:
115+
"""
116+
Whether or not the given operator can run, based on input
117+
"""
118+
return True
119+
102120
def expand_inputs(self, **kwargs):
103121
"""
104122
Generic function to handle expanding values.

src/deepsparse/v2/pipeline.py

+54-16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from deepsparse.v2.operators import Operator
1919
from deepsparse.v2.routers import Router
2020
from deepsparse.v2.schedulers import OperatorScheduler, SchedulerGroup
21+
from deepsparse.v2.utils import InferenceState, PipelineState
2122

2223

2324
__all__ = ["Pipeline"]
@@ -27,7 +28,7 @@ class Pipeline(Operator):
2728
"""
2829
Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline
2930
will use the router to run through all the defined operators. The operators should
30-
be implemented using the Operator class and each implemented Operator should be
31+
be implemented using the Operator class and each implemented operator should be
3132
responsible for a functional component of the pipelines. The flow of inputs/outputs
3233
between the operators and the steps in the pipeline should be defined by the router,
3334
(based off of the Router class), which dicates the next operator in the pipeline.
@@ -37,6 +38,7 @@ class Pipeline(Operator):
3738
or dictionary of operators.
3839
:param router: A Router which dictates the next operator to call.
3940
:param schedulers: A list of schedulers to run operators.
41+
:param pipeline_state: pipeline_state created during pipeline initialization
4042
4143
"""
4244

@@ -45,57 +47,93 @@ def __init__(
4547
ops: Union[Dict[str, Operator], List[Operator]],
4648
router: Router,
4749
schedulers: List[OperatorScheduler],
50+
pipeline_state: PipelineState = None,
4851
):
4952

5053
self.ops = ops
5154
self.router = router
5255
self.schedulers = schedulers
56+
self.pipeline_state = pipeline_state
5357
self.validate()
5458

5559
# SchedulerGroup handles running all schedulers in order of priority
5660
self._scheduler_group = SchedulerGroup(self.schedulers)
5761

58-
def run(self, *args, **kwargs):
62+
def run(
63+
self,
64+
*args,
65+
inference_state: InferenceState,
66+
pipeline_state: PipelineState,
67+
**kwargs,
68+
):
5969
"""
60-
Run through the operators using the provided router and scheduler. Update the
61-
context to reflect each step of the router. The input to a given operator is the
62-
output of the previous operator.
63-
64-
:param inp: input to the operator. expected to be of any type that is
65-
expected by the operator.
66-
:param context: context to store the current the inputs, outputs, and operator
67-
for each step of the router.
70+
Run through the operators using the provided router and scheduler.
71+
The input to a given operator is the output of the previous operator.
6872
73+
:param inference_state: inference_state for the pipeline.
74+
:param pipeline_state: pipeline_state for the pipeline. The values in the state
75+
are created during pipeline creation and are read-only during inference.
6976
"""
7077
next_step = self.router.START_ROUTE
7178
operator_output = None
79+
7280
while next_step != self.router.END_ROUTE:
7381
# Either a dictionary key or valid index
7482
operator = self.ops[next_step]
7583
if next_step == self.router.START_ROUTE:
7684
output_future = self._scheduler_group.submit(
77-
*args, operator=operator, **kwargs
85+
*args,
86+
inference_state=inference_state,
87+
operator=operator,
88+
pipeline_state=pipeline_state,
89+
**kwargs,
7890
)
7991
else:
8092
if isinstance(operator_output, dict):
8193
output_future = self._scheduler_group.submit(
82-
operator=operator, **operator_output
94+
inference_state=inference_state,
95+
operator=operator,
96+
pipeline_state=pipeline_state,
97+
**operator_output,
8398
)
8499
else:
85100
output_future = self._scheduler_group.submit(
86-
operator_output, operator=operator
101+
operator_output,
102+
inference_state=inference_state,
103+
pipeline_state=pipeline_state,
104+
operator=operator,
87105
)
88106

89-
# wait for future to resolve
90107
operator_output = output_future.result()
91-
next_step = self.router.next(next_step, self.ops)
108+
if isinstance(operator_output, tuple):
109+
state_update = operator_output[-1]
110+
operator_output = operator_output[0]
111+
inference_state.update_state(state_update)
112+
113+
next_step = self.router.next(next_step, self.ops, operator_output)
114+
92115
return operator_output
93116

94117
def __call__(self, *args, **kwargs):
95118
"""
119+
Consolidate any provided inference_state or pipeline_state objects and pass
120+
any other operator inputs to run().
121+
96122
:return: output of the pipeline operators ran with the router for the given
97-
input
123+
input
98124
"""
125+
if kwargs.get("inference_state"):
126+
inference_state = kwargs.pop("inference_state")
127+
else:
128+
inference_state = InferenceState()
129+
inference_state.create_state({})
130+
131+
if "pipeline_state" in kwargs:
132+
self.pipeline_state = kwargs.get("pipeline_state")
133+
134+
kwargs["inference_state"] = inference_state
135+
kwargs["pipeline_state"] = self.pipeline_state
136+
99137
return self.run(*args, **kwargs)
100138

101139
def validate(self):

src/deepsparse/v2/routers/router.py

+51-6
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
import logging
1717
from abc import abstractmethod
18-
from typing import Dict, List, Union
18+
from typing import Any, Dict, List, Optional, Union
1919

2020
from deepsparse.v2.operators import Operator
2121

2222

2323
_LOGGER = logging.getLogger(__name__)
2424

25-
__all__ = ["Router", "LinearRouter"]
25+
__all__ = ["Router", "LinearRouter", "GraphRouter"]
2626

2727

2828
class Router:
@@ -32,23 +32,34 @@ class Router:
3232
3333
:param start_route: the start index or key of the router
3434
:param end_route: the end index or key of the router
35+
:param route: the route that the router has to traverse through
3536
3637
"""
3738

38-
def __init__(self, end_route: Union[str, int], start_route: Union[str, int]):
39+
def __init__(
40+
self,
41+
end_route: Union[str, int],
42+
start_route: Union[str, int],
43+
route: Optional[Dict] = None,
44+
):
3945
self.START_ROUTE = start_route
4046
self.END_ROUTE = end_route
47+
self.route = route
4148

4249
@abstractmethod
4350
def next(
44-
self, past: Union[str, int], ops: Union[List[Operator], Dict[str, Operator]]
51+
self,
52+
past: Union[str, int],
53+
ops: Optional[Union[List[Operator], Dict[str, Operator]]],
54+
inp: Optional[Any],
4555
) -> Union[str, int]:
4656
"""
4757
Determines the index or dictionary key for the next operator which should run.
4858
4959
:param past: the previous index or key. This should uniquely determine the next
50-
operator to run
60+
operator to run
5161
:param ops: list or dictionary of operators
62+
:param inp: operator input
5263
:returns: the next index or dictionary key for the next operator to run
5364
"""
5465
raise NotImplementedError
@@ -69,7 +80,9 @@ class LinearRouter(Router):
6980
def __init__(self, end_route: int, start_route: int = 0):
7081
super().__init__(end_route=end_route, start_route=start_route)
7182

72-
def next(self, past: int, ops: List[Operator]) -> int:
83+
def next(
84+
self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None
85+
) -> int:
7386
new_index = past + 1
7487
if new_index < self.END_ROUTE:
7588
return new_index
@@ -105,3 +118,35 @@ def validate(operators: List[Operator]) -> bool:
105118
)
106119
return False
107120
return True
121+
122+
123+
class GraphRouter(Router):
124+
"""
125+
Router for a DAG. Expects graphs be presented in the form of a dictionary, where
126+
keys are the nodes of the graph and the values are the connected nodes. For
127+
nodes with multiple ouput edges, all the nodes will be visited and the first node
128+
where `can_operate` returns True will run. Paths should be deterministic.
129+
"""
130+
131+
def __init__(self, end_route: str, start_route: str, route: Dict):
132+
super().__init__(end_route=end_route, start_route=start_route, route=route)
133+
134+
def next(
135+
self,
136+
past: str,
137+
ops: Dict[str, Operator],
138+
inp: Any,
139+
) -> int:
140+
node = past
141+
if isinstance(self.route[node], str):
142+
return self.route[node]
143+
else:
144+
for neighbour_node in self.route[node]:
145+
neighbour_node_op = ops[neighbour_node]
146+
if neighbour_node_op.can_operate(inp):
147+
return neighbour_node
148+
raise ValueError("Cannot operate on any of the nodes")
149+
150+
@staticmethod
151+
def validate(ops) -> bool:
152+
pass

0 commit comments

Comments
 (0)