18
18
from deepsparse .v2 .operators import Operator
19
19
from deepsparse .v2 .routers import Router
20
20
from deepsparse .v2 .schedulers import OperatorScheduler , SchedulerGroup
21
- from deepsparse .v2 .utils import Context
21
+ from deepsparse .v2 .utils import InferenceState , PipelineState
22
22
23
23
24
24
__all__ = ["Pipeline" ]
@@ -40,71 +40,76 @@ class Pipeline(Operator):
40
40
:param schedulers: A list of schedulers to run operators.
41
41
42
42
"""
43
-
43
+
44
44
def __init__ (
45
45
self ,
46
46
ops : Union [Dict [str , Operator ], List [Operator ]],
47
47
router : Router ,
48
48
schedulers : List [OperatorScheduler ],
49
+ pipeline_state : PipelineState = None ,
49
50
):
50
51
51
52
self .ops = ops
52
53
self .router = router
53
54
self .schedulers = schedulers
55
+ self .pipeline_state = pipeline_state
54
56
self .validate ()
55
57
56
58
# SchedulerGroup handles running all schedulers in order of priority
57
59
self ._scheduler_group = SchedulerGroup (self .schedulers )
58
60
59
- def run (self , * args , ** kwargs ):
61
+ def run (
62
+ self ,
63
+ * args ,
64
+ inference_state : InferenceState ,
65
+ pipeline_state : PipelineState ,
66
+ ** kwargs ,
67
+ ):
60
68
"""
61
- Run through the operators using the provided router and scheduler. Update the
62
- context to reflect each step of the router. The input to a given operator is the
63
- output of the previous operator.
69
+ Run through the operators using the provided router and scheduler.
70
+ The input to a given operator is the output of the previous operator.
64
71
65
72
:param inp: input to the operator. expected to be of any type that is
66
73
expected by the operator.
67
74
68
75
"""
69
76
next_step = self .router .START_ROUTE
70
77
operator_output = None
78
+
71
79
while next_step != self .router .END_ROUTE :
72
80
# Either a dictionary key or valid index
73
81
operator = self .ops [next_step ]
74
82
if next_step == self .router .START_ROUTE :
75
83
output_future = self ._scheduler_group .submit (
76
- * args , operator = operator , ** kwargs
84
+ * args ,
85
+ inference_state = inference_state ,
86
+ operator = operator ,
87
+ pipeline_state = pipeline_state ,
88
+ ** kwargs ,
77
89
)
78
90
else :
79
91
if isinstance (operator_output , dict ):
80
92
output_future = self ._scheduler_group .submit (
81
- operator = operator , ** operator_output
93
+ inference_state = inference_state ,
94
+ operator = operator ,
95
+ pipeline_state = pipeline_state ,
96
+ ** operator_output ,
82
97
)
83
98
else :
84
99
output_future = self ._scheduler_group .submit (
85
- operator_output , operator = operator
100
+ operator_output ,
101
+ inference_state = inference_state ,
102
+ pipeline_state = pipeline_state ,
103
+ operator = operator ,
86
104
)
87
-
88
- # print("Current State", inference_state.current_state)
89
-
90
- """
91
- output_future = self._scheduler_group.submit(
92
- operator=operator,
93
- operator_input=inp,
94
- context=context,
95
- pipeline_state=self.pipeline_state,
96
- inference_state=inference_state,
97
- )
98
- """
99
-
100
- # wait for future to resolve
101
- operator_output , state_update = output_future .result ()
102
- inference_state .update_state (state_update )
103
-
104
- next_step = self .router .next (
105
- next_step , self .ops , context , operator_output , inference_state
106
- )
107
- inp = operator_output
105
+
106
+ operator_output = output_future .result ()
107
+ if isinstance (operator_output , tuple ):
108
+ state_update = operator_output [- 1 ]
109
+ operator_output = operator_output [0 ]
110
+ inference_state .update_state (state_update )
111
+
112
+ next_step = self .router .next (next_step , self .ops , operator_output )
108
113
109
114
return operator_output
110
115
@@ -113,6 +118,18 @@ def __call__(self, *args, **kwargs):
113
118
:return: output of the pipeline operators ran with the router for the given
114
119
input
115
120
"""
121
+ if kwargs .get ("inference_state" ):
122
+ inference_state = kwargs .pop ("inference_state" )
123
+ else :
124
+ inference_state = InferenceState ()
125
+ inference_state .create_state ({})
126
+
127
+ if "pipeline_state" in kwargs :
128
+ self .pipeline_state = kwargs .get ("pipeline_state" )
129
+
130
+ kwargs ["inference_state" ] = inference_state
131
+ kwargs ["pipeline_state" ] = self .pipeline_state
132
+
116
133
return self .run (* args , ** kwargs )
117
134
118
135
def validate (self ):
0 commit comments