Skip to content

Commit b0695f6

Browse files
committed
clean-up pipelines, updatetyping and descriptions
1 parent c1003be commit b0695f6

File tree

6 files changed

+52
-24
lines changed

6 files changed

+52
-24
lines changed

Diff for: src/deepsparse/clip/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ kwargs = {
110110
pipeline = BasePipeline.create(task="clip_caption", **kwargs)
111111

112112
pipeline_input = CLIPCaptionInput(image=CLIPVisualInput(images="thailand.jpg"))
113-
output = pipeline(pipeline_input)
113+
output = pipeline(pipeline_input).caption
114114
print(output[0])
115115
```
116116
Running the code above, we get the following caption:

Diff for: src/deepsparse/clip/captioning_pipeline.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ class CLIPCaptionOutput(BaseModel):
5353

5454
@BasePipeline.register(task="clip_caption", default_model_path=None)
5555
class CLIPCaptionPipeline(BasePipeline):
56+
"""
57+
Pipelines designed to generate a caption for a given image. The CLIPCaptionPipeline
58+
relies on 3 other pipelines: CLIPVisualPipeline, CLIPTextPipeline, and the
59+
CLIPDecoder Pipeline. The pipeline takes in a single image and then uses the
60+
pipelines along with Beam Search to generate a caption.
61+
62+
:param visual_model_path: either a local path or sparsezoo stub for the CLIP visual
63+
branch onnx model
64+
:param text_model_path: either a local path or sparsezoo stub for the CLIP text
65+
branch onnx model
66+
:param decoder_model_path: either a local path or sparsezoo stub for the CLIP
67+
decoder branch onnx model
68+
:param num_beams: number of beams to use in Beam Search
69+
:param num_beam_groups: number of beam groups to use in Beam Search
70+
:param min_seq_len: the minimum length of the caption sequence
71+
:param max_seq_len: the maxmium length of the caption sequence
72+
73+
"""
74+
5675
def __init__(
5776
self,
5877
visual_model_path: str,
@@ -61,17 +80,13 @@ def __init__(
6180
num_beams: int = 10,
6281
num_beam_groups: int = 5,
6382
min_seq_len: int = 5,
64-
seq_len: int = 20,
65-
fixed_output_length: bool = False,
83+
max_seq_len: int = 20,
6684
**kwargs,
6785
):
6886
self.num_beams = num_beams
6987
self.num_beam_groups = num_beam_groups
70-
self.seq_len = seq_len
88+
self.max_seq_len = max_seq_len
7189
self.min_seq_len = min_seq_len
72-
self.fixed_output_length = fixed_output_length
73-
74-
super().__init__(**kwargs)
7590

7691
self.visual = Pipeline.create(
7792
task="clip_visual",
@@ -86,8 +101,9 @@ def __init__(
86101
**{"model_path": decoder_model_path},
87102
)
88103

89-
# TODO: have to verify all input types
90-
def _encode_and_decode(self, text, image_embs):
104+
super().__init__(**kwargs)
105+
106+
def _encode_and_decode(self, text: torch.Tensor, image_embs: torch.Tensor):
91107
original_size = text.shape[-1]
92108
padded_tokens = F.pad(text, (15 - original_size, 0))
93109
text_embeddings = self.text(
@@ -104,16 +120,15 @@ def _encode_and_decode(self, text, image_embs):
104120
}
105121

106122
# Adapted from open_clip
107-
def _generate(self, pipeline_inputs):
108-
# Make these input values?
123+
def _generate(self, pipeline_inputs: CLIPCaptionInput):
109124
sot_token_id = 49406
110125
eos_token_id = 49407
111126
pad_token_id = 0
112127
batch_size = 1
113128
repetition_penalty = 1.0
114129
device = "cpu"
115130

116-
stopping_criteria = [MaxLengthCriteria(max_length=self.seq_len)]
131+
stopping_criteria = [MaxLengthCriteria(max_length=self.max_seq_len)]
117132
stopping_criteria = StoppingCriteriaList(stopping_criteria)
118133

119134
logits_processor = LogitsProcessorList(

Diff for: src/deepsparse/clip/decoder_pipeline.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ class CLIPDecoderInput(BaseModel):
2929
Input for the CLIP Decoder Branch
3030
"""
3131

32-
text_embeddings: Any = Field(description="Text emebddings from the text branch")
33-
image_embeddings: Any = Field(description="Image embeddings from the visual branch")
32+
text_embeddings: Any = Field(
33+
description="np.array of text emebddings from the " "text branch"
34+
)
35+
image_embeddings: Any = Field(
36+
description="np.array of image embeddings from the " "visual branch"
37+
)
3438

3539

3640
class CLIPDecoderOutput(BaseModel):
@@ -39,7 +43,7 @@ class CLIPDecoderOutput(BaseModel):
3943
"""
4044

4145
logits: List[Any] = Field(
42-
description="Logits produced from the text and image emebeddings."
46+
description="np.array of logits produced from the decoder."
4347
)
4448

4549

Diff for: src/deepsparse/clip/text_pipeline.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CLIPTextInput(BaseModel):
3131
"""
3232

3333
text: Union[str, List[str], Any, List[Any]] = Field(
34-
description="Either raw text or text embeddings"
34+
description="Either raw strings or an np.array with tokenized text"
3535
)
3636

3737

@@ -41,8 +41,10 @@ class CLIPTextOutput(BaseModel):
4141
"""
4242

4343
text_embeddings: List[Any] = Field(
44-
description="Text embeddings for the single text or list of embeddings for "
45-
"multiple."
44+
description="np.array of text embeddings. For the caption "
45+
"pipeline, a list of two embeddings is produced. For zero-shot "
46+
"classifcation, one array is produced with the embeddings stacked along "
47+
"batch axis."
4648
)
4749

4850

Diff for: src/deepsparse/clip/visual_pipeline.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ class CLIPVisualOutput(BaseModel):
4444
"""
4545

4646
image_embeddings: List[Any] = Field(
47-
description="Image embeddings for the single image or list of embeddings for "
48-
"multiple images"
47+
description="np.arrays consisting of image embeddings. For the caption "
48+
"pipeline, a list of two image embeddings is produced. For zero-shot "
49+
"classifcation, one array is produced with the embeddings stacked along "
50+
"batch axis."
4951
)
5052

5153

Diff for: src/deepsparse/clip/zeroshot_pipeline.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,23 @@ class CLIPZeroShotInput(BaseModel):
3232
"""
3333

3434
image: CLIPVisualInput = Field(
35-
description="Path to image to run zero-shot prediction on."
35+
description="Image(s) to run zero-shot prediction. See CLIPVisualPipeline "
36+
"for details."
37+
)
38+
text: CLIPTextInput = Field(
39+
description="Text/classes to run zero-shot prediction "
40+
"see CLIPTextPipeline for details."
3641
)
37-
text: CLIPTextInput = Field(description="List of text to process")
3842

3943

4044
class CLIPZeroShotOutput(BaseModel):
4145
"""
4246
Output for the CLIP Zero Shot Model
4347
"""
4448

45-
# TODO: Maybe change this to a dictionary where keys are text inputs
46-
text_scores: List[Any] = Field(description="Probability of each text class")
49+
text_scores: List[Any] = Field(
50+
description="np.array consisting of probabilities " " each class provided."
51+
)
4752

4853

4954
@BasePipeline.register(task="clip_zeroshot", default_model_path=None)

0 commit comments

Comments
 (0)