Skip to content

Commit ea82e99

Browse files
committed
remove files to make review easier
1 parent 950c653 commit ea82e99

File tree

5 files changed

+28
-463
lines changed

5 files changed

+28
-463
lines changed

src/deepsparse/pipeline.py

+21-33
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
263263
batches = self.split_engine_inputs(engine_inputs, self._batch_size)
264264

265265
# submit split batches to engine threadpool
266-
batch_outputs = [self.engine_forward(x) for x in batches]
266+
batch_outputs = list(self.executor.map(self.engine_forward, batches))
267267

268268
# join together the batches of size `self._batch_size`
269269
engine_outputs = self.join_engine_outputs(batch_outputs)
@@ -567,34 +567,6 @@ def _register_pipeline_tasks_decorator(pipeline_class: Pipeline):
567567

568568
return _register_pipeline_tasks_decorator
569569

570-
@staticmethod
571-
def create_engine(
572-
onnx_file_path: str,
573-
engine_type: str,
574-
engine_args: Dict,
575-
context: Optional[Context] = None,
576-
) -> Union[Engine, MultiModelEngine, ORTEngine]:
577-
engine_type = engine_type.lower()
578-
579-
if engine_type == DEEPSPARSE_ENGINE:
580-
if context is not None and isinstance(context, Context):
581-
engine_args.pop("num_cores", None)
582-
engine_args.pop("scheduler", None)
583-
engine_args["context"] = context
584-
return MultiModelEngine(
585-
model=onnx_file_path,
586-
**engine_args,
587-
)
588-
return Engine(onnx_file_path, **engine_args)
589-
590-
if engine_type == ORT_ENGINE:
591-
return ORTEngine(onnx_file_path, **engine_args)
592-
593-
raise ValueError(
594-
f"Unknown engine_type {engine_type}. Supported values include: "
595-
f"{SUPPORTED_PIPELINE_ENGINES}"
596-
)
597-
598570
@classmethod
599571
def from_config(
600572
cls,
@@ -819,10 +791,26 @@ def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarr
819791
"""
820792
return self.engine(engine_inputs)
821793

822-
def _initialize_engine(self) -> Union[Engine, MultiModelEngine, ORTEngine]:
823-
return Pipeline.create_engine(
824-
self.onnx_file_path, self.engine_type, self._engine_args, self.context
825-
)
794+
def _initialize_engine(self) -> Union[Engine, ORTEngine]:
795+
engine_type = self.engine_type.lower()
796+
797+
if engine_type == DEEPSPARSE_ENGINE:
798+
if self.context is not None and isinstance(self.context, Context):
799+
self._engine_args.pop("num_cores", None)
800+
self._engine_args.pop("scheduler", None)
801+
self._engine_args["context"] = self.context
802+
return MultiModelEngine(
803+
model=self.onnx_file_path,
804+
**self._engine_args,
805+
)
806+
return Engine(self.onnx_file_path, **self._engine_args)
807+
elif engine_type == ORT_ENGINE:
808+
return ORTEngine(self.onnx_file_path, **self._engine_args)
809+
else:
810+
raise ValueError(
811+
f"Unknown engine_type {self.engine_type}. Supported values include: "
812+
f"{SUPPORTED_PIPELINE_ENGINES}"
813+
)
826814

827815
def _identifier(self):
828816
# get pipeline identifier; used in the context of logging

src/deepsparse/tasks.py

-23
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,6 @@ class SupportedTasks:
9595
),
9696
)
9797

98-
text_generation = namedtuple("text_generation", ["opt", "codegen", "bloom"])(
99-
codegen=AliasedTask("codegen", []),
100-
opt=AliasedTask("opt", []),
101-
bloom=AliasedTask("bloom", []),
102-
)
103-
10498
image_classification = namedtuple("image_classification", ["image_classification"])(
10599
image_classification=AliasedTask(
106100
"image_classification",
@@ -156,9 +150,6 @@ def check_register_task(
156150
# custom task, register the CustomPipeline
157151
import deepsparse.pipelines.custom_pipeline # noqa: F401
158152

159-
elif cls.is_text_generation(task):
160-
import deepsparse.transformers.pipelines.text_generation # noqa: F401
161-
162153
elif cls.is_nlp(task):
163154
# trigger transformers pipelines to register with Pipeline.register
164155
import deepsparse.transformers.pipelines # noqa: F401
@@ -202,20 +193,6 @@ def check_register_task(
202193
f"{list(all_tasks)}"
203194
)
204195

205-
@classmethod
206-
def is_text_generation(cls, task: str) -> bool:
207-
"""
208-
:param task: the name of the task to check whether it is a text generation task
209-
such as codegen
210-
:return: True if it is a text generation task, False otherwise
211-
"""
212-
return any(
213-
[
214-
text_generation_task.matches(task)
215-
for text_generation_task in cls.text_generation
216-
]
217-
)
218-
219196
@classmethod
220197
def is_nlp(cls, task: str) -> bool:
221198
"""

src/deepsparse/transformers/README.md

+3-47
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ methods such as [pruning](https://neuralmagic.com/blog/pruning-overview/) and [q
1010
These techniques result in significantly more performant and smaller models with limited to no effect on the baseline metrics.
1111

1212
This integration currently supports several fundamental NLP tasks:
13-
- **Text Generation** - given the input prompt, generate an output text sequence (e.g. to fill in incomplete text or paraphrase part of the prompt)
1413
- **Question Answering** - posing questions about a document
1514
- **Sentiment Analysis** - assigning a sentiment to a piece of text
1615
- **Text Classification** - assigning a label or class to a piece of text (e.g duplicate question pairing)
@@ -31,12 +30,10 @@ compatible with our [hardware requirements](https://docs.neuralmagic.com/deepspa
3130
By default, to deploy the transformer using DeepSparse Engine it is required to supply the model in the ONNX format along with the HuggingFace supporting files.
3231
This grants the engine the flexibility to serve any model in a framework-agnostic environment.
3332

34-
In general, the DeepSparse pipelines require the following files within a folder on the local server to properly load a Transformers model:
33+
The DeepSparse pipelines require the following files within a folder on the local server to properly load a Transformers model:
3534
- `model.onnx`: The exported Transformers model in the [ONNX format](https://github.com/onnx/onnx).
36-
- `model_kvcache.onnx` (optional): the ONNX model with the KV Cache support (akin to the Transformers model with `use_cache = True`. Specific for the `text-generation` integration.
35+
- `tokenizer.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model.
3736
- `config.json`: The [HuggingFace compatible configuration file](https://huggingface.co/docs/transformers/main_classes/configuration) used with the model.
38-
- `tokenizer_config.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model.
39-
- `tokenizer.json`, `special_tokens_map.json`, `vocab.json`, `merges.txt` (optional): Other files that may be required by a tokenizer
4037

4138
Below we describe two possibilities to obtain the required structure.
4239

@@ -51,7 +48,7 @@ sparseml.transformers.export_onnx --task question-answering --model_path model_p
5148
```
5249

5350
This creates `model.onnx` file, in the directory of your `model_path`(e.g. `/trained_model/model.onnx`).
54-
Any additional, required files, such as e.g.`tokenizer.json` or `config.json`, are stored under the `model_path` folder as well, so a DeepSparse pipeline ca be directly instantiated by using that folder after export (e.g. `/trained_model/`).
51+
The `tokenizer.json` and `config.json` are stored under the `model_path` folder as well, so a DeepSparse pipeline ca be directly instantiated by using that folder after export (e.g. `/trained_model/`).
5552

5653
#### SparseZoo Stub
5754
Alternatively, you can skip the process of the ONNX model export by using Neural Magic's [SparseZoo](https://sparsezoo.neuralmagic.com/). The SparseZoo contains pre-sparsified models and SparseZoo stubs enable you to reference any model on the SparseZoo in a convenient and predictable way.
@@ -140,47 +137,6 @@ response.text
140137

141138
>> '{"score":0.9534820914268494,"start":8,"end":14,"answer":"batman"}'
142139
```
143-
### Text Generation
144-
The text generation task generates a sequence of words given the prompt. Popular text generation LLMs (Large Language Models) are used
145-
for the chats (the instruction models), code generation, text summarization, or filling out the missing text.
146-
are used for chats or following instructions are also covered in this task. The following example uses a sparsified text classification
147-
OPT model to complete the prompt
148-
149-
[List of available SparseZoo Text Generation Models](
150-
https://sparsezoo.neuralmagic.com/?useCase=text_generation)
151-
152-
#### Python Pipeline
153-
```python
154-
from deepsparse import Pipeline
155-
156-
opt_pipeline = Pipeline.create(task="opt")
157-
158-
inference = opt_pipeline("Who is the president of the United States?")
159-
160-
>> 'The president of the United States is the head of the executive branch of government...'
161-
```
162-
163-
#### HTTP Server
164-
Spinning up:
165-
```bash
166-
deepsparse.server \
167-
task text-generation \
168-
--model_path # TODO: Pending until text generation models get uploaded to SparseZoo
169-
```
170-
171-
Making a request:
172-
```python
173-
import requests
174-
175-
url = "http://localhost:5543/predict" # Server's port default to 5543
176-
177-
obj = {"sequence": "Who is the president of the United States?"}
178-
179-
response = requests.post(url, json=obj)
180-
response.text
181-
182-
>> 'The president of the United States is the head of the executive branch of government...'
183-
```
184140

185141
### Sentiment Analysis
186142
The sentiment analysis task takes in a sentence and classifies its sentiment. The following example

src/deepsparse/transformers/pipelines/pipeline.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -126,23 +126,19 @@ def setup_onnx_file_path(self) -> str:
126126
return onnx_path
127127

128128
def tokens_to_engine_input(
129-
self,
130-
tokens: Mapping[Any, numpy.ndarray],
131-
onnx_input_names: Optional[List[str]] = None,
129+
self, tokens: Mapping[Any, numpy.ndarray]
132130
) -> List[numpy.ndarray]:
133131
"""
134132
:param tokens: outputs of the pipeline tokenizer
135133
:return: list of numpy arrays in expected order for model input
136134
"""
137-
if onnx_input_names is None:
138-
onnx_input_names = self.onnx_input_names
139-
if not all(name in tokens for name in onnx_input_names):
135+
if not all(name in tokens for name in self.onnx_input_names):
140136
raise ValueError(
141-
f"pipeline expected arrays with names {onnx_input_names}, "
137+
f"pipeline expected arrays with names {self.onnx_input_names}, "
142138
f"received inputs: {list(tokens.keys())}"
143139
)
144140

145-
return [tokens[name] for name in onnx_input_names]
141+
return [tokens[name] for name in self.onnx_input_names]
146142

147143
@staticmethod
148144
def should_bucket(*args, **kwargs) -> bool:

0 commit comments

Comments
 (0)