@@ -199,6 +199,56 @@ def sample_sonnet_requests(
199
199
return sampled_requests
200
200
201
201
202
+ def sample_mmmu_pro_vision_requests (
203
+ dataset ,
204
+ num_requests : int ,
205
+ tokenizer : PreTrainedTokenizerBase ,
206
+ fixed_output_len : Optional [int ] = None ,
207
+ ) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
208
+ sampled_requests : List [Tuple [str , int , int , Dict [str ,
209
+ Collection [str ]]]] = []
210
+ for data in dataset :
211
+ if len (sampled_requests ) == num_requests :
212
+ break
213
+
214
+ # MMMU-Pro vision direct prompt
215
+ # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
216
+ prompt = (
217
+ "Answer with the option letter from the given choices directly. "
218
+ "The last line of your response should be of the following "
219
+ "format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
220
+ "options." )
221
+
222
+ prompt_token_ids = tokenizer (prompt ).input_ids
223
+ if fixed_output_len is None :
224
+ # Default max output len is set to 128
225
+ print ("--hf-output-len is not provided. Using default value 128." )
226
+ fixed_output_len = 128
227
+
228
+ prompt_len = len (prompt_token_ids )
229
+ output_len = fixed_output_len
230
+
231
+ assert isinstance (
232
+ data ["image" ],
233
+ Image ), ("Input image format must be `PIL.Image.Image`, "
234
+ f"given { type (data ['image' ])} ." )
235
+ image : Image = data ["image" ]
236
+ image = image .convert ("RGB" )
237
+ image_data = io .BytesIO ()
238
+ image .save (image_data , format = 'JPEG' )
239
+ image_base64 = base64 .b64encode (image_data .getvalue ()).decode ("utf-8" )
240
+ mm_content = {
241
+ "type" : "image_url" ,
242
+ "image_url" : {
243
+ "url" : f"data:image/jpeg;base64,{ image_base64 } "
244
+ },
245
+ }
246
+
247
+ sampled_requests .append ((prompt , prompt_len , output_len , mm_content ))
248
+
249
+ return sampled_requests
250
+
251
+
202
252
def sample_hf_requests (
203
253
dataset_path : str ,
204
254
dataset_subset : str ,
@@ -208,6 +258,21 @@ def sample_hf_requests(
208
258
random_seed : int ,
209
259
fixed_output_len : Optional [int ] = None ,
210
260
) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
261
+
262
+ # Special case for MMMU-Pro vision dataset
263
+ if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision' :
264
+ assert dataset_split == "test"
265
+ dataset = load_dataset (dataset_path ,
266
+ name = dataset_subset ,
267
+ split = dataset_split ,
268
+ streaming = True )
269
+ assert "image" in dataset .features , (
270
+ "MMMU/MMMU_Pro vision dataset must have 'image' column." )
271
+ filter_func = lambda x : isinstance (x ["image" ], Image )
272
+ dataset = dataset .shuffle (seed = random_seed ).filter (filter_func )
273
+ return sample_mmmu_pro_vision_requests (dataset , num_requests ,
274
+ tokenizer , fixed_output_len )
275
+
211
276
dataset = load_dataset (dataset_path ,
212
277
name = dataset_subset ,
213
278
split = dataset_split ,
0 commit comments