Skip to content

Commit e0652b0

Browse files
ywang96DarkLight1337
authored andcommitted
[Doc][V1] Update model implementation guide for V1 support (vllm-project#11998)
Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent e7d2de4 commit e0652b0

File tree

2 files changed

+83
-16
lines changed

2 files changed

+83
-16
lines changed

docs/source/contributing/model/basic.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,17 @@ class MyModelForCausalLM(nn.Module):
5757

5858
### Computation Code
5959

60-
Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
60+
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
61+
62+
```python
63+
class MyModel(nn.Module):
64+
...
65+
66+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
67+
...
68+
```
69+
70+
- Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
6171

6272
```python
6373
def forward(

docs/source/contributing/model/multimodal.md

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,78 @@ This document walks you through the steps to extend a basic model so that it acc
99
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
1010
Further update the model as follows:
1111

12-
- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
12+
- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example:
13+
14+
```diff
15+
def forward(
16+
self,
17+
input_ids: torch.Tensor,
18+
positions: torch.Tensor,
19+
kv_caches: List[torch.Tensor],
20+
attn_metadata: AttentionMetadata,
21+
+ pixel_values: torch.Tensor,
22+
) -> SamplerOutput:
23+
```
24+
25+
More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it.
26+
27+
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
28+
29+
```python
30+
class YourModelForImage2Seq(nn.Module):
31+
...
32+
33+
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
34+
35+
assert self.vision_encoder is not None
36+
image_features = self.vision_encoder(image_input)
37+
return self.multi_modal_projector(image_features)
38+
39+
def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:
40+
41+
# Validate the multimodal input keyword arguments
42+
image_input = self._parse_and_validate_image_input(**kwargs)
43+
if image_input is None:
44+
return None
45+
46+
# Run multimodal inputs through encoder and projector
47+
vision_embeddings = self._process_image_input(image_input)
48+
return vision_embeddings
49+
```
50+
51+
```{important}
52+
The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
53+
```
54+
55+
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
56+
57+
```python
58+
from .utils import merge_multimodal_embeddings
59+
60+
class YourModelForImage2Seq(nn.Module):
61+
...
62+
63+
def get_input_embeddings(
64+
self,
65+
input_ids: torch.Tensor,
66+
multimodal_embeddings: Optional[NestedTensors] = None,
67+
) -> torch.Tensor:
68+
69+
# `get_input_embeddings` should already be implemented for the language
70+
# model as one of the requirements of basic vLLM model implementation.
71+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
72+
73+
if multimodal_embeddings is not None:
74+
inputs_embeds = merge_multimodal_embeddings(
75+
input_ids=input_ids,
76+
inputs_embeds=inputs_embeds,
77+
multimodal_embeddings=multimodal_embeddings,
78+
placeholder_token_id=self.config.image_token_index)
79+
80+
return inputs_embeds
81+
```
82+
83+
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
1384

1485
```diff
1586
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
@@ -23,20 +94,6 @@ Further update the model as follows:
2394
Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
2495
```
2596

26-
- If you haven't already done so, reserve a keyword parameter in {meth}`~torch.nn.Module.forward`
27-
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
28-
29-
```diff
30-
def forward(
31-
self,
32-
input_ids: torch.Tensor,
33-
positions: torch.Tensor,
34-
kv_caches: List[torch.Tensor],
35-
attn_metadata: AttentionMetadata,
36-
+ pixel_values: torch.Tensor,
37-
) -> SamplerOutput:
38-
```
39-
4097
## 2. Specify processing information
4198

4299
Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`

0 commit comments

Comments
 (0)