diff --git a/src/deepsparse/v2/operators/__init__.py b/src/deepsparse/v2/operators/__init__.py index 8f7e6a169d..9d1a9812ac 100644 --- a/src/deepsparse/v2/operators/__init__.py +++ b/src/deepsparse/v2/operators/__init__.py @@ -13,5 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from .operator import * diff --git a/src/deepsparse/v2/operators/engine_operator.py b/src/deepsparse/v2/operators/engine_operator.py index 2c61755df9..b7d920a686 100644 --- a/src/deepsparse/v2/operators/engine_operator.py +++ b/src/deepsparse/v2/operators/engine_operator.py @@ -17,7 +17,8 @@ from pydantic import BaseModel, Field -from deepsparse import Context, Engine, MultiModelEngine, Scheduler +from deepsparse import Context as EngineContext +from deepsparse import Engine, MultiModelEngine, Scheduler from deepsparse.benchmark import ORTEngine from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs from deepsparse.v2.operators import Operator @@ -54,16 +55,15 @@ def __init__( self, model_path: str, engine_type: str = DEEPSPARSE_ENGINE, - batch_size: Optional[int] = 1, num_cores: int = None, num_streams: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, - engine_context: Optional[Context] = None, + engine_context: Optional[EngineContext] = None, + engine_kwargs: Dict = None, ): - - self._batch_size = batch_size self.model_path = model_to_path(model_path) + self._batch_size = 1 self.engine_context = engine_context if self.engine_context is not None: @@ -87,7 +87,7 @@ def __init__( self._engine_args = engine_args self._engine_type = engine_type - self.engine = self.create_engine() + self.engine = self.create_engine(**engine_kwargs) @property def batch_size(self) -> int: @@ -114,12 +114,12 @@ def create_engine( if engine_type == DEEPSPARSE_ENGINE: if self.engine_context is not None and isinstance( - self.engine_context, Context + self.engine_context, EngineContext ): engine_args.pop("num_cores", None) engine_args.pop("scheduler", None) engine_args.pop("num_streams", None) - engine_args["context"] = self.engien_context + engine_args["context"] = self.engine_context return MultiModelEngine( model=onnx_file_path, **engine_args, @@ -135,7 +135,7 @@ def create_engine( f"{SUPPORTED_PIPELINE_ENGINES}" ) - def run(self, inp: EngineOperatorInputs) -> Dict: + def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict: if inp.engine: # run with custom engine, do not split/join since custom engine # may run at any batch size, returning here as code below has a diff --git a/src/deepsparse/v2/operators/operator.py b/src/deepsparse/v2/operators/operator.py index c3a3e28b78..b3963d8223 100644 --- a/src/deepsparse/v2/operators/operator.py +++ b/src/deepsparse/v2/operators/operator.py @@ -17,6 +17,8 @@ from pydantic import BaseModel +from deepsparse.v2.utils import InferenceState, PipelineState + __all__ = ["Operator"] @@ -54,6 +56,8 @@ def has_output_schema(cls) -> bool: def __call__( self, *args, + inference_state: InferenceState, + pipeline_state: PipelineState, **kwargs, ) -> Any: """ @@ -61,7 +65,9 @@ def __call__( :param args: an unnamed arg may only be provided if it is of the type of the input_schema - :param context: pipeline context to pass to operator + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. :param kwargs: kwargs when not initializing from an instantiated schema :return: operator output """ @@ -81,10 +87,18 @@ def __call__( "in the form of a dictionary or an instance of the input_schema" "object" ) - - run_output = self.run(inference_input) + run_output = self.run( + inference_input, + inference_state=inference_state, + pipeline_state=pipeline_state, + ) else: - run_output = self.run(*args, **kwargs) + run_output = self.run( + *args, + inference_state=inference_state, + pipeline_state=pipeline_state, + **kwargs, + ) if self.has_output_schema(): return self.output_schema(**run_output) @@ -93,12 +107,16 @@ def __call__( @abstractmethod def run(self, *args, **kwargs) -> Any: """ - :param inp: operator input, as the defined input schema if applicable - :param context: pipeline context of already run operators :return: result of this operator as the defined output schema if applicable """ raise NotImplementedError + def can_operate(self, inp: Any) -> bool: + """ + Whether or not the given operator can run, based on input + """ + return True + def expand_inputs(self, **kwargs): """ Generic function to handle expanding values. diff --git a/src/deepsparse/v2/pipeline.py b/src/deepsparse/v2/pipeline.py index e58f8a5191..0a8c8b2f93 100644 --- a/src/deepsparse/v2/pipeline.py +++ b/src/deepsparse/v2/pipeline.py @@ -18,6 +18,7 @@ from deepsparse.v2.operators import Operator from deepsparse.v2.routers import Router from deepsparse.v2.schedulers import OperatorScheduler, SchedulerGroup +from deepsparse.v2.utils import InferenceState, PipelineState __all__ = ["Pipeline"] @@ -27,7 +28,7 @@ class Pipeline(Operator): """ Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline will use the router to run through all the defined operators. The operators should - be implemented using the Operator class and each implemented Operator should be + be implemented using the Operator class and each implemented operator should be responsible for a functional component of the pipelines. The flow of inputs/outputs between the operators and the steps in the pipeline should be defined by the router, (based off of the Router class), which dicates the next operator in the pipeline. @@ -37,6 +38,7 @@ class Pipeline(Operator): or dictionary of operators. :param router: A Router which dictates the next operator to call. :param schedulers: A list of schedulers to run operators. + :param pipeline_state: pipeline_state created during pipeline initialization """ @@ -45,57 +47,93 @@ def __init__( ops: Union[Dict[str, Operator], List[Operator]], router: Router, schedulers: List[OperatorScheduler], + pipeline_state: PipelineState = None, ): self.ops = ops self.router = router self.schedulers = schedulers + self.pipeline_state = pipeline_state self.validate() # SchedulerGroup handles running all schedulers in order of priority self._scheduler_group = SchedulerGroup(self.schedulers) - def run(self, *args, **kwargs): + def run( + self, + *args, + inference_state: InferenceState, + pipeline_state: PipelineState, + **kwargs, + ): """ - Run through the operators using the provided router and scheduler. Update the - context to reflect each step of the router. The input to a given operator is the - output of the previous operator. - - :param inp: input to the operator. expected to be of any type that is - expected by the operator. - :param context: context to store the current the inputs, outputs, and operator - for each step of the router. + Run through the operators using the provided router and scheduler. + The input to a given operator is the output of the previous operator. + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. """ next_step = self.router.START_ROUTE operator_output = None + while next_step != self.router.END_ROUTE: # Either a dictionary key or valid index operator = self.ops[next_step] if next_step == self.router.START_ROUTE: output_future = self._scheduler_group.submit( - *args, operator=operator, **kwargs + *args, + inference_state=inference_state, + operator=operator, + pipeline_state=pipeline_state, + **kwargs, ) else: if isinstance(operator_output, dict): output_future = self._scheduler_group.submit( - operator=operator, **operator_output + inference_state=inference_state, + operator=operator, + pipeline_state=pipeline_state, + **operator_output, ) else: output_future = self._scheduler_group.submit( - operator_output, operator=operator + operator_output, + inference_state=inference_state, + pipeline_state=pipeline_state, + operator=operator, ) - # wait for future to resolve operator_output = output_future.result() - next_step = self.router.next(next_step, self.ops) + if isinstance(operator_output, tuple): + state_update = operator_output[-1] + operator_output = operator_output[0] + inference_state.update_state(state_update) + + next_step = self.router.next(next_step, self.ops, operator_output) + return operator_output def __call__(self, *args, **kwargs): """ + Consolidate any provided inference_state or pipeline_state objects and pass + any other operator inputs to run(). + :return: output of the pipeline operators ran with the router for the given - input + input """ + if kwargs.get("inference_state"): + inference_state = kwargs.pop("inference_state") + else: + inference_state = InferenceState() + inference_state.create_state({}) + + if "pipeline_state" in kwargs: + self.pipeline_state = kwargs.get("pipeline_state") + + kwargs["inference_state"] = inference_state + kwargs["pipeline_state"] = self.pipeline_state + return self.run(*args, **kwargs) def validate(self): diff --git a/src/deepsparse/v2/routers/router.py b/src/deepsparse/v2/routers/router.py index 6050803b5e..d1110d4ca7 100644 --- a/src/deepsparse/v2/routers/router.py +++ b/src/deepsparse/v2/routers/router.py @@ -15,14 +15,14 @@ import logging from abc import abstractmethod -from typing import Dict, List, Union +from typing import Any, Dict, List, Optional, Union from deepsparse.v2.operators import Operator _LOGGER = logging.getLogger(__name__) -__all__ = ["Router", "LinearRouter"] +__all__ = ["Router", "LinearRouter", "GraphRouter"] class Router: @@ -32,23 +32,34 @@ class Router: :param start_route: the start index or key of the router :param end_route: the end index or key of the router + :param route: the route that the router has to traverse through """ - def __init__(self, end_route: Union[str, int], start_route: Union[str, int]): + def __init__( + self, + end_route: Union[str, int], + start_route: Union[str, int], + route: Optional[Dict] = None, + ): self.START_ROUTE = start_route self.END_ROUTE = end_route + self.route = route @abstractmethod def next( - self, past: Union[str, int], ops: Union[List[Operator], Dict[str, Operator]] + self, + past: Union[str, int], + ops: Optional[Union[List[Operator], Dict[str, Operator]]], + inp: Optional[Any], ) -> Union[str, int]: """ Determines the index or dictionary key for the next operator which should run. :param past: the previous index or key. This should uniquely determine the next - operator to run + operator to run :param ops: list or dictionary of operators + :param inp: operator input :returns: the next index or dictionary key for the next operator to run """ raise NotImplementedError @@ -69,7 +80,9 @@ class LinearRouter(Router): def __init__(self, end_route: int, start_route: int = 0): super().__init__(end_route=end_route, start_route=start_route) - def next(self, past: int, ops: List[Operator]) -> int: + def next( + self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None + ) -> int: new_index = past + 1 if new_index < self.END_ROUTE: return new_index @@ -105,3 +118,35 @@ def validate(operators: List[Operator]) -> bool: ) return False return True + + +class GraphRouter(Router): + """ + Router for a DAG. Expects graphs be presented in the form of a dictionary, where + keys are the nodes of the graph and the values are the connected nodes. For + nodes with multiple ouput edges, all the nodes will be visited and the first node + where `can_operate` returns True will run. Paths should be deterministic. + """ + + def __init__(self, end_route: str, start_route: str, route: Dict): + super().__init__(end_route=end_route, start_route=start_route, route=route) + + def next( + self, + past: str, + ops: Dict[str, Operator], + inp: Any, + ) -> int: + node = past + if isinstance(self.route[node], str): + return self.route[node] + else: + for neighbour_node in self.route[node]: + neighbour_node_op = ops[neighbour_node] + if neighbour_node_op.can_operate(inp): + return neighbour_node + raise ValueError("Cannot operate on any of the nodes") + + @staticmethod + def validate(ops) -> bool: + pass diff --git a/src/deepsparse/v2/schedulers/scheduler.py b/src/deepsparse/v2/schedulers/scheduler.py index 7d4f249444..78a58e3389 100644 --- a/src/deepsparse/v2/schedulers/scheduler.py +++ b/src/deepsparse/v2/schedulers/scheduler.py @@ -36,19 +36,30 @@ class OperatorScheduler: def __init__(self, max_workers: int = 1): self._threadpool = ThreadPoolExecutor(max_workers=max_workers) - def submit(self, *args, operator: Operator, **kwargs) -> Future: + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: """ :param operator: operator to run - :param operator_input: input schema to the operator - :param context: context of already run operators :return: future referencing the asynchronously run output of the operator """ - return self._threadpool.submit(operator, *args, **kwargs) + return self._threadpool.submit( + operator, + *args, + **kwargs, + ) - def can_process(self, *args, operator: Operator, **kwargs) -> bool: + def can_process( + self, + *args, + operator: Operator, + **kwargs, + ) -> bool: """ :param operator: operator to check - :param operator_input: operator_input to check :return: True if this Operator can process the given operator and input. Base OperatorScheduler always returns True """ diff --git a/src/deepsparse/v2/schedulers/scheduler_group.py b/src/deepsparse/v2/schedulers/scheduler_group.py index 7f00a3c17c..40b5695f22 100644 --- a/src/deepsparse/v2/schedulers/scheduler_group.py +++ b/src/deepsparse/v2/schedulers/scheduler_group.py @@ -34,25 +34,44 @@ class SchedulerGroup(OperatorScheduler): def __init__(self, schedulers: List[OperatorScheduler]): self.schedulers = schedulers - def submit(self, *args, operator: Operator, **kwargs) -> Future: + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: """ :param operator: operator to run - :param operator_input: input schema to the operator - :param context: context of already run operators :return: future referencing the asynchronously run output of the operator """ for scheduler in self.schedulers: - if scheduler.can_process(*args, operator=operator, **kwargs): - return scheduler.submit(*args, operator=operator, **kwargs) + if scheduler.can_process( + *args, + operator=operator, + **kwargs, + ): + return scheduler.submit( + *args, + operator=operator, + **kwargs, + ) - def can_process(self, *args, operator: Operator, **kwargs) -> bool: + def can_process( + self, + *args, + operator: Operator, + **kwargs, + ) -> bool: """ :param operator: operator to check - :param operator_input: operator_input to check :return: True if this Operator can process the given operator and input. SchedulerGroup always returns True """ return any( - scheduler.can_process(*args, operator=operator, **kwargs) + scheduler.can_process( + *args, + operator=operator, + **kwargs, + ) for scheduler in self.schedulers ) diff --git a/src/deepsparse/v2/text_generation/__init__.py b/src/deepsparse/v2/text_generation/__init__.py new file mode 100644 index 0000000000..37ac88d02f --- /dev/null +++ b/src/deepsparse/v2/text_generation/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .autoregressive_preprocess_operator import * +from .compile_logits import * +from .kv_cache_operator import * +from .multi_engine_prefill_operator import * +from .nl_engine_operator import * +from .prep_for_prefill import * +from .process_inputs import * + + +from .pipeline import * # isort:skip diff --git a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py new file mode 100644 index 0000000000..cfe7cb531b --- /dev/null +++ b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import numpy + +from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["AutoRegressiveOperatorPreprocess"] + + +class AutoRegressiveOperatorPreprocess(Operator): + def __init__(self, sequence_length: int, prompt_sequence_length: int): + """ + Prepare the tokens for the single-token engine. This requires creating the + attention mask, positions, and causal mask. The output contains these three + arrays to be passed into the single-token engine. + """ + self.sequence_length = sequence_length + self.prompt_sequence_length = prompt_sequence_length + self.set_capacity = False + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "onnx_input_names_no_cache attribute set from the NLEngineOperator." + ) + + def can_operate(self, inp: Any) -> bool: + """ + Can run this Operator if the number of tokens left to process is greater than + 0 but less than the self.prompt_sequence_length. + """ + tokens = inp.get("tokens") + kv_cache = inp.get("kv_cache") + + remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens + if remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length: + return True + return False + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + + if not self.set_capacity: + self.set_capacity = True + kv_cache.set_capacity(self.sequence_length - 1) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + new_token = tokens[num_total_processed_tokens] + engine_input_names = pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ) + + # padding is added to left, so attention mask is 1s from the + # right up to the number of total tokens (prompt + generated) + attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) + num_attention_entries_to_unmask = min( + num_total_processed_tokens + 1, self.sequence_length + ) # cap by seq len + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64) + input_ids = numpy.array([[new_token]]) + causal_mask = create_causal_mask(input_ids, attention_mask) + + engine_inputs_map = dict( + input_ids=input_ids, + attention_mask=attention_mask, + causal_mask=causal_mask, + positions=positions, + ) + + engine_inputs = [engine_inputs_map[name] for name in engine_input_names] + + onnx_input_names_no_cache = pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ) + engine_inputs = [engine_inputs_map[name] for name in onnx_input_names_no_cache] + + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + } diff --git a/src/deepsparse/v2/text_generation/compile_logits.py b/src/deepsparse/v2/text_generation/compile_logits.py new file mode 100644 index 0000000000..55c87d791d --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_logits.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompilePromptLogits"] + + +class CompilePromptLogits(Operator): + """ + Combine the prompt logits. Currently relying on the inference state to store the + prompt logits for each token or multi-token batch processed. This operator will + take prompt logits from each iteration run and update the inference state. + """ + + def run(self, logits, inference_state: InferenceState, **kwargs): + logit_type = "prompt_logits" + + if inference_state.current_state.get(logit_type) is not None: + current_logits = inference_state.current_state.get(logit_type).copy() + current_logits.append(logits) + else: + current_logits = [logits] + + state_update = {logit_type: current_logits} + return { + "kv_cache": kwargs.get("kv_cache"), + "tokens": kwargs.get("tokens"), + }, state_update diff --git a/src/deepsparse/v2/text_generation/kv_cache_operator.py b/src/deepsparse/v2/text_generation/kv_cache_operator.py new file mode 100644 index 0000000000..0b232402b3 --- /dev/null +++ b/src/deepsparse/v2/text_generation/kv_cache_operator.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pydantic import BaseModel, Field + +from deepsparse.transformers.utils import DecoderKVCache +from deepsparse.transformers.utils.helpers import ( + initialize_kv_cache_state, + prepends_bos_token, +) +from deepsparse.v2.operators import Operator + + +__all__ = ["KVCacheCreator"] + + +class KVCacheCreatorOutput(BaseModel): + kv_cache: Any = Field(description="KV Cache Created") # DecoderKVCache + + +class KVCacheCreatorInput(BaseModel): + cache_shape: Any = Field(description="shape") + kv_cache_data_type: Any = Field(description="data type") + output_names: Any = Field(description="output names") + + +class KVCacheCreator(Operator): + input_schema = KVCacheCreatorInput + output_schema = KVCacheCreatorOutput + + def __init__( + self, + tokenizer, + sequence_length: int, + prompt_sequence_length: int, + internal_kv_cache: bool, + ): + self.tokenizer = tokenizer + self.prompt_sequence_length = prompt_sequence_length + self.internal_kv_cache = internal_kv_cache + self.sequence_length = sequence_length + + def run(self, cache_shape, kv_cache_data_type: str, output_names: list, **kwargs): + kv_cache_state = initialize_kv_cache_state( + cache_shape=cache_shape, + kv_cache_data_type=kv_cache_data_type, + output_names=output_names, + length=self.sequence_length - self.prompt_sequence_length, + empty=bool(self.internal_kv_cache), + ) + + kv_cache = DecoderKVCache(self.internal_kv_cache) + kv_cache.setup( + state=kv_cache_state, + freeze_first_position=prepends_bos_token(self.tokenizer), + ) + return {"kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py new file mode 100644 index 0000000000..41ee830a8a --- /dev/null +++ b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import Enum +from typing import Any + +import numpy + +from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["MultiEnginePrefill"] + + +class OnnxInputNames(Enum): + INPUT_IDS = "input_ids" + ATTN_MASK = "attention_mask" + CAUSAL_MASK = "causal_mask" + POSITIONS = "positions" + + +# NOTE: A possible clean-up could involve combining this Operator and the +# autoregressive_preprocess_operator + + +class MultiEnginePrefill(Operator): + def __init__(self, prompt_sequence_length, sequence_length): + """ + Prepare the tokens for the multi-token engine. This requires creating the + attention mask, positions, and causal mask. The output contains these three + arrays to be passed into the multi-token engine. + """ + self.prompt_sequence_length = prompt_sequence_length + self.sequence_length = sequence_length + self.cases = { + OnnxInputNames.ATTN_MASK.value: self._case_attn_mask, + OnnxInputNames.POSITIONS.value: self._case_positions, + } + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "onnx_input_names_no_cache attribute set from the NLEngineOperator." + ) + + def can_operate(self, inp: Any): + """ + Can only run if the number of prompt tokens left to process is greater than + or equal to the self.prompt_sequence_length. + """ + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + if len(tokens) < self.prompt_sequence_length: + return False + + if ( + len(tokens) - kv_cache.total_num_processed_tokens + >= self.prompt_sequence_length + ): + return True + return False + + def _case_attn_mask(self, num_total_processed_tokens: int): + # create an empty attention mask + engine_input = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) + # calculate the number of entries in attention mask that should be set to 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + self.prompt_sequence_length, + self.sequence_length, + ) + engine_input[:, -num_attention_entries_to_unmask:] = 1 + return engine_input + + def _case_positions(self, num_total_processed_tokens: int): + return ( + numpy.arange( + num_total_processed_tokens, + num_total_processed_tokens + self.prompt_sequence_length, + ) + .reshape(1, -1) + .astype(numpy.int64) + ) + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + + onnx_input_names_no_cache = pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + start = num_total_processed_tokens + end = start + self.prompt_sequence_length + token_batch = tokens[start:end] + + engine_inputs = [] + for name in onnx_input_names_no_cache: + if name == OnnxInputNames.INPUT_IDS.value: + engine_input = numpy.array([token_batch]) + elif ( + name == OnnxInputNames.ATTN_MASK.value + or name == OnnxInputNames.POSITIONS.value + ): + engine_input = self.cases[name](num_total_processed_tokens) + elif name == OnnxInputNames.CAUSAL_MASK.value: + continue + + engine_inputs.append(engine_input) + + if OnnxInputNames.CAUSAL_MASK.value in onnx_input_names_no_cache: + causal_mask = create_causal_mask( + input_ids=engine_inputs[0], + attention_mask=engine_inputs[1], + ) + engine_inputs.append(causal_mask) + + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + } diff --git a/src/deepsparse/v2/text_generation/nl_engine_operator.py b/src/deepsparse/v2/text_generation/nl_engine_operator.py new file mode 100644 index 0000000000..6c1ad1966e --- /dev/null +++ b/src/deepsparse/v2/text_generation/nl_engine_operator.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Any, List, Tuple + +from pydantic import BaseModel, Field + +from deepsparse.utils.onnx import ( + CACHE_INPUT_PREFIX, + overwrite_onnx_model_inputs_for_kv_cache_models, +) +from deepsparse.v2.operators.engine_operator import ( + DEEPSPARSE_ENGINE, + EngineOperator, + EngineOperatorInputs, +) + + +__all__ = ["NLEngineOperator"] + + +class NlEngineInput(BaseModel): + engine_inputs: List = Field(description="engine inputs") + kv_cache: Any = Field(description="kv_cache object") + tokens: List = Field(description="tokens") + + +class NLEngineOperator(EngineOperator): + + """ + Operator for the NL Decoder Engine. This Operator inherits from the EngineOperator. + Specific updates to engine attributes are made through this operator, as well + as updating the kv_cache. This Operator is used for both the single-token and + multi-token case. + """ + + input_schema = NlEngineInput + output_schema = None + + def __init__( + self, + sequence_length: int, + input_ids_length: int, + internal_kv_cache: bool = False, + **kwargs, + ): + + self.kv_cache_data_type = None + ( + onnx_file_path, + output_indices_to_be_cached, + kv_cache_data_type, + ) = overwrite_onnx_model_inputs_for_kv_cache_models( + onnx_file_path=kwargs.get("model_path"), + batch_size=kwargs.get("batch_size", 1), + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + + engine_kwargs = kwargs.get("engine_kwargs", {}) + if kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE: + if "WAND_OPT_FLAGS" not in os.environ: + os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + + if any(output_indices_to_be_cached): + self.kv_cache_data_type = kv_cache_data_type + if ( + internal_kv_cache + and kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE + ): + engine_kwargs["cached_outputs"] = output_indices_to_be_cached + + kwargs["engine_kwargs"] = engine_kwargs + kwargs["model_path"] = onnx_file_path + super().__init__(**kwargs) + + self.input_ids_length = input_ids_length + + def run(self, inp: NlEngineInput, **kwargs) -> Any: + engine_input = inp.engine_inputs + kv_cache = inp.kv_cache + + inputs = self._add_kv_cache_to_input(engine_input, kv_cache) + if bool(kv_cache.engine_internal_cache): + # conventionally, before dispatching + # inputs to the engine, we validate them + # if val_inp=True. However, in this case + # we want to pass the empty kv cache inputs + # (batch_size=0) to the engine. Therefore, + # we skip the validation + out = self.engine._eng_net.execute_list_out( + inputs, kv_cache.engine_internal_cache + ) + else: + # run the engine without the LIB.kv_cache object + out = ( + super() + .run(EngineOperatorInputs(engine_inputs=inputs), **kwargs) + .get("engine_outputs") + ) + + logits, *kv_cache_state = out + self._update_kv_cache( + kv_cache_state=kv_cache_state, + input_ids_len=self.input_ids_length, + kv_cache=kv_cache, + ) + + output = {"logits": logits, "kv_cache": kv_cache, "tokens": inp.tokens} + return output + + def _add_kv_cache_to_input(self, engine_input, kv_cache): + kv_cache_state = copy.copy(kv_cache.cached_inputs) + + for idx, input_name in enumerate(self.onnx_input_names_no_cache): + kv_cache_state[input_name] = engine_input[idx] + + new_inp = [kv_cache_state[name] for name in self.engine.input_names] + return new_inp + + def _update_kv_cache(self, kv_cache_state, input_ids_len, kv_cache): + if bool(kv_cache.engine_internal_cache): + kv_cache.total_num_processed_tokens += input_ids_len + return + + kv_cache_state = { + name: array + for name, array in zip(self.onnx_input_names_cached, kv_cache_state) + } + + kv_cache.update( + state=kv_cache_state, + input_ids_len=input_ids_len, + ) + + @property + def onnx_input_names_no_cache(self) -> List[str]: + """ + :return: The input names for the onnx model, excluding + the potential kv cache inputs + """ + return [ + name + for name in self.engine.input_names + if not name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def onnx_input_names_cached(self) -> List[str]: + """ + :return: The cached input names for the onnx model + """ + return [ + name + for name in self.engine.input_names + if name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def cache_shape(self) -> Tuple[int, int, int, int]: + """ + :return: The shape of the kv cache inputs + for the onnx model. The shape is + (batch_size, num_heads, sequence_length, hidden_size) + """ + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if CACHE_INPUT_PREFIX in name + ) + return self.engine.input_shapes[cache_engine_input_index] + + @property + def output_names(self) -> List[str]: + """ + :return: The output names for the onnx model + """ + return self.engine.output_names diff --git a/src/deepsparse/v2/text_generation/pipeline.py b/src/deepsparse/v2/text_generation/pipeline.py new file mode 100644 index 0000000000..9878aa0061 --- /dev/null +++ b/src/deepsparse/v2/text_generation/pipeline.py @@ -0,0 +1,213 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.v2.operators import Operator +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers import GraphRouter +from deepsparse.v2.schedulers import OperatorScheduler +from deepsparse.v2.text_generation import ( + AutoRegressiveOperatorPreprocess, + CompilePromptLogits, + KVCacheCreator, + MultiEnginePrefill, + NLEngineOperator, + PrepareforPrefill, + ProcessInputsTextGeneration, +) +from deepsparse.v2.utils import PipelineState + + +class TextGenerationPipeline(Pipeline): + def __init__( + self, + model_path: str, + prompt_sequence_length: int = 16, + sequence_length: int = 1024, + internal_kv_cache: bool = True, + force_max_tokens: bool = False, + generation_config=None, + engine_kwargs: Dict = None, + ): + + pipeline_state = PipelineState() + pipeline_state_vals = {} + + # TODO: The code below will be replaced with a transformers set-up Operator. + self.tokenizer = None + model_path = self.setup_onnx_file_path(model_path, sequence_length) + self.tokenizer.padding_side = "left" + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + if not engine_kwargs: + engine_kwargs = {} + engine_kwargs["model_path"] = model_path + + if internal_kv_cache and engine_kwargs.get("engine_type") == "onnxruntime": + internal_kv_cache = False + + single_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=1, + **engine_kwargs, + ) + + multi_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=prompt_sequence_length, + **engine_kwargs, + ) + + # NOTE: Currently using pipeline state. Can swap to simply pass in the + # attributes to the specific Operator that neeed them, as class attributes. + pipeline_state_vals[ + "onnx_input_names_no_cache" + ] = single_engine_operator.onnx_input_names_no_cache + pipeline_state_vals["cache_shape"] = single_engine_operator.cache_shape + pipeline_state_vals["output_names"] = single_engine_operator.output_names + pipeline_state_vals[ + "kv_cache_data_type" + ] = single_engine_operator.kv_cache_data_type + pipeline_state.create_state(pipeline_state_vals) + + process_inputs = ProcessInputsTextGeneration( + generation_config=process_generation_config(generation_config), + sequence_length=sequence_length, + tokenizer=self.tokenizer, + ) + + kv_cache_creator = KVCacheCreator( + sequence_length=sequence_length, + tokenizer=self.tokenizer, + prompt_sequence_length=prompt_sequence_length, + internal_kv_cache=internal_kv_cache, + ) + + # NOTE: Can also have the KVCacheCreator be initialized inside this Operator. + # Relies on pipeline state variables set-up above (can be swapped to be class + # attributes instead of using the state. + engine_inputs_for_prefill = PrepareforPrefill(kv_cache_creator=kv_cache_creator) + + multi_engine_prefill = MultiEnginePrefill( + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + ) + compile_prompt_logits = CompilePromptLogits() + """ + prep_for_single_engine = PrepareforSingleEngine( + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + ) + """ + autoregressive_preprocess = AutoRegressiveOperatorPreprocess( + sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, + ) + final_step = FinalStep() + + ops = { + "process_input": process_inputs, + "single_engine": single_engine_operator, + "multi_engine": multi_engine_operator, + "kv_cache_creator": kv_cache_creator, + "prepare_prefill": engine_inputs_for_prefill, + "multi_engine_prefill": multi_engine_prefill, + "compile_logits": compile_prompt_logits, + "autoregressive_preprocess": autoregressive_preprocess, + "final_step": final_step, + } + + routes = { + "process_input": "prepare_prefill", + "prepare_prefill": ["multi_engine_prefill", "autoregressive_preprocess"], + "multi_engine_prefill": "multi_engine", + "multi_engine": "compile_logits", + "compile_logits": [ + "multi_engine_prefill", + "autoregressive_preprocess", + "final_step", + ], + "autoregressive_preprocess": "single_engine", + "single_engine": "compile_logits", + "final_step": "STOP", + } + + router = GraphRouter( + end_route="STOP", start_route="process_input", route=routes + ) + scheduler = [OperatorScheduler()] + super().__init__( + ops=ops, router=router, schedulers=scheduler, pipeline_state=pipeline_state + ) + + # TODO: Move to be part of a generic transformers set-up Operator. + def setup_onnx_file_path(self, model_path, sequence_length) -> str: + import logging + + import transformers + from transformers import AutoTokenizer + + from deepsparse.transformers.helpers import get_deployment_path + + """ + Parses ONNX model from the `model_path` provided. It additionally + creates config and tokenizer objects from the `deployment path`, + derived from the `model_path` provided. + + :return: file path to the processed ONNX file for the engine to compile + """ + deployment_path, onnx_path = get_deployment_path(model_path) + + hf_logger = logging.getLogger("transformers") + hf_logger_level = hf_logger.level + hf_logger.setLevel(logging.ERROR) + self.config = transformers.PretrainedConfig.from_pretrained( + deployment_path, + finetuning_task=self.task if hasattr(self, "task") else None, + ) + hf_logger.setLevel(hf_logger_level) + + self._trust_remote_code = False + self.tokenizer = AutoTokenizer.from_pretrained( + deployment_path, + trust_remote_code=self._trust_remote_code, + model_max_length=sequence_length, + ) + + if not self.config or not self.tokenizer: + raise RuntimeError( + "Invalid config or tokenizer provided. Please provide " + "paths to the files or ensure they exist in the `model_path` provided. " + "See `tokenizer` and `config` arguments for details." + ) + return onnx_path + + +# NOTE: This is a dummy last step which will be removed. Used as a final step +# for the current routes. +class FinalStep(Operator): + def can_operate(self, *args, **kwargs): + return True + + def run(self, *args, **kwargs): + import numpy + + inference_state = kwargs.get("inference_state") + prompt_logits = inference_state.current_state.get("prompt_logits") + return numpy.concatenate(prompt_logits, axis=1) diff --git a/src/deepsparse/v2/text_generation/prep_for_prefill.py b/src/deepsparse/v2/text_generation/prep_for_prefill.py new file mode 100644 index 0000000000..2f9eb15797 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_prefill.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["PrepareforPrefill"] + + +class PrepareforPrefill(Operator): + def __init__(self, kv_cache_creator: Operator): + """ + Operator before prefill. Responsible for creating the kv_cache based on engine + variables. Currently, this operator expects that the kv_cache_creator is + provided during initization and then uses pipeline_state to run the + kv_cache_operator. + """ + # NOTE: Alternatively, we can initialize the kv_cache_creater operator here, + # instead of at the pipeline level. + self.kv_cache_creator = kv_cache_creator + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "cache_shape, output_names, kv_cache_data_type attributes to be set " + "from the NLEngineOperator" + ) + + def run(self, tokens: Any, pipeline_state: PipelineState, **kwargs): + # NOTE: Can potentially just be class attributes instead of relying on + # pipeline state. + cache_shape = pipeline_state.current_state.get("cache_shape") + data_type = pipeline_state.current_state.get("kv_cache_data_type") + output_names = pipeline_state.current_state.get("output_names") + + kv_cache = self.kv_cache_creator.run( + cache_shape=cache_shape, + kv_cache_data_type=data_type, + output_names=output_names, + ).get("kv_cache") + return {"tokens": tokens, "kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/process_inputs.py b/src/deepsparse/v2/text_generation/process_inputs.py new file mode 100644 index 0000000000..528dcee0b7 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_inputs.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from typing import Dict, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import TextGenerationInput +from deepsparse.transformers.utils.helpers import ( + check_and_return_generation_config, + override_config, + repeat_inputs, +) +from deepsparse.v2.operators import Operator + + +class GenerationDefaults: + num_return_sequences = 1 + max_length = 1024 + max_new_tokens = None + output_scores = False + top_k = 0 + top_p = 0.0 + repetition_penalty = 0.0 + do_sample = False + temperature = 1.0 + + +__all__ = ["ProcessInputsTextGeneration"] + + +class ProcessInputsTextGeneration(Operator): + """ + Input processing operator. Responsible for tokenizing the input, handling the + generation_config (if provided), updating the inference_state for later use, + and returning the tokens for prompt inferece. The expected input is defined by + the input_schema, which for this operator is TextGeneratioInput. + """ + + input_schema = TextGenerationInput + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + generation_config: Union[ + str, pathlib.Path, Dict, transformers.GenerationConfig + ], + sequence_length: int, + ): + self.generation_config = generation_config + self.tokenizer = tokenizer + self.sequence_length = sequence_length + + def run(self, inp: TextGenerationInput, **kwargs): + generation_config = check_and_return_generation_config( + self.generation_config, inp.generation_config, GenerationDefaults() + ) + + generation_config = override_config(inp.generation_kwargs, generation_config) + + original_inputs = inp.sequences + if generation_config.num_return_sequences > 1: + if isinstance(inp.sequences, str): + inp.sequences = [inp.sequences] + inp.sequences = repeat_inputs( + inp.sequences, generation_config.num_return_sequences + ) + + if inp.fixed_sequences_length: + # to enforce a fixed sequence length, we need to + # truncate the input to the maximum sequence length + # or/and pad it to the maximum sequence length + truncate, padding = True, "max_length" + else: + # otherwise, we do not need to truncate the input + # and we shall can pad it to the longest sequence + # in the batch (so that the engine can process multiple inputs + # at once) + truncate, padding = False, "longest" + + input_tokens = self.tokenizer( + inp.sequences, + return_tensors="np", + max_length=self.sequence_length, + padding=padding, + truncation=truncate, + ) + + input_ids = input_tokens["input_ids"] + attention_mask = input_tokens["attention_mask"] + + inference_state_update = dict( + prompts=original_inputs, + streaming=inp.streaming, + generation_config=generation_config, + include_prompt_logits=inp.include_prompt_logits, + callback=inp.callback, + stop=inp.stop, + top_p=generation_config.top_p, + top_k=generation_config.top_k, + presence_penalty=inp.presence_penalty, + frequency_penalty=generation_config.repetition_penalty, + ) + + # TODO: move this step to prep_for_prefill and add attention mask to the output + # this will allow us to split/join more easily when processing multiple prompts + # in parallel + tokens = input_ids[attention_mask.nonzero()].tolist() + return {"tokens": tokens}, inference_state_update diff --git a/src/deepsparse/v2/utils/__init__.py b/src/deepsparse/v2/utils/__init__.py index a36d8e92ec..358405d7af 100644 --- a/src/deepsparse/v2/utils/__init__.py +++ b/src/deepsparse/v2/utils/__init__.py @@ -13,5 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from .state import * from .types import * diff --git a/src/deepsparse/v2/utils/state.py b/src/deepsparse/v2/utils/state.py new file mode 100644 index 0000000000..b54b890acf --- /dev/null +++ b/src/deepsparse/v2/utils/state.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from abc import ABC +from typing import Any, Union + + +__all__ = ["State", "PipelineState", "InferenceState"] + + +class State(ABC): + """ + Abstract class to store pipeline-level and inference-level state variables which + are generated by some Operator, and required by some other Operator. + """ + + def __init__(self): + self._current_state = None + + @property + def current_state(self): + return self._current_state + + +class PipelineState(State): + """ + Created during pipeline initialization. Pipeline state values are ready-only + duirng inference. + """ + + def create_state(self, new_state: dict): + if self._current_state: + raise ValueError("State creation is only allowed during initialization.") + self._current_state = new_state + + +class InferenceState(State): + """ + Inference state, created during every inference run. + """ + + def create_state(self, new_state: dict): + if self._current_state: + warnings.warn("Current state already exists, overriding.") + self._current_state = new_state + + def update_value(self, attribute: str, value: Union[str, int, list]): + if not self._current_state.get(attribute): + raise ValueError(f"{attribute} is not a valid state attribute") + self._current_state[attribute] = value + + def update_state(self, value: Any): + self._current_state.update(value) diff --git a/tests/deepsparse/v2/test_basic_pipeline.py b/tests/deepsparse/v2/test_basic_pipeline.py index 9f85e4976e..bedddd537a 100644 --- a/tests/deepsparse/v2/test_basic_pipeline.py +++ b/tests/deepsparse/v2/test_basic_pipeline.py @@ -34,7 +34,7 @@ class AddOneOperator(Operator): input_schema = IntSchema output_schema = IntSchema - def run(self, inp: IntSchema) -> Dict: + def run(self, inp: IntSchema, **kwargs) -> Dict: return {"value": inp.value + 1} @@ -42,7 +42,7 @@ class AddTwoOperator(Operator): input_schema = IntSchema output_schema = IntSchema - def run(self, inp: IntSchema) -> Dict: + def run(self, inp: IntSchema, **kwargs) -> Dict: return {"value": inp.value + 2}