Skip to content

Commit 72234a5

Browse files
committed
postprocessing_kwargs -> context
1 parent 9b916f3 commit 72234a5

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

Diff for: src/deepsparse/pipeline.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
229229
# batch size of the inputs may be `> self._batch_size` at this point
230230
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
231231
if isinstance(engine_inputs, tuple):
232-
engine_inputs, postprocess_kwargs = engine_inputs
232+
engine_inputs, context = engine_inputs
233233
else:
234-
postprocess_kwargs = {}
234+
context = {}
235235

236236
timer.stop(InferenceStages.PRE_PROCESS)
237237
self.log(
@@ -248,9 +248,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
248248
)
249249

250250
# submit split batches to engine threadpool
251-
engine_forward_with_context = partial(
252-
self.engine_forward, context=postprocess_kwargs
253-
)
251+
engine_forward_with_context = partial(self.engine_forward, context=context)
254252
batch_outputs = list(
255253
self.executor.map(engine_forward_with_context, batches)
256254
)
@@ -276,9 +274,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
276274

277275
# ------ POSTPROCESSING ------
278276
timer.start(InferenceStages.POST_PROCESS)
279-
pipeline_outputs = self.process_engine_outputs(
280-
engine_outputs, **postprocess_kwargs
281-
)
277+
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
282278
if not isinstance(pipeline_outputs, self.output_schema):
283279
raise ValueError(
284280
f"Outputs of {self.__class__} must be instances of "

0 commit comments

Comments
 (0)