18
18
from deepsparse .v2 .operators import Operator
19
19
from deepsparse .v2 .routers import Router
20
20
from deepsparse .v2 .schedulers import OperatorScheduler , SchedulerGroup
21
+ from deepsparse .v2 .utils import InferenceState , PipelineState
21
22
22
23
23
24
__all__ = ["Pipeline" ]
@@ -27,7 +28,7 @@ class Pipeline(Operator):
27
28
"""
28
29
Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline
29
30
will use the router to run through all the defined operators. The operators should
30
- be implemented using the Operator class and each implemented Operator should be
31
+ be implemented using the Operator class and each implemented operator should be
31
32
responsible for a functional component of the pipelines. The flow of inputs/outputs
32
33
between the operators and the steps in the pipeline should be defined by the router,
33
34
(based off of the Router class), which dicates the next operator in the pipeline.
@@ -37,6 +38,7 @@ class Pipeline(Operator):
37
38
or dictionary of operators.
38
39
:param router: A Router which dictates the next operator to call.
39
40
:param schedulers: A list of schedulers to run operators.
41
+ :param pipeline_state: pipeline_state created during pipeline initialization
40
42
41
43
"""
42
44
@@ -45,57 +47,93 @@ def __init__(
45
47
ops : Union [Dict [str , Operator ], List [Operator ]],
46
48
router : Router ,
47
49
schedulers : List [OperatorScheduler ],
50
+ pipeline_state : PipelineState = None ,
48
51
):
49
52
50
53
self .ops = ops
51
54
self .router = router
52
55
self .schedulers = schedulers
56
+ self .pipeline_state = pipeline_state
53
57
self .validate ()
54
58
55
59
# SchedulerGroup handles running all schedulers in order of priority
56
60
self ._scheduler_group = SchedulerGroup (self .schedulers )
57
61
58
- def run (self , * args , ** kwargs ):
62
+ def run (
63
+ self ,
64
+ * args ,
65
+ inference_state : InferenceState ,
66
+ pipeline_state : PipelineState ,
67
+ ** kwargs ,
68
+ ):
59
69
"""
60
- Run through the operators using the provided router and scheduler. Update the
61
- context to reflect each step of the router. The input to a given operator is the
62
- output of the previous operator.
63
-
64
- :param inp: input to the operator. expected to be of any type that is
65
- expected by the operator.
66
- :param context: context to store the current the inputs, outputs, and operator
67
- for each step of the router.
70
+ Run through the operators using the provided router and scheduler.
71
+ The input to a given operator is the output of the previous operator.
68
72
73
+ :param inference_state: inference_state for the pipeline.
74
+ :param pipeline_state: pipeline_state for the pipeline. The values in the state
75
+ are created during pipeline creation and are read-only during inference.
69
76
"""
70
77
next_step = self .router .START_ROUTE
71
78
operator_output = None
79
+
72
80
while next_step != self .router .END_ROUTE :
73
81
# Either a dictionary key or valid index
74
82
operator = self .ops [next_step ]
75
83
if next_step == self .router .START_ROUTE :
76
84
output_future = self ._scheduler_group .submit (
77
- * args , operator = operator , ** kwargs
85
+ * args ,
86
+ inference_state = inference_state ,
87
+ operator = operator ,
88
+ pipeline_state = pipeline_state ,
89
+ ** kwargs ,
78
90
)
79
91
else :
80
92
if isinstance (operator_output , dict ):
81
93
output_future = self ._scheduler_group .submit (
82
- operator = operator , ** operator_output
94
+ inference_state = inference_state ,
95
+ operator = operator ,
96
+ pipeline_state = pipeline_state ,
97
+ ** operator_output ,
83
98
)
84
99
else :
85
100
output_future = self ._scheduler_group .submit (
86
- operator_output , operator = operator
101
+ operator_output ,
102
+ inference_state = inference_state ,
103
+ pipeline_state = pipeline_state ,
104
+ operator = operator ,
87
105
)
88
106
89
- # wait for future to resolve
90
107
operator_output = output_future .result ()
91
- next_step = self .router .next (next_step , self .ops )
108
+ if isinstance (operator_output , tuple ):
109
+ state_update = operator_output [- 1 ]
110
+ operator_output = operator_output [0 ]
111
+ inference_state .update_state (state_update )
112
+
113
+ next_step = self .router .next (next_step , self .ops , operator_output )
114
+
92
115
return operator_output
93
116
94
117
def __call__ (self , * args , ** kwargs ):
95
118
"""
119
+ Consolidate any provided inference_state or pipeline_state objects and pass
120
+ any other operator inputs to run().
121
+
96
122
:return: output of the pipeline operators ran with the router for the given
97
- input
123
+ input
98
124
"""
125
+ if kwargs .get ("inference_state" ):
126
+ inference_state = kwargs .pop ("inference_state" )
127
+ else :
128
+ inference_state = InferenceState ()
129
+ inference_state .create_state ({})
130
+
131
+ if "pipeline_state" in kwargs :
132
+ self .pipeline_state = kwargs .get ("pipeline_state" )
133
+
134
+ kwargs ["inference_state" ] = inference_state
135
+ kwargs ["pipeline_state" ] = self .pipeline_state
136
+
99
137
return self .run (* args , ** kwargs )
100
138
101
139
def validate (self ):
0 commit comments