Skip to content

Commit ed9b1ee

Browse files
authored
[CLIP]: Zeroshot Pipeline (#1098)
* initial refactor * move BasePipeline to a new file * test fix * anothe test fix * fix import * revert * initial refactor * add tests for BasePipeline * move BasePipeline to a new file * initial refactor * update test; finish off initial refactoring changes post local testing * initial commit for clip zero-shot * add basic structure for text branch and zeroshot * add schema details * update pipelines after running mock engine tests * add zeroshot tests * rebase fix * clean-up comments; add note about onnx export issue * add clip dependency * move paths to fixtures * rebase fix * rebase fix * refactor pipelines to separate visual, text, and zeroshot. also add pytest skips until model issues are resolved * make zershot arguments explicit; deal with quality : * update workflow to install clip for base test * update pipelines after using MLR's zeroshot models * add readme with examples, update setup.py and clean-up return types * quality fix * Update visual_pipeline.py update model loading Co-authored-by: dbogunowicz <[email protected]> * add docstring * move docstring; add params * fix rebase * quality
1 parent 3254ca8 commit ed9b1ee

File tree

9 files changed

+562
-2
lines changed

9 files changed

+562
-2
lines changed

.github/workflows/test-check.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: "Clean sparsezoo directory"
3333
run: rm -r sparsezoo/
3434
- name: ⚙️ Install dependencies
35-
run: pip3 install .[dev,server,image_classification,transformers] opencv-python
35+
run: pip3 install .[dev,server,image_classification,transformers,clip] opencv-python
3636
- name: Run base tests
3737
run: make test
3838
cli-smoke-tests:

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _parse_requirements_file(file_path):
162162
"haystack_reqs.txt",
163163
)
164164
_haystack_integration_deps = _parse_requirements_file(_haystack_requirements_file_path)
165-
165+
_clip_deps = ["open_clip_torch==2.20.0", "scipy==1.10.1"]
166166

167167
_torch_deps = ["torch>=1.7.0,<=2.0"]
168168

@@ -280,6 +280,7 @@ def _setup_extras() -> Dict:
280280
"yolov8": _yolov8_integration_deps,
281281
"transformers": _transformers_integration_deps,
282282
"torch": _torch_deps,
283+
"clip": _clip_deps,
283284
}
284285

285286

