Skip to content

Commit b4ae2dd

Browse files
dbogunowiczbfinerandsikka
authored
[Feature Branch] The almighty V2 (#1422)
* Pipelines Refactor - Initial Impl (#1287) * [Pipeline Refactor] Additional functionality, engine operator, linear router and image classification pipeline/operators/example (#1325) * initial functionality and working example with image classification * remove testing image * update args * initial functionality and working example with image classification * remove testing image * pr comments * defines schemas for operators and test * add image classification test, PR comments * fix input/output handling in pipeline and operator base classes to be more generic; remove context * add additional operator input message * typo fix * [v2] EngineOperator updates to make continuous batching easier (#1371) * [v2] EngineOperator updates to make continuous batching easier * test fixes * [Pipeline Refactor] Update routes, text generation initial functionality (#1348) * initial functionality and working example with image classification * remove testing image * rebase fixes * initial functionality and working example with image classification * text gen * updates func * prompt inference, initial functionality * remove image; update state docstring * Fix typo * add todo for split/join * remove context, clean-up args, remove prefill_preprocess_operaator * fix docstrings * [Pipeline Refactor] Additional Operators, Route update and completed generation functionality (#1356) * initial functionality and working example with image classification * remove testing image * rebase fixes * initial functionality and working example with image classification * text gen * updates func * prompt inference, initial functionality * remove image; update state docstring * Fix typo * add todo for split/join * remove context, clean-up args, remove prefill_preprocess_operaator * fix docstrings * initial functionality and working example with image classification * updates func * prompt inference, initial functionality * finish generation operators and update routes * further breakdown operators * add operators * fix can_operate condition * update can_operate to not rely on the inference_state * rebase + update * fix condition * fix capacity settting again * typo fixes * [Pipeline Refactor] Split/Join Functionality for multiple prompts (#1384) * add split/join functionality * update router to include split/join in parent class, refactor pipeline code to remove repeat code, update map function * process multiple generations * move map to base class * [Pipeline Refactor] Unit Testing for Text Generation Operators (#1392) * unit testing for text generation operators * additional changes * unit testing completion * remove debug * fix * add todo * more clean-up * fix test * add docstrings/comments * break out tests to individual unit test files; add conftest and make scope of fixtures module to help with speed * fix name * [Continuous Batching] Queue Implementation to support batching grouping and prioritization (#1373) * [Continuous Batching] Queue Implementation to support batching grouping and prioritization * has_key method * thread safety * add blocking option for pop_batch * update docstring * allow mutex to be shared across continuous batching objects * revert last commit * [Continuous Batching] Executor thread for running continuous batching (#1374) * [Continuous Batching] Executor thread for running continuous batching * quality * ensure that executor stops when main thread does - clean up test hack * [ContinuousBatching] ContinuousBatchingScheduler Implementation (#1375) * [ContinuousBatching] ContinuousBatchingScheduler Implementation * cleanup unnecessary stop condition * [continuous batching] singleton pattern for scheduler (#1391) * [continuous batching] singleton pattern for scheduler * catch from review * [Pipeline Refactor][Text-Generation] Create a helper function for creating engine_inputs (#1364) * rebasing off my initial commit * cleanups * unit testing for text generation operators * additional changes * unit testing completion * remove debug * fix * add todo * more clean-up * fix test * add docstrings/comments * break out tests to individual unit test files; add conftest and make scope of fixtures module to help with speed * Delete tests/deepsparse/v2/unit/text_generation/test_msic.py --------- Co-authored-by: Dipika Sikka <[email protected]> * [Pipeline Refactor][Text-Generation] Refactor `transformers` helpers functions (#1394) * add split/join functionality * update router to include split/join in parent class, refactor pipeline code to remove repeat code, update map function * process multiple generations * initial commit * fix error * unit testing for text generation operators * additional changes * unit testing completion * remove debug * fix * add todo * more clean-up * fix test * add docstrings/comments * break out tests to individual unit test files; add conftest and make scope of fixtures module to help with speed * Delete tests/deepsparse/v2/unit/text_generation/test_msic.py * pipeline runs, but incorrectly * Revert "pipeline runs, but incorrectly" This reverts commit 51c4ee6. * PR review comments --------- Co-authored-by: Dipika Sikka <[email protected]> * [Text Generation][V2] End-to-end tests (#1402) * initial commit * initial commit * its working now * beautification * thank you Dipika <3 * ready to review * [Pipeline Refactor][Text Generation][Continuous Batching] Integration (#1409) * update split/join * use map * update * run end-to-end * clean-up * fix bug with batch size, introduce SplitRoute dataclass * update tests to use new inputs/outputs * use the normal scheduler for internal kv_cache * add pipeline inpuits * clean-up * change engine type, update docstrings, update override function to be more generic * move subgraph functionality to its own function; clean-up cont batching in text gen pipeline * update linear pathway to also use subgraph execution * rebase fix * fix tests * [Pipeline Refactor] Operator Registry (#1420) * initial registry functionality * use sparsezoo mixin * [Pipeline Refactor] Fix Operator scheduling to fix issue with slow execution (#1453) * fix scheduling to fix issue with engine running very slowly; introduce new completed attribute for Subgraph instead of checking instance type * fix warning message * [Pipeline Refactor] Add `Pipeline.create` method to initialize pipelines (#1457) * add pipeline create method for pipeline creation using the operator registry * add instance check * [Pipeline Refactor] async (#1380) * initial functionality and working example with image classification * remove testing image * rebase fixes * initial functionality and working example with image classification * text gen * updates func * prompt inference, initial functionality * remove image; update state docstring * Fix typo * add todo for split/join * remove context, clean-up args, remove prefill_preprocess_operaator * fix docstrings * initial functionality and working example with image classification * updates func * prompt inference, initial functionality * finish generation operators and update routes * further breakdown operators * add operators * fix can_operate condition * update can_operate to not rely on the inference_state * rebase + update * fix condition * async initial functionality * fix capacity settting again * add blocking * more testing * update to use split/join * fix * rebase fix * remove index * change event loop * rebase fix * update async run to use new operator scheduling properly * rebase fixes (#1458) * more fixes (#1459) --------- Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 317be2d commit b4ae2dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+5698
-2
lines changed

src/deepsparse/transformers/helpers.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
"""
1818

1919

20+
import logging
2021
import os
2122
import re
2223
from pathlib import Path
2324
from tempfile import NamedTemporaryFile
24-
from typing import List, Optional, Tuple, Union
25+
from typing import Any, Dict, List, Optional, Tuple, Union
2526

2627
import numpy
2728
import onnx
29+
import transformers
2830
from onnx import ModelProto
2931

3032
from deepsparse.log import get_main_logger
@@ -38,6 +40,7 @@
3840

3941
__all__ = [
4042
"get_deployment_path",
43+
"setup_transformers_pipeline",
4144
"overwrite_transformer_onnx_model_inputs",
4245
"fix_numpy_types",
4346
"get_transformer_layer_init_names",
@@ -47,6 +50,82 @@
4750
_LOGGER = get_main_logger()
4851

4952

53+
def setup_transformers_pipeline(
54+
model_path: str,
55+
sequence_length: int,
56+
tokenizer_padding_side: str = "left",
57+
engine_kwargs: Optional[Dict] = None,
58+
) -> Tuple[
59+
str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer, Dict[str, Any]
60+
]:
61+
"""
62+
A helper function that sets up the model path, config, tokenizer,
63+
and engine kwargs for a transformers model.
64+
:param model_path: The path to the model to load
65+
:param sequence_length: The sequence length to use for the model
66+
:param tokenizer_padding_side: The side to pad on for the tokenizer,
67+
either "left" or "right"
68+
:param engine_kwargs: The kwargs to pass to the engine
69+
:return The model path, config, tokenizer, and engine kwargs
70+
"""
71+
model_path, config, tokenizer = fetch_onnx_file_path(model_path, sequence_length)
72+
73+
tokenizer.padding_side = tokenizer_padding_side
74+
if not tokenizer.pad_token:
75+
tokenizer.pad_token = tokenizer.eos_token
76+
77+
engine_kwargs = engine_kwargs or {}
78+
if engine_kwargs.get("model_path"):
79+
raise ValueError(
80+
"The engine kwargs already specify "
81+
f"a model path: {engine_kwargs['model_path']}, "
82+
f"but a model path was also provided: {model_path}. "
83+
"Please only provide one."
84+
)
85+
engine_kwargs["model_path"] = model_path
86+
return model_path, config, tokenizer, engine_kwargs
87+
88+
89+
def fetch_onnx_file_path(
90+
model_path: str,
91+
sequence_length: int,
92+
task: Optional[str] = None,
93+
) -> Tuple[str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer]:
94+
"""
95+
Parses ONNX model from the `model_path` provided. It additionally
96+
creates config and tokenizer objects from the `deployment path`,
97+
derived from the `model_path` provided.
98+
:param model_path: path to the model to be parsed
99+
:param sequence_length: maximum sequence length of the model
100+
:return: file path to the processed ONNX file for the engine to compile
101+
"""
102+
deployment_path, onnx_path = get_deployment_path(model_path)
103+
104+
hf_logger = logging.getLogger("transformers")
105+
hf_logger_level = hf_logger.level
106+
hf_logger.setLevel(logging.ERROR)
107+
108+
config = transformers.PretrainedConfig.from_pretrained(
109+
deployment_path, finetuning_task=task
110+
)
111+
hf_logger.setLevel(hf_logger_level)
112+
113+
trust_remote_code = False
114+
tokenizer = transformers.AutoTokenizer.from_pretrained(
115+
deployment_path,
116+
trust_remote_code=trust_remote_code,
117+
model_max_length=sequence_length,
118+
)
119+
120+
if not config or not tokenizer:
121+
raise RuntimeError(
122+
"Invalid config or tokenizer provided. Please provide "
123+
"paths to the files or ensure they exist in the `model_path` provided. "
124+
"See `tokenizer` and `config` arguments for details."
125+
)
126+
return onnx_path, config, tokenizer
127+
128+
50129
def get_deployment_path(model_path: str) -> Tuple[str, str]:
51130
"""
52131
Returns the path to the deployment directory

src/deepsparse/transformers/pipelines/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def setup_onnx_file_path(self) -> str:
124124
125125
:return: file path to the processed ONNX file for the engine to compile
126126
"""
127+
127128
deployment_path, onnx_path = get_deployment_path(self.model_path)
128129
self._deployment_path = deployment_path
129130

src/deepsparse/transformers/utils/helpers.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import pathlib
1616
import uuid
17-
from typing import Any, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy
2020
from transformers import AutoTokenizer, GenerationConfig
@@ -33,6 +33,7 @@
3333
"override_config",
3434
"process_generation_config",
3535
"validate_session_ids",
36+
"compute_engine_inputs",
3637
"set_generated_length",
3738
]
3839

@@ -82,6 +83,95 @@ def set_generated_length(
8283
)
8384

8485

86+
def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]:
87+
"""
88+
Given the names of the onnx inputs, compute the inputs
89+
to the engine. The inputs will be calculating from the
90+
passed kwargs. The information about the required kwargs
91+
can be found in the docstring of the individual compute
92+
functions.
93+
94+
:param onnx_input_names: The names of the onnx inputs
95+
:param kwargs: The kwargs to compute the inputs from
96+
:return: The computed inputs to the engine
97+
"""
98+
engine_inputs = []
99+
for input_name in onnx_input_names:
100+
if input_name == "causal_mask":
101+
# delay the computation of the causal mask
102+
continue
103+
# fetch the compute function for the
104+
# given input_name
105+
compute_func = _get_compute_func(input_name)
106+
# compute the engine input from the kwargs
107+
# and append it to the engine_inputs
108+
engine_inputs.append(compute_func(**kwargs))
109+
110+
if "causal_mask" in onnx_input_names:
111+
# compute the causal mask and append it to the engine_inputs
112+
input_ids, attention_mask, *_ = engine_inputs
113+
engine_inputs.append(create_causal_mask(input_ids, attention_mask))
114+
115+
return engine_inputs
116+
117+
118+
def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]:
119+
# given the input_name, return the appropriate compute function
120+
compute_func = {
121+
"input_ids": _compute_input_ids,
122+
"attention_mask": _compute_attention_mask,
123+
"positions": _compute_positions,
124+
}.get(input_name)
125+
if compute_func is None:
126+
raise ValueError(
127+
"Could not find compute function " f"for the input_name: {input_name}"
128+
)
129+
return compute_func
130+
131+
132+
def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray:
133+
# convert the token_batch to a numpy array
134+
return numpy.array([token_batch])
135+
136+
137+
def _compute_attention_mask(
138+
sequence_length: int,
139+
prompt_sequence_length: int,
140+
num_total_processed_tokens: int,
141+
**kwargs,
142+
) -> numpy.ndarray:
143+
# create a fully masked attention mask with the appropriate
144+
# shape (equal to the sequence_length)
145+
attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64)
146+
# unmask the appropriate number of tokens, the sum of
147+
# - the number of tokens already processed and cached (num_total_processed_tokens)
148+
# - the number of tokens currently processed (prompt_sequence_length)
149+
# the sum cannot exceed the maximum length of the attention_mask
150+
num_attention_entries_to_unmask = min(
151+
num_total_processed_tokens + prompt_sequence_length, sequence_length
152+
)
153+
# unmask the bits from the right-hand side
154+
attention_mask[:, -num_attention_entries_to_unmask:] = 1
155+
return attention_mask
156+
157+
158+
def _compute_positions(
159+
num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs
160+
):
161+
# create the positions array with the appropriate shape
162+
# positions count starts from the number of tokens already processed
163+
# and ends at the number of tokens already processed + the number of tokens
164+
# currently processed
165+
return (
166+
numpy.arange(
167+
num_total_processed_tokens,
168+
num_total_processed_tokens + prompt_sequence_length,
169+
)
170+
.reshape(1, -1)
171+
.astype(numpy.int64)
172+
)
173+
174+
85175
def validate_session_ids(
86176
session_ids: Optional[str], other_attributes: Dict[str, Any]
87177
) -> Optional[List[str]]:

src/deepsparse/v2/__init__.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# flake8: noqa
2+
3+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from .operators import *
18+
from .pipeline import *
19+
from .routers import *
20+
from .schedulers import *
21+
from .task import *
22+
from .utils import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
from .postprocess_operator import *
17+
from .preprocess_operator import *
18+
19+
20+
from .pipeline import * # isort:skip
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import warnings
17+
from typing import Dict, Optional, Tuple, Union
18+
19+
from deepsparse.v2.image_classification.postprocess_operator import (
20+
ImageClassificationPostProcess,
21+
)
22+
from deepsparse.v2.image_classification.preprocess_operator import (
23+
ImageClassificationPreProcess,
24+
)
25+
from deepsparse.v2.operators.engine_operator import EngineOperator
26+
from deepsparse.v2.pipeline import Pipeline
27+
from deepsparse.v2.routers.router import LinearRouter
28+
from deepsparse.v2.schedulers.scheduler import OperatorScheduler
29+
30+
31+
_LOGGER = logging.getLogger(__name__)
32+
33+
__all__ = ["ImageClassificationPipeline"]
34+
35+
36+
class ImageClassificationPipeline(Pipeline):
37+
def __init__(
38+
self,
39+
model_path: str,
40+
engine_kwargs: Optional[Dict] = None,
41+
class_names: Union[None, str, Dict[str, str]] = None,
42+
image_size: Optional[Tuple[int]] = None,
43+
top_k: int = 1,
44+
):
45+
if not engine_kwargs:
46+
engine_kwargs = {}
47+
engine_kwargs["model_path"] = model_path
48+
elif engine_kwargs.get("model_path") != model_path:
49+
warnings.warn(f"Updating engine_kwargs to include {model_path}")
50+
51+
engine = EngineOperator(**engine_kwargs)
52+
preproces = ImageClassificationPreProcess(
53+
model_path=engine.model_path, image_size=image_size
54+
)
55+
postprocess = ImageClassificationPostProcess(
56+
top_k=top_k, class_names=class_names
57+
)
58+
59+
ops = [preproces, engine, postprocess]
60+
router = LinearRouter(end_route=len(ops))
61+
scheduler = [OperatorScheduler()]
62+
super().__init__(ops=ops, router=router, schedulers=scheduler)

0 commit comments

Comments
 (0)