@@ -180,19 +180,66 @@ def apply(
180
180
mm_inputs = super ().apply (prompt , mm_data , hf_processor_mm_kwargs ,
181
181
return_mm_hashes )
182
182
183
+ image_token_id = self .info .get_hf_config ().image_token_index
183
184
# Check that the number of image tokens in the decoder prompt matches
184
185
# the number of images provided in mm_data
185
- num_image_tokens = mm_inputs ['prompt_token_ids' ].count (
186
- self .info .get_hf_config ().image_token_index )
186
+ num_image_tokens = mm_inputs ['prompt_token_ids' ].count (image_token_id )
187
187
image_data = mm_data .get ("image" , [])
188
188
num_images = 1 if isinstance (image_data , Image ) else len (image_data )
189
189
if num_image_tokens != num_images :
190
190
raise ValueError (
191
191
f"The number of image tokens ({ num_image_tokens } ) must be"
192
192
f" the same as the number of images ({ num_images } )" )
193
193
194
+ # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
195
+ # P0 & P1 do cross attention with placeholder of <IMG0>
196
+ # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
197
+ # Example input to encoder and decoder:
198
+ # {
199
+ # 'encoder': {
200
+ # 'type': 'token',
201
+ # 'prompt_token_ids': [128256, 128256, ..., 128256],
202
+ # 'prompt': '<|image|><|image|>...<|image|>',
203
+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
204
+ # },
205
+ # 'decoder': {
206
+ # 'type': 'token',
207
+ # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
208
+ # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
209
+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
210
+ # },
211
+ # }
212
+
213
+ if mm_data :
214
+ # Since only the last group of consecutive images
215
+ # are attended by the decoded tokens, we only need to
216
+ # get the number of tokens for those images.
217
+ token_per_chunk = self .info .get_token_per_chunk_from_config ()
218
+ num_decode_images = self ._get_num_image_in_last_group (
219
+ mm_inputs ["prompt_token_ids" ])
220
+ num_encode_images = num_images - num_decode_images
221
+
222
+ # Set encoder prompt length based on the number of tiles.
223
+ # This tells the block manager to allocate correct number
224
+ # of slots for encoder tokens.
225
+ num_tiles = mm_inputs ["mm_kwargs" ]["num_tiles" ]
226
+ decode_tiles = num_tiles [num_encode_images :num_images ].sum ().item ()
227
+ num_tokens = decode_tiles * token_per_chunk
228
+ mm_inputs ["encoder_prompt_token_ids" ] = [image_token_id
229
+ ] * num_tokens
230
+ mm_inputs ["encoder_prompt" ] = "<|image|>" * num_tokens
231
+
194
232
return mm_inputs
195
233
234
+ def _get_num_image_in_last_group (self , prompt_token_ids : List [int ]) -> int :
235
+ num_images = 0
236
+ for token_id in prompt_token_ids [::- 1 ]:
237
+ if token_id == self .info .get_hf_config ().image_token_index :
238
+ num_images += 1
239
+ elif num_images > 0 :
240
+ break
241
+ return num_images
242
+
196
243
def _call_hf_processor (
197
244
self ,
198
245
prompt : str ,
@@ -210,19 +257,7 @@ def _call_hf_processor(
210
257
processed_outputs ["num_tiles" ] = torch .tensor (num_tiles )
211
258
for k in ('pixel_values' , 'aspect_ratio_ids' , "aspect_ratio_mask" ):
212
259
processed_outputs [k ] = processed_outputs [k ].squeeze (0 )
213
- # Example input to encoder and decoder:
214
- # {
215
- # 'encoder': {
216
- # 'type': 'token',
217
- # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
218
- # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
219
- # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
220
- # },
221
- # 'decoder': {
222
- # 'type': 'token',
223
- # 'prompt_token_ids': [128000],
224
- # },
225
- # }
260
+
226
261
processed_token_ids = processed_outputs .pop ("input_ids" )
227
262
start_idx , end_idx = 0 , processed_token_ids .size (1 )
228
263
processed_prompt_text = tokenizer .decode (processed_token_ids [0 ])
0 commit comments