Skip to content

Commit 1dc2bd1

Browse files
committed
update pipeline to use kwargs
1 parent 95a6377 commit 1dc2bd1

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

Diff for: src/deepsparse/pipeline.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
259259
)
260260

261261
# join together the batches of size `self._batch_size`
262-
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
262+
engine_outputs = self.join_engine_outputs(
263+
batch_outputs, orig_batch_size, **context
264+
)
263265
timer.stop(InferenceStages.ENGINE_FORWARD)
264266

265267
self.log(
@@ -470,7 +472,7 @@ def to_config(self) -> "PipelineConfig":
470472
)
471473

472474
def join_engine_outputs(
473-
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
475+
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs
474476
) -> List[numpy.ndarray]:
475477
"""
476478
Joins list of engine outputs together into one list.

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def __init__(
273273
self.tokenizer.pad_token = self.tokenizer.eos_token
274274

275275
self.engine, self.multitoken_engine = self.initialize_engines()
276-
self.streaming = False
277276

278277
def initialize_engines(
279278
self,
@@ -419,7 +418,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
419418
:param inputs: the input schema for the pipeline
420419
:return: the inputs for the engine
421420
"""
422-
self.streaming = inputs.streaming
423421
if not self.cache_support_enabled and inputs.max_tokens > 1:
424422
raise ValueError(
425423
"The model used for inference does not support kv cache. It is "
@@ -488,6 +486,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
488486

489487
context = dict(
490488
prompts=original_inputs,
489+
streaming=inputs.streaming,
491490
num_generated_predictions=inputs.num_generated_predictions,
492491
return_logits=inputs.return_logits,
493492
include_prompt_logits=inputs.include_prompt_logits,
@@ -547,8 +546,9 @@ def process_engine_outputs(
547546
"""
548547

549548
prompts = kwargs.get("prompts")
549+
streaming = kwargs.get("streaming")
550550

551-
if self.streaming:
551+
if streaming:
552552
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)
553553

554554
generated_tokens, generated_logits, finished_reason = list(*engine_outputs)
@@ -611,6 +611,7 @@ def engine_forward(
611611

612612
with self.timer_manager.new_timer_context(total_inference=False) as timer:
613613
finished_reason = []
614+
streaming = context.get("streaming")
614615

615616
if not self.cache_support_enabled:
616617
prompt_logits = self.multitoken_engine(engine_inputs)
@@ -688,17 +689,17 @@ def engine_forward(
688689
if len(generated_tokens) == max_tokens:
689690
finished_reason.append(FinishReason.LENGTH)
690691

691-
if self.streaming:
692+
if streaming:
692693
yield (numpy.array([token]), numpy.array([logits]), [None])
693694

694-
if self.streaming:
695+
if streaming:
695696
yield (
696697
numpy.array([token]),
697698
numpy.array([logits]),
698699
[finished_reason[-1]],
699700
)
700701

701-
if not self.streaming:
702+
if not streaming:
702703
yield (
703704
numpy.array([generated_tokens]),
704705
numpy.concatenate(generated_logits, axis=1),
@@ -895,6 +896,7 @@ def join_engine_outputs(
895896
self,
896897
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
897898
orig_batch_size: int,
899+
**kwargs,
898900
) -> List[Union[numpy.ndarray, FinishReason]]:
899901
"""
900902
Takes a list of outputs (batches) from the engine
@@ -906,7 +908,8 @@ def join_engine_outputs(
906908
:param orig_batch_size: The original batch size
907909
:return: A list of joined outputs
908910
"""
909-
if self.streaming:
911+
streaming = kwargs.get("streaming")
912+
if streaming:
910913
for batch in batch_outputs:
911914
for outputs in batch:
912915
yield outputs

0 commit comments

Comments
 (0)