src/deepsparse/clip/README.md

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# CLIP Inference Pipelines
2+
3+
DeepSparse allows inference on [CLIP](https://github.com/mlfoundations/open_clip) models.
4+
5+
The CLIP integration currently supports the following task:
6+
- **Zero-shot Image Classification** - Classifying images given possible classes
7+
8+
## Getting Started
9+
10+
Before you start your adventure with the DeepSparse Engine, make sure that your machine is compatible with our [hardware requirements](https://docs.neuralmagic.com/deepsparse/source/hardware.html).
11+
12+
### Installation
13+
```pip install deepsparse[clip]```
14+
15+
### Model Format
16+
By default, to deploy CLIP models using the DeepSparse Engine, it is required to supply the model in the ONNX format. This grants the engine the flexibility to serve any model in a framework-agnostic environment. To see examples of pulling CLIP models and exporting them to ONNX, please see the [sparseml documentation](https://github.com/neuralmagic/sparseml/tree/main/integrations/clip). For the Zero-shot image classification workflow, two ONNX models are required, a visual model for CLIP's visual branch, and a text model for CLIP's text branch. Both of these model should be produced through the sparseml integration linked above.
17+
18+
### Deployment examples:
19+
The following example uses pipelines to run the CLIP models for inference. As input, the pipeline ingests a list of images and a list of possible classes. A class is returned for each of the provided images.
20+
21+
If you don't have images ready, pull down the sample images using the following commands:
22+
23+
```bash
24+
wget -O basilica.jpg https://raw.githubusercontent.com/neuralmagic/deepsparse/main/src/deepsparse/yolo/sample_images/basilica.jpg
25+
26+
wget -O buddy.jpeg https://raw.githubusercontent.com/neuralmagic/deepsparse/main/tests/deepsparse/pipelines/sample_images/buddy.jpeg
27+
```
28+
29+
This will pull down two images, one with a happy dog and one with St.Peter's basilica.
30+
31+
#### Zero-shot Prediction
32+
33+
Let's run an example to clasify the images. We'll provide the images in a list with their file names as well as a list of possible classes. We'll also provide paths to the exported ONNX models.
34+
35+
```python
36+
import numpy as np
37+
38+
from deepsparse import BasePipeline
39+
from deepsparse.clip import (
40+
CLIPTextInput,
41+
CLIPVisualInput,
42+
CLIPZeroShotInput
43+
)
44+
45+
possible_classes = ["ice cream", "an elephant", "a dog", "a building", "a church"]
46+
images = ["basilica.jpg", "buddy.jpeg"]
47+
48+
model_path_text = "zeroshot_research/text/model.onnx"
49+
model_path_visual = "zeroshot_research/visual/model.onnx"
50+
51+
kwargs = {
52+
"visual_model_path": model_path_visual,
53+
"text_model_path": model_path_text,
54+
}
55+
pipeline = BasePipeline.create(task="clip_zeroshot", **kwargs)
56+
57+
pipeline_input = CLIPZeroShotInput(
58+
image=CLIPVisualInput(images=images),
59+
text=CLIPTextInput(text=possible_classes),
60+
)
61+
62+
output = pipeline(pipeline_input).text_scores
63+
for i in range(len(output)):
64+
prediction = possible_classes[np.argmax(output[i])]
65+
print(f"Image {images[i]} is a picture of {prediction}")
66+
```
67+
68+
Running the code above, we get the following outuput:
69+
70+
```
71+
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230727 COMMUNITY | (3cb4a3e5) (optimized) (system=avx2, binary=avx2)
72+
73+
Image basilica.jpg is a picture of a church
74+
Image buddy.jpeg is a picture of a dog
75+
```

src/deepsparse/clip/__init__.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
17+
from deepsparse.clip.text_pipeline import (
18+
CLIPTextInput,
19+
CLIPTextOutput,
20+
CLIPTextPipeline,
21+
)
22+
from deepsparse.clip.visual_pipeline import (
23+
CLIPVisualInput,
24+
CLIPVisualOutput,
25+
CLIPVisualPipeline,
26+
)
27+
from deepsparse.clip.zeroshot_pipeline import (
28+
CLIPZeroShotInput,
29+
CLIPZeroShotOutput,
30+
CLIPZeroShotPipeline,
31+
)

src/deepsparse/clip/constants.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
__all__ = ["CLIP_RGB_MEANS", "CLIP_RGB_STDS"]
17+
18+
CLIP_RGB_MEANS = [0.48145466, 0.4578275, 0.40821073]
19+
CLIP_RGB_STDS = [0.26862954, 0.26130258, 0.27577711]

src/deepsparse/clip/text_pipeline.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Type, Union
16+
17+
import numpy as np
18+
from pydantic import BaseModel, Field
19+
20+
from deepsparse.pipeline import Pipeline
21+
from deepsparse.utils import model_to_path
22+
from open_clip.tokenizer import tokenize
23+
24+
25+
__all__ = ["CLIPTextInput", "CLIPTextOutput", "CLIPTextPipeline"]
26+
27+
28+
class CLIPTextInput(BaseModel):
29+
"""
30+
Input for the CLIP Text Branch
31+
"""
32+
33+
text: Union[str, List[str]] = Field(description="List of text to process")
34+
35+
36+
class CLIPTextOutput(BaseModel):
37+
"""
38+
Output for the CLIP Text Branch
39+
"""
40+
41+
text_embeddings: List[Any] = Field(
42+
description="Text embeddings for the single text or list of embeddings for "
43+
"multiple."
44+
)
45+
46+
47+
@Pipeline.register(task="clip_text", default_model_path=None)
48+
class CLIPTextPipeline(Pipeline):
49+
def __init__(self, **kwargs):
50+
super().__init__(**kwargs)
51+
52+
self.tokenizer = tokenize
53+
54+
@property
55+
def input_schema(self) -> Type[CLIPTextInput]:
56+
"""
57+
:return: pydantic model class that inputs to this pipeline must comply to
58+
"""
59+
return CLIPTextInput
60+
61+
@property
62+
def output_schema(self) -> Type[CLIPTextOutput]:
63+
"""
64+
:return: pydantic model class that inputs to this pipeline must comply to
65+
"""
66+
return CLIPTextOutput
67+
68+
def setup_onnx_file_path(self):
69+
"""
70+
Performs any setup to unwrap and process the given `model_path` and other
71+
class properties into an inference ready onnx file to be compiled by the
72+
engine of the pipeline
73+
74+
:return: file path to the ONNX file for the engine to compile
75+
"""
76+
return model_to_path(self.model_path)
77+
78+
def process_inputs(self, inputs: CLIPTextInput) -> List[np.ndarray]:
79+
"""
80+
Preprocess inputs for CLIP's Trext Branch to comply with the DeepSparse Engine
81+
82+
:param inputs: CLITextInput
83+
:return: list of preprocessed numpy arrays
84+
"""
85+
if isinstance(inputs.text, str):
86+
inputs.text = [inputs.text]
87+
88+
tokens = self.tokenizer(inputs.text)
89+
tokens = [np.array(t).astype(np.int32) for t in tokens]
90+
tokens = np.stack(tokens, axis=0)
91+
return [tokens]
92+
93+
def process_engine_outputs(
94+
self, engine_outputs: List[np.array], **kwargs
95+
) -> CLIPTextOutput:
96+
"""
97+
:param engine_outputs: list of numpy arrays that are the output of the engine
98+
forward pass
99+
:return: outputs of engine post-processed into an object in the `output_schema`
100+
format of this pipeline
101+
"""
102+
return self.output_schema(text_embeddings=engine_outputs)

0 commit comments

Comments
 (0)