Skip to content

Commit c858b1f

Browse files
authored
[Pipeline Refactor][Text Generation][Continuous Batching] Integration (#1409)
* update split/join * use map * update * run end-to-end * clean-up * fix bug with batch size, introduce SplitRoute dataclass * update tests to use new inputs/outputs * use the normal scheduler for internal kv_cache * add pipeline inpuits * clean-up * change engine type, update docstrings, update override function to be more generic * move subgraph functionality to its own function; clean-up cont batching in text gen pipeline * update linear pathway to also use subgraph execution * rebase fix * fix tests
1 parent 1b9238a commit c858b1f

20 files changed

+486
-171
lines changed

src/deepsparse/v2/operators/engine_operator.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from deepsparse import Context as EngineContext
2121
from deepsparse import Engine, MultiModelEngine, Scheduler
2222
from deepsparse.benchmark import ORTEngine
23-
from deepsparse.utils import model_to_path
23+
from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs
2424
from deepsparse.v2.operators import Operator
2525

2626

@@ -29,12 +29,12 @@
2929

3030
SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE]
3131

32-
__all__ = ["EngineOperator"]
32+
__all__ = ["EngineOperator", "EngineOperatorInputs", "EngineOperatorOutputs"]
3333

3434

3535
class EngineOperatorInputs(BaseModel):
3636
engine_inputs: List = Field(description="engine_inputs")
37-
engine: Optional[Engine] = Field(
37+
engine: Optional[Union[ORTEngine, Engine]] = Field(
3838
description="override the engine to run forward pass with",
3939
default=None,
4040
)
@@ -95,8 +95,8 @@ def __init__(
9595
engine_kwargs: Dict = None,
9696
):
9797
self.model_path = model_to_path(model_path)
98-
self._batch_size = 1
9998
self.engine_context = engine_context
99+
self._batch_size = 1
100100

101101
if self.engine_context is not None:
102102
num_cores = num_cores or self.engine_context.num_cores
@@ -131,6 +131,7 @@ def batch_size(self) -> int:
131131
"""
132132
return self._batch_size
133133

134+
# TODO: maybe add a few args to make this less opaque?
134135
def create_engine(
135136
self,
136137
**kwargs,
@@ -142,7 +143,8 @@ def create_engine(
142143
constructor/compilation
143144
:return: inference engine
144145
"""
145-
onnx_file_path = self.model_path
146+
147+
onnx_file_path = kwargs.pop("model_path", self.model_path)
146148
engine_args = deepcopy(self._engine_args)
147149
engine_args.update(kwargs)
148150
engine_type = self._engine_type.lower()

src/deepsparse/v2/operators/operator.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pydantic import BaseModel
1919

20-
from deepsparse.v2.utils import InferenceState, PipelineState
20+
from deepsparse.v2.utils import InferenceState
2121

2222

2323
__all__ = ["Operator"]
@@ -57,7 +57,6 @@ def __call__(
5757
self,
5858
*args,
5959
inference_state: InferenceState,
60-
pipeline_state: PipelineState,
6160
**kwargs,
6261
) -> Any:
6362
"""
@@ -90,13 +89,11 @@ def __call__(
9089
run_output = self.run(
9190
inference_input,
9291
inference_state=inference_state,
93-
pipeline_state=pipeline_state,
9492
)
9593
else:
9694
run_output = self.run(
9795
*args,
9896
inference_state=inference_state,
99-
pipeline_state=pipeline_state,
10097
**kwargs,
10198
)
10299
if self.has_output_schema():

src/deepsparse/v2/pipeline.py

+138-86
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515

1616
import copy
1717
from concurrent.futures import Future
18-
from functools import partial
19-
from typing import Any, Callable, Dict, List, Union
18+
from typing import Any, Dict, List, Union
2019

21-
from deepsparse.v2.operators import Operator
20+
from deepsparse.v2.operators import EngineOperator, Operator
2221
from deepsparse.v2.routers import Router
23-
from deepsparse.v2.schedulers import OperatorScheduler, SchedulerGroup
22+
from deepsparse.v2.schedulers import (
23+
ContinuousBatchingScheduler,
24+
OperatorScheduler,
25+
SchedulerGroup,
26+
)
2427
from deepsparse.v2.utils import InferenceState, PipelineState
28+
from deepsparse.v2.utils.data import SubGraph
29+
from deepsparse.v2.utils.helpers import run_func
2530

2631

2732
__all__ = ["Pipeline"]
@@ -50,39 +55,100 @@ def __init__(
5055
ops: Union[Dict[str, Operator], List[Operator]],
5156
router: Router,
5257
schedulers: List[OperatorScheduler],
58+
continuous_batching_scheduler: ContinuousBatchingScheduler,
5359
pipeline_state: PipelineState = None,
5460
):
5561

5662
self.ops = ops
5763
self.router = router
5864
self.schedulers = schedulers
5965
self.pipeline_state = pipeline_state
66+
self._continuous_batching_scheduler = continuous_batching_scheduler
6067
self.validate()
6168

6269
self._scheduler_group = SchedulerGroup(self.schedulers)
6370

64-
def _run_sequential(
71+
def _run_next(
6572
self,
6673
inp: Any,
6774
inference_state: InferenceState,
68-
pipeline_state: PipelineState,
69-
start: str,
70-
end: str,
75+
next_step: str,
7176
):
72-
next_step = start
73-
while next_step != end:
74-
outputs = self._run_next_step(
75-
func=self.ops[next_step],
76-
next_step=next_step,
77-
input=inp,
78-
pipeline_state=pipeline_state,
79-
inference_state=inference_state,
77+
if (
78+
isinstance(self.ops[next_step], EngineOperator)
79+
and self._continuous_batching_scheduler
80+
):
81+
func = self._continuous_batching_scheduler.submit
82+
inp = self.ops[next_step].input_schema(**inp)
83+
else:
84+
func = self._scheduler_group.submit
85+
86+
return run_func(
87+
func=func,
88+
operator=self.ops[next_step],
89+
inp=inp,
90+
pipeline_state=self.pipeline_state,
91+
inference_state=inference_state,
92+
)
93+
94+
def _run_sub_graphs(
95+
self, sub_graph_inputs: List[Any], sub_graphs: List[SubGraph]
96+
) -> List[Any]:
97+
"""
98+
Run a list of sub_graphs asynchronously. Polls to identify the sub graph that is
99+
still running but has completed its current step. Schedules the next step
100+
subgraph step. This is repeated until all subgraphs have finished running and
101+
have reached their end step (stored in the Subgraph.end attribute).
102+
103+
:param sub_graph_inputs: A list of inputs that should be passed to each
104+
subgraph. Each subgraph is given an element of the list as input to its
105+
first node.
106+
:param sub_graphs: A list of Subgraph objects. Each stores the relevant
107+
execution information for the particular subgraph, such as its current step
108+
in the sub graph, inference state, output, and end step.
109+
110+
:returns: a list of outputs for all the completed Subgraph objects. Returned
111+
in the same order that the subgraphs were passed to the function.
112+
"""
113+
for i in range(len(sub_graphs)):
114+
sub_graphs[i].output = self._run_next(
115+
sub_graph_inputs[i], sub_graphs[i].inf, sub_graphs[i].step
80116
)
81-
next_step, operator_output, state_update = outputs
82-
if state_update:
83-
inference_state.update_state(state_update)
84-
inp = operator_output
85-
return inp
117+
118+
# Execute all sub graphs until all graphs have been completed.
119+
while True:
120+
for sub_graph in sub_graphs:
121+
if isinstance(sub_graph.output, Future) and sub_graph.output.done():
122+
# get the result for the completed operator; resolve its output
123+
operator_output = sub_graph.output.result()
124+
operator_output = sub_graph.parse_output(operator_output)
125+
126+
# determine the next step for the particular operator, using
127+
# its previous output and previously stored step
128+
next_step = self.router.next(
129+
sub_graph.step, self.ops, operator_output
130+
)
131+
# update the step
132+
sub_graph.step = next_step
133+
134+
# store the output for the next step. If the next step is
135+
# end step, this particular route has completed. Simply
136+
# update the output value
137+
if next_step in sub_graph.end:
138+
sub_graph.output = operator_output
139+
else:
140+
sub_graph.output = self._run_next(
141+
inp=operator_output,
142+
inference_state=sub_graph.inf,
143+
next_step=next_step,
144+
)
145+
break
146+
147+
# keep running until all sub graphs have completed.
148+
if not any(isinstance(x.output, Future) for x in sub_graphs):
149+
break
150+
151+
return [x.output for x in sub_graphs]
86152

87153
def _apply_split(self, inp: Any, inference_state: InferenceState):
88154
"""
@@ -93,59 +159,29 @@ def _apply_split(self, inp: Any, inference_state: InferenceState):
93159
"""
94160

95161
batches, orig_batch_size = self.expand_inputs(inp, 1)
96-
run_with_state = partial(
97-
self._run_sequential,
98-
pipeline_state=self.pipeline_state,
99-
start=self.router.route[self.router.SPLIT_ROUTE],
100-
end=self.router.JOIN_ROUTE,
101-
)
102-
inference_state_list = [
103-
copy.deepcopy(inference_state) for x in range(len(batches))
104-
]
105-
futures = self._scheduler_group.map(
106-
batches,
107-
inference_state_list,
108-
func=run_with_state,
109-
)
110-
return self.condense_inputs([x.result() for x in futures])
111162

112-
def _run_next_step(
113-
self,
114-
*args,
115-
func: Callable,
116-
next_step: Union[str, int],
117-
input: Any = None,
118-
**kwargs,
119-
):
120-
"""
121-
Generic function to run a given func, process the output and determine the next
122-
step.
123-
"""
124-
if input:
125-
operator_output = (
126-
func(*args, **kwargs, **input)
127-
if isinstance(input, dict)
128-
else func(input, *args, **kwargs)
163+
# Create a list of SplitRoutes, per batch size 1
164+
# Each SplitRoute object holds information about the particular path it
165+
# follows. All start at the same step defined by SPLIT_ROUTE and start
166+
# with the same inference_state.
167+
split_graphs = [
168+
SubGraph(
169+
inf=copy.deepcopy(inference_state),
170+
step=self.router.route[self.router.SPLIT_ROUTE],
171+
end=[self.router.JOIN_ROUTE],
129172
)
130-
else:
131-
operator_output = func(*args, **kwargs)
132-
133-
if isinstance(operator_output, Future):
134-
operator_output = operator_output.result()
135-
136-
state_update = None
137-
if isinstance(operator_output, tuple):
138-
state_update = operator_output[-1]
139-
operator_output = operator_output[0]
173+
for i in range(len(batches))
174+
]
140175

141-
next_step = self.router.next(next_step, self.ops, operator_output)
142-
return next_step, operator_output, state_update
176+
outputs = self._run_sub_graphs(
177+
sub_graph_inputs=batches, sub_graphs=split_graphs
178+
)
179+
return self.condense_inputs(outputs)
143180

144181
def run(
145182
self,
146183
*args,
147184
inference_state: InferenceState,
148-
pipeline_state: PipelineState,
149185
**kwargs,
150186
):
151187
"""
@@ -158,36 +194,56 @@ def run(
158194
"""
159195
next_step = self.router.START_ROUTE
160196
operator_output = None
161-
162197
while next_step != self.router.END_ROUTE:
198+
199+
# Split Grap Execution (i.e multiple subgraphs)
163200
# NOTE: split_route should only appear after the start route node
164201
if next_step == self.router.SPLIT_ROUTE:
202+
if operator_output is None:
203+
raise ValueError(
204+
f"{self.router.SPLIT_ROUTE} should appear after "
205+
f"{self.ROUTER.START_ROUTE}"
206+
)
207+
165208
operator_output = self._apply_split(operator_output, inference_state)
166209
next_step = self.router.route[self.router.JOIN_ROUTE]
210+
if next_step == self.router.END_ROUTE:
211+
return operator_output
167212

168213
if next_step == self.router.START_ROUTE:
169-
outputs = self._run_next_step(
214+
operator_output = run_func(
170215
*args,
171-
next_step=next_step,
172216
func=self._scheduler_group.submit,
173-
inference_state=inference_state,
174217
operator=self.ops[next_step],
175-
pipeline_state=pipeline_state,
218+
inference_state=inference_state,
219+
pipeline_state=self.pipeline_state,
176220
**kwargs,
177-
)
221+
).result()
222+
223+
if isinstance(operator_output, tuple):
224+
operator_output, state_update = (
225+
operator_output[0],
226+
operator_output[-1],
227+
)
228+
inference_state.update_state(state_update)
229+
230+
next_step = self.router.next(next_step, self.ops, operator_output)
231+
178232
else:
179-
outputs = self._run_next_step(
180-
func=self._scheduler_group.submit,
181-
input=operator_output,
182-
next_step=next_step,
183-
inference_state=inference_state,
184-
operator=self.ops[next_step],
185-
pipeline_state=pipeline_state,
233+
# Single graph execution
234+
graph = SubGraph(
235+
inf=copy.deepcopy(inference_state),
236+
step=next_step,
237+
end=[self.router.SPLIT_ROUTE, self.router.END_ROUTE],
186238
)
187239

188-
next_step, operator_output, state_update = outputs
189-
if state_update:
190-
inference_state.update_state(state_update)
240+
operator_output = self._run_sub_graphs(
241+
sub_graph_inputs=[operator_output], sub_graphs=[graph]
242+
)[0]
243+
244+
inference_state = graph.inf
245+
next_step = graph.step
246+
191247
return operator_output
192248

193249
def __call__(self, *args, **kwargs):
@@ -204,11 +260,7 @@ def __call__(self, *args, **kwargs):
204260
inference_state = InferenceState()
205261
inference_state.create_state({})
206262

207-
if "pipeline_state" in kwargs:
208-
self.pipeline_state = kwargs.get("pipeline_state")
209-
210263
kwargs["inference_state"] = inference_state
211-
kwargs["pipeline_state"] = self.pipeline_state
212264

213265
return self.run(*args, **kwargs)
214266

src/deepsparse/v2/routers/router.py

-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ class LinearRouter(Router):
8383

8484
def __init__(self, end_route: int, start_route: int = 0):
8585
super().__init__(end_route=end_route, start_route=start_route)
86-
self.SPLIT_ROUTE = None
87-
self.JOIN_ROUTE = None
8886
_LOGGER.warn("SPLIT and JOIN are not yet supported for the LinearRouter.")
8987

9088
def next(

0 commit comments

Comments
 (0)