Skip to content

Commit 625a1c3

Browse files
committed
remove context, clean-up args, remove prefill_preprocess_operaator
1 parent 6007a75 commit 625a1c3

19 files changed

+249
-302
lines changed

src/deepsparse/v2/operators/engine_operator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
num_streams: int = None,
5252
scheduler: Scheduler = None,
5353
input_shapes: List[List[int]] = None,
54-
engine_context: Optional[Context] = None,
54+
engine_context: Optional[EngineContext] = None,
5555
engine_kwargs: Dict = None,
5656
):
5757
self.model_path = model_to_path(model_path)
@@ -76,7 +76,7 @@ def __init__(
7676
engine_args["scheduler"] = scheduler
7777
engine_args["num_streams"] = num_streams
7878

79-
engine_args.updated(engine_kwargs)
79+
engine_args.update(engine_kwargs)
8080
self.engine = self._create_engine(self.model_path, engine_type, engine_args)
8181

8282
def _create_engine(
@@ -95,12 +95,12 @@ def _create_engine(
9595

9696
if engine_type == DEEPSPARSE_ENGINE:
9797
if self.engine_context is not None and isinstance(
98-
self.engine_context, Context
98+
self.engine_context, EngineContext
9999
):
100100
engine_args.pop("num_cores", None)
101101
engine_args.pop("scheduler", None)
102102
engine_args.pop("num_streams", None)
103-
engine_args["context"] = self.engien_context
103+
engine_args["context"] = self.engine_context
104104
return MultiModelEngine(
105105
model=onnx_file_path,
106106
**engine_args,
@@ -116,7 +116,7 @@ def _create_engine(
116116
f"{SUPPORTED_PIPELINE_ENGINES}"
117117
)
118118

119-
def run(self, inp: EngineOperatorInputs) -> Dict:
119+
def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict:
120120
inp = inp.engine_inputs
121121
batches, orig_batch_size = self.expand_inputs(engine_inputs=inp)
122122
batches_outputs = list(map(self.engine, batches))

src/deepsparse/v2/operators/operator.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from pydantic import BaseModel
1919

20+
from deepsparse.v2.utils import InferenceState, PipelineState
21+
2022

2123
__all__ = ["Operator"]
2224

@@ -54,6 +56,8 @@ def has_output_schema(cls) -> bool:
5456
def __call__(
5557
self,
5658
*args,
59+
inference_state: InferenceState,
60+
pipeline_state: PipelineState,
5761
**kwargs,
5862
) -> Any:
5963
"""
@@ -81,10 +85,18 @@ def __call__(
8185
"in the form of a dictionary or an instance of the input_schema"
8286
"object"
8387
)
84-
85-
run_output = self.run(inference_input)
88+
run_output = self.run(
89+
inference_input,
90+
inference_state=inference_state,
91+
pipeline_state=pipeline_state,
92+
)
8693
else:
87-
run_output = self.run(*args, **kwargs)
94+
run_output = self.run(
95+
*args,
96+
inference_state=inference_state,
97+
pipeline_state=pipeline_state,
98+
**kwargs,
99+
)
88100

89101
if self.has_output_schema():
90102
return self.output_schema(**run_output)
@@ -99,13 +111,11 @@ def run(self, *args, **kwargs) -> Any:
99111
"""
100112
raise NotImplementedError
101113

102-
def can_operate(
103-
self, inp: Any, context: Context, inference_state: InferenceState
104-
) -> bool:
114+
def can_operate(self, inp: Any) -> bool:
105115
"""
106-
Whether or not the given operator can run, based on input, context, or state
116+
Whether or not the given operator can run, based on input
107117
"""
108-
raise NotImplementedError
118+
return True
109119

110120
def expand_inputs(self, **kwargs):
111121
"""

src/deepsparse/v2/pipeline.py

+47-30
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from deepsparse.v2.operators import Operator
1919
from deepsparse.v2.routers import Router
2020
from deepsparse.v2.schedulers import OperatorScheduler, SchedulerGroup
21-
from deepsparse.v2.utils import Context
21+
from deepsparse.v2.utils import InferenceState, PipelineState
2222

2323

2424
__all__ = ["Pipeline"]
@@ -40,71 +40,76 @@ class Pipeline(Operator):
4040
:param schedulers: A list of schedulers to run operators.
4141
4242
"""
43-
43+
4444
def __init__(
4545
self,
4646
ops: Union[Dict[str, Operator], List[Operator]],
4747
router: Router,
4848
schedulers: List[OperatorScheduler],
49+
pipeline_state: PipelineState = None,
4950
):
5051

5152
self.ops = ops
5253
self.router = router
5354
self.schedulers = schedulers
55+
self.pipeline_state = pipeline_state
5456
self.validate()
5557

5658
# SchedulerGroup handles running all schedulers in order of priority
5759
self._scheduler_group = SchedulerGroup(self.schedulers)
5860

59-
def run(self, *args, **kwargs):
61+
def run(
62+
self,
63+
*args,
64+
inference_state: InferenceState,
65+
pipeline_state: PipelineState,
66+
**kwargs,
67+
):
6068
"""
61-
Run through the operators using the provided router and scheduler. Update the
62-
context to reflect each step of the router. The input to a given operator is the
63-
output of the previous operator.
69+
Run through the operators using the provided router and scheduler.
70+
The input to a given operator is the output of the previous operator.
6471
6572
:param inp: input to the operator. expected to be of any type that is
6673
expected by the operator.
6774
6875
"""
6976
next_step = self.router.START_ROUTE
7077
operator_output = None
78+
7179
while next_step != self.router.END_ROUTE:
7280
# Either a dictionary key or valid index
7381
operator = self.ops[next_step]
7482
if next_step == self.router.START_ROUTE:
7583
output_future = self._scheduler_group.submit(
76-
*args, operator=operator, **kwargs
84+
*args,
85+
inference_state=inference_state,
86+
operator=operator,
87+
pipeline_state=pipeline_state,
88+
**kwargs,
7789
)
7890
else:
7991
if isinstance(operator_output, dict):
8092
output_future = self._scheduler_group.submit(
81-
operator=operator, **operator_output
93+
inference_state=inference_state,
94+
operator=operator,
95+
pipeline_state=pipeline_state,
96+
**operator_output,
8297
)
8398
else:
8499
output_future = self._scheduler_group.submit(
85-
operator_output, operator=operator
100+
operator_output,
101+
inference_state=inference_state,
102+
pipeline_state=pipeline_state,
103+
operator=operator,
86104
)
87-
88-
# print("Current State", inference_state.current_state)
89-
90-
"""
91-
output_future = self._scheduler_group.submit(
92-
operator=operator,
93-
operator_input=inp,
94-
context=context,
95-
pipeline_state=self.pipeline_state,
96-
inference_state=inference_state,
97-
)
98-
"""
99-
100-
# wait for future to resolve
101-
operator_output, state_update = output_future.result()
102-
inference_state.update_state(state_update)
103-
104-
next_step = self.router.next(
105-
next_step, self.ops, context, operator_output, inference_state
106-
)
107-
inp = operator_output
105+
106+
operator_output = output_future.result()
107+
if isinstance(operator_output, tuple):
108+
state_update = operator_output[-1]
109+
operator_output = operator_output[0]
110+
inference_state.update_state(state_update)
111+
112+
next_step = self.router.next(next_step, self.ops, operator_output)
108113

109114
return operator_output
110115

@@ -113,6 +118,18 @@ def __call__(self, *args, **kwargs):
113118
:return: output of the pipeline operators ran with the router for the given
114119
input
115120
"""
121+
if kwargs.get("inference_state"):
122+
inference_state = kwargs.pop("inference_state")
123+
else:
124+
inference_state = InferenceState()
125+
inference_state.create_state({})
126+
127+
if "pipeline_state" in kwargs:
128+
self.pipeline_state = kwargs.get("pipeline_state")
129+
130+
kwargs["inference_state"] = inference_state
131+
kwargs["pipeline_state"] = self.pipeline_state
132+
116133
return self.run(*args, **kwargs)
117134

118135
def validate(self):

src/deepsparse/v2/routers/router.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515

1616
import logging
1717
from abc import abstractmethod
18-
from typing import Dict, List, Union
18+
from typing import Any, Dict, List, Optional, Union
1919

2020
from deepsparse.v2.operators import Operator
21-
from deepsparse.v2.utils import Context, InferenceState
2221

2322

2423
_LOGGER = logging.getLogger(__name__)
2524

26-
__all__ = ["Router", "LinearRouter"]
25+
__all__ = ["Router", "LinearRouter", "GraphRouter"]
2726

2827

2928
class Router:
@@ -36,14 +35,22 @@ class Router:
3635
3736
"""
3837

39-
def __init__(self, end_route: Union[str, int], start_route: Union[str, int]):
38+
def __init__(
39+
self,
40+
end_route: Union[str, int],
41+
start_route: Union[str, int],
42+
route: Optional[Dict] = None,
43+
):
4044
self.START_ROUTE = start_route
4145
self.END_ROUTE = end_route
4246
self.route = route
4347

4448
@abstractmethod
4549
def next(
46-
self, past: Union[str, int], ops: Union[List[Operator], Dict[str, Operator]]
50+
self,
51+
past: Union[str, int],
52+
ops: Optional[Union[List[Operator], Dict[str, Operator]]],
53+
inp: Optional[Any],
4754
) -> Union[str, int]:
4855
"""
4956
Determines the index or dictionary key for the next operator which should run.
@@ -73,7 +80,9 @@ class LinearRouter(Router):
7380
def __init__(self, end_route: int, start_route: int = 0):
7481
super().__init__(end_route=end_route, start_route=start_route)
7582

76-
def next(self, past: int, ops: List[Operator]) -> int:
83+
def next(
84+
self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None
85+
) -> int:
7786
new_index = past + 1
7887
if new_index < self.END_ROUTE:
7988
return new_index
@@ -111,7 +120,7 @@ def validate(operators: List[Operator]) -> bool:
111120
return True
112121

113122

114-
class TextGenerationRouter(Router):
123+
class GraphRouter(Router):
115124
"""
116125
Router for a DAG. Expects graphs be presented in the form of a dictionary, where
117126
keys are the nodes of the graph and the values are the connected nodes. For
@@ -126,17 +135,15 @@ def next(
126135
self,
127136
past: str,
128137
ops: Dict[str, Operator],
129-
context: Context,
130138
inp: Any,
131-
inference_state: InferenceState,
132139
) -> int:
133140
node = past
134141
if isinstance(self.route[node], str):
135142
return self.route[node]
136143
else:
137144
for neighbour_node in self.route[node]:
138145
neighbour_node_op = ops[neighbour_node]
139-
if neighbour_node_op.can_operate(inp, context, inference_state):
146+
if neighbour_node_op.can_operate(inp):
140147
return neighbour_node
141148
raise ValueError("Cannot operate on any of the nodes")
142149

src/deepsparse/v2/schedulers/scheduler.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
from concurrent.futures import Future, ThreadPoolExecutor
17-
from typing import Any
1817

1918
from deepsparse.v2.operators import Operator
2019

@@ -37,19 +36,30 @@ class OperatorScheduler:
3736
def __init__(self, max_workers: int = 1):
3837
self._threadpool = ThreadPoolExecutor(max_workers=max_workers)
3938

40-
def submit(self, *args, operator: Operator, **kwargs) -> Future:
39+
def submit(
40+
self,
41+
*args,
42+
operator: Operator,
43+
**kwargs,
44+
) -> Future:
4145
"""
4246
:param operator: operator to run
43-
:param operator_input: input schema to the operator
44-
:param context: context of already run operators
4547
:return: future referencing the asynchronously run output of the operator
4648
"""
47-
return self._threadpool.submit(operator, *args, **kwargs)
49+
return self._threadpool.submit(
50+
operator,
51+
*args,
52+
**kwargs,
53+
)
4854

49-
def can_process(self, *args, operator: Operator, **kwargs) -> bool:
55+
def can_process(
56+
self,
57+
*args,
58+
operator: Operator,
59+
**kwargs,
60+
) -> bool:
5061
"""
5162
:param operator: operator to check
52-
:param operator_input: operator_input to check
5363
:return: True if this Operator can process the given operator and input.
5464
Base OperatorScheduler always returns True
5565
"""

0 commit comments

Comments
 (0)