15
15
16
16
import copy
17
17
from concurrent .futures import Future
18
- from functools import partial
19
- from typing import Any , Callable , Dict , List , Union
18
+ from typing import Any , Dict , List , Union
20
19
21
- from deepsparse .v2 .operators import Operator
20
+ from deepsparse .v2 .operators import EngineOperator , Operator
22
21
from deepsparse .v2 .routers import Router
23
- from deepsparse .v2 .schedulers import OperatorScheduler , SchedulerGroup
22
+ from deepsparse .v2 .schedulers import (
23
+ ContinuousBatchingScheduler ,
24
+ OperatorScheduler ,
25
+ SchedulerGroup ,
26
+ )
24
27
from deepsparse .v2 .utils import InferenceState , PipelineState
28
+ from deepsparse .v2 .utils .data import SubGraph
29
+ from deepsparse .v2 .utils .helpers import run_func
25
30
26
31
27
32
__all__ = ["Pipeline" ]
@@ -50,39 +55,100 @@ def __init__(
50
55
ops : Union [Dict [str , Operator ], List [Operator ]],
51
56
router : Router ,
52
57
schedulers : List [OperatorScheduler ],
58
+ continuous_batching_scheduler : ContinuousBatchingScheduler ,
53
59
pipeline_state : PipelineState = None ,
54
60
):
55
61
56
62
self .ops = ops
57
63
self .router = router
58
64
self .schedulers = schedulers
59
65
self .pipeline_state = pipeline_state
66
+ self ._continuous_batching_scheduler = continuous_batching_scheduler
60
67
self .validate ()
61
68
62
69
self ._scheduler_group = SchedulerGroup (self .schedulers )
63
70
64
- def _run_sequential (
71
+ def _run_next (
65
72
self ,
66
73
inp : Any ,
67
74
inference_state : InferenceState ,
68
- pipeline_state : PipelineState ,
69
- start : str ,
70
- end : str ,
75
+ next_step : str ,
71
76
):
72
- next_step = start
73
- while next_step != end :
74
- outputs = self ._run_next_step (
75
- func = self .ops [next_step ],
76
- next_step = next_step ,
77
- input = inp ,
78
- pipeline_state = pipeline_state ,
79
- inference_state = inference_state ,
77
+ if (
78
+ isinstance (self .ops [next_step ], EngineOperator )
79
+ and self ._continuous_batching_scheduler
80
+ ):
81
+ func = self ._continuous_batching_scheduler .submit
82
+ inp = self .ops [next_step ].input_schema (** inp )
83
+ else :
84
+ func = self ._scheduler_group .submit
85
+
86
+ return run_func (
87
+ func = func ,
88
+ operator = self .ops [next_step ],
89
+ inp = inp ,
90
+ pipeline_state = self .pipeline_state ,
91
+ inference_state = inference_state ,
92
+ )
93
+
94
+ def _run_sub_graphs (
95
+ self , sub_graph_inputs : List [Any ], sub_graphs : List [SubGraph ]
96
+ ) -> List [Any ]:
97
+ """
98
+ Run a list of sub_graphs asynchronously. Polls to identify the sub graph that is
99
+ still running but has completed its current step. Schedules the next step
100
+ subgraph step. This is repeated until all subgraphs have finished running and
101
+ have reached their end step (stored in the Subgraph.end attribute).
102
+
103
+ :param sub_graph_inputs: A list of inputs that should be passed to each
104
+ subgraph. Each subgraph is given an element of the list as input to its
105
+ first node.
106
+ :param sub_graphs: A list of Subgraph objects. Each stores the relevant
107
+ execution information for the particular subgraph, such as its current step
108
+ in the sub graph, inference state, output, and end step.
109
+
110
+ :returns: a list of outputs for all the completed Subgraph objects. Returned
111
+ in the same order that the subgraphs were passed to the function.
112
+ """
113
+ for i in range (len (sub_graphs )):
114
+ sub_graphs [i ].output = self ._run_next (
115
+ sub_graph_inputs [i ], sub_graphs [i ].inf , sub_graphs [i ].step
80
116
)
81
- next_step , operator_output , state_update = outputs
82
- if state_update :
83
- inference_state .update_state (state_update )
84
- inp = operator_output
85
- return inp
117
+
118
+ # Execute all sub graphs until all graphs have been completed.
119
+ while True :
120
+ for sub_graph in sub_graphs :
121
+ if isinstance (sub_graph .output , Future ) and sub_graph .output .done ():
122
+ # get the result for the completed operator; resolve its output
123
+ operator_output = sub_graph .output .result ()
124
+ operator_output = sub_graph .parse_output (operator_output )
125
+
126
+ # determine the next step for the particular operator, using
127
+ # its previous output and previously stored step
128
+ next_step = self .router .next (
129
+ sub_graph .step , self .ops , operator_output
130
+ )
131
+ # update the step
132
+ sub_graph .step = next_step
133
+
134
+ # store the output for the next step. If the next step is
135
+ # end step, this particular route has completed. Simply
136
+ # update the output value
137
+ if next_step in sub_graph .end :
138
+ sub_graph .output = operator_output
139
+ else :
140
+ sub_graph .output = self ._run_next (
141
+ inp = operator_output ,
142
+ inference_state = sub_graph .inf ,
143
+ next_step = next_step ,
144
+ )
145
+ break
146
+
147
+ # keep running until all sub graphs have completed.
148
+ if not any (isinstance (x .output , Future ) for x in sub_graphs ):
149
+ break
150
+
151
+ return [x .output for x in sub_graphs ]
86
152
87
153
def _apply_split (self , inp : Any , inference_state : InferenceState ):
88
154
"""
@@ -93,59 +159,29 @@ def _apply_split(self, inp: Any, inference_state: InferenceState):
93
159
"""
94
160
95
161
batches , orig_batch_size = self .expand_inputs (inp , 1 )
96
- run_with_state = partial (
97
- self ._run_sequential ,
98
- pipeline_state = self .pipeline_state ,
99
- start = self .router .route [self .router .SPLIT_ROUTE ],
100
- end = self .router .JOIN_ROUTE ,
101
- )
102
- inference_state_list = [
103
- copy .deepcopy (inference_state ) for x in range (len (batches ))
104
- ]
105
- futures = self ._scheduler_group .map (
106
- batches ,
107
- inference_state_list ,
108
- func = run_with_state ,
109
- )
110
- return self .condense_inputs ([x .result () for x in futures ])
111
162
112
- def _run_next_step (
113
- self ,
114
- * args ,
115
- func : Callable ,
116
- next_step : Union [str , int ],
117
- input : Any = None ,
118
- ** kwargs ,
119
- ):
120
- """
121
- Generic function to run a given func, process the output and determine the next
122
- step.
123
- """
124
- if input :
125
- operator_output = (
126
- func (* args , ** kwargs , ** input )
127
- if isinstance (input , dict )
128
- else func (input , * args , ** kwargs )
163
+ # Create a list of SplitRoutes, per batch size 1
164
+ # Each SplitRoute object holds information about the particular path it
165
+ # follows. All start at the same step defined by SPLIT_ROUTE and start
166
+ # with the same inference_state.
167
+ split_graphs = [
168
+ SubGraph (
169
+ inf = copy .deepcopy (inference_state ),
170
+ step = self .router .route [self .router .SPLIT_ROUTE ],
171
+ end = [self .router .JOIN_ROUTE ],
129
172
)
130
- else :
131
- operator_output = func (* args , ** kwargs )
132
-
133
- if isinstance (operator_output , Future ):
134
- operator_output = operator_output .result ()
135
-
136
- state_update = None
137
- if isinstance (operator_output , tuple ):
138
- state_update = operator_output [- 1 ]
139
- operator_output = operator_output [0 ]
173
+ for i in range (len (batches ))
174
+ ]
140
175
141
- next_step = self .router .next (next_step , self .ops , operator_output )
142
- return next_step , operator_output , state_update
176
+ outputs = self ._run_sub_graphs (
177
+ sub_graph_inputs = batches , sub_graphs = split_graphs
178
+ )
179
+ return self .condense_inputs (outputs )
143
180
144
181
def run (
145
182
self ,
146
183
* args ,
147
184
inference_state : InferenceState ,
148
- pipeline_state : PipelineState ,
149
185
** kwargs ,
150
186
):
151
187
"""
@@ -158,36 +194,56 @@ def run(
158
194
"""
159
195
next_step = self .router .START_ROUTE
160
196
operator_output = None
161
-
162
197
while next_step != self .router .END_ROUTE :
198
+
199
+ # Split Grap Execution (i.e multiple subgraphs)
163
200
# NOTE: split_route should only appear after the start route node
164
201
if next_step == self .router .SPLIT_ROUTE :
202
+ if operator_output is None :
203
+ raise ValueError (
204
+ f"{ self .router .SPLIT_ROUTE } should appear after "
205
+ f"{ self .ROUTER .START_ROUTE } "
206
+ )
207
+
165
208
operator_output = self ._apply_split (operator_output , inference_state )
166
209
next_step = self .router .route [self .router .JOIN_ROUTE ]
210
+ if next_step == self .router .END_ROUTE :
211
+ return operator_output
167
212
168
213
if next_step == self .router .START_ROUTE :
169
- outputs = self . _run_next_step (
214
+ operator_output = run_func (
170
215
* args ,
171
- next_step = next_step ,
172
216
func = self ._scheduler_group .submit ,
173
- inference_state = inference_state ,
174
217
operator = self .ops [next_step ],
175
- pipeline_state = pipeline_state ,
218
+ inference_state = inference_state ,
219
+ pipeline_state = self .pipeline_state ,
176
220
** kwargs ,
177
- )
221
+ ).result ()
222
+
223
+ if isinstance (operator_output , tuple ):
224
+ operator_output , state_update = (
225
+ operator_output [0 ],
226
+ operator_output [- 1 ],
227
+ )
228
+ inference_state .update_state (state_update )
229
+
230
+ next_step = self .router .next (next_step , self .ops , operator_output )
231
+
178
232
else :
179
- outputs = self ._run_next_step (
180
- func = self ._scheduler_group .submit ,
181
- input = operator_output ,
182
- next_step = next_step ,
183
- inference_state = inference_state ,
184
- operator = self .ops [next_step ],
185
- pipeline_state = pipeline_state ,
233
+ # Single graph execution
234
+ graph = SubGraph (
235
+ inf = copy .deepcopy (inference_state ),
236
+ step = next_step ,
237
+ end = [self .router .SPLIT_ROUTE , self .router .END_ROUTE ],
186
238
)
187
239
188
- next_step , operator_output , state_update = outputs
189
- if state_update :
190
- inference_state .update_state (state_update )
240
+ operator_output = self ._run_sub_graphs (
241
+ sub_graph_inputs = [operator_output ], sub_graphs = [graph ]
242
+ )[0 ]
243
+
244
+ inference_state = graph .inf
245
+ next_step = graph .step
246
+
191
247
return operator_output
192
248
193
249
def __call__ (self , * args , ** kwargs ):
@@ -204,11 +260,7 @@ def __call__(self, *args, **kwargs):
204
260
inference_state = InferenceState ()
205
261
inference_state .create_state ({})
206
262
207
- if "pipeline_state" in kwargs :
208
- self .pipeline_state = kwargs .get ("pipeline_state" )
209
-
210
263
kwargs ["inference_state" ] = inference_state
211
- kwargs ["pipeline_state" ] = self .pipeline_state
212
264
213
265
return self .run (* args , ** kwargs )
214
266
0 commit comments