6
6
import re
7
7
import types
8
8
from pathlib import PosixPath
9
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
9
+ from typing import Callable , List , Optional , Tuple , Union
10
10
11
11
import torch
12
12
from PIL .Image import Image
17
17
from vllm .transformers_utils .tokenizer import patch_padding_side
18
18
from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE
19
19
20
- from .....conftest import (HfRunner , ImageAsset , PromptAudioInput ,
21
- PromptImageInput , PromptVideoInput , _ImageAssets )
22
- from ....utils import TokensTextLogprobs
20
+ from .....conftest import HfRunner , ImageAsset , _ImageAssets
23
21
from .types import RunnerOutput
24
22
25
23
@@ -522,74 +520,7 @@ def _generate(self, *args, **kwargs):
522
520
return hf_model
523
521
524
522
525
- def _generate_greedy_logprobs_limit (
526
- self ,
527
- prompts : List [str ],
528
- max_tokens : int ,
529
- num_logprobs : int ,
530
- images : Optional [PromptImageInput ] = None ,
531
- audios : Optional [PromptAudioInput ] = None ,
532
- videos : Optional [PromptVideoInput ] = None ,
533
- ** kwargs : Any ,
534
- ) -> List [TokensTextLogprobs ]:
535
- all_inputs = self .get_inputs (prompts ,
536
- images = images ,
537
- videos = videos ,
538
- audios = audios )
539
-
540
- # Process in batches for inference.
541
- if len (all_inputs ):
542
- input_ids_lst = []
543
- images_lst = []
544
- images_input_idx_lst = []
545
- imges_masks_lst = []
546
- for inputs in all_inputs :
547
- input_ids_lst .append (inputs ["input_ids" ])
548
- images_lst .append (inputs ["images" ])
549
- images_input_idx_lst .append (inputs ["image_input_idx" ])
550
- imges_masks_lst .append (inputs ["image_masks" ])
551
- batch_inputs = {}
552
- batch_inputs ['input_ids' ] = torch .cat (input_ids_lst , dim = 0 )
553
- batch_inputs ['images' ] = torch .cat (images_lst , dim = 0 )
554
- batch_inputs ['image_input_idx' ] = torch .cat (images_input_idx_lst ,
555
- dim = 0 )
556
- batch_inputs ['image_masks' ] = torch .cat (imges_masks_lst , dim = 0 )
557
-
558
- outputs = self .model .generate_from_batch (
559
- batch = self .wrap_device (batch_inputs ,
560
- device = self .model .device .type ),
561
- generation_config = GenerationConfig (
562
- max_new_tokens = max_tokens ,
563
- stop_strings = "<|endoftext|>" ,
564
- do_sample = False ,
565
- ),
566
- tokenizer = self .tokenizer ,
567
- output_hidden_states = True ,
568
- return_dict_in_generate = True ,
569
- )
570
-
571
- all_logprobs : List [List [Dict [int , float ]]] = []
572
- all_output_ids : List [List [int ]] = []
573
- all_output_strs : List [str ] = []
574
-
575
- for index in range (len (all_inputs )):
576
- (
577
- seq_logprobs_lst ,
578
- output_len ,
579
- ) = self ._hidden_states_to_logprobs (outputs .hidden_states ,
580
- num_logprobs )
581
- all_logprobs .append (seq_logprobs_lst )
582
- seq_ids = outputs .sequences [index ]
583
- output_ids = seq_ids [- output_len :]
584
- all_output_ids .append (output_ids .tolist ())
585
- all_output_strs .append (self .tokenizer .decode (output_ids ))
586
- outputs = zip (all_output_ids , all_output_strs , all_logprobs )
587
- return [(output_ids , output_str , output_logprobs )
588
- for output_ids , output_str , output_logprobs in outputs ]
589
-
590
-
591
- ####### Molmo-specific HuggingFace runner patchers
592
- def mlomo_patch_hf_runner (hf_model : HfRunner ) -> HfRunner :
523
+ def molmo_patch_hf_runner (hf_model : HfRunner ) -> HfRunner :
593
524
"""Patches and returns an instance of the HfRunner to use for Molmo."""
594
525
hf_processor = hf_model .processor
595
526
@@ -598,10 +529,23 @@ def _processor(*args, **kwargs):
598
529
599
530
hf_model .processor = _processor
600
531
601
- setattr ( # noqa: B010
602
- hf_model ,
603
- "generate_greedy_logprobs_limit" ,
604
- types .MethodType (_generate_greedy_logprobs_limit , hf_model ),
605
- )
532
+ def _generate (self , max_new_tokens = None , do_sample = None , ** kwargs ):
533
+ batch = {
534
+ k : kwargs .pop (k )
535
+ for k in ("input_ids" , "images" , "image_input_idx" , "image_masks" )
536
+ if k in kwargs
537
+ }
538
+
539
+ return self .generate_from_batch (
540
+ batch ,
541
+ generation_config = GenerationConfig (
542
+ max_new_tokens = max_new_tokens ,
543
+ stop_strings = "<|endoftext|>" ,
544
+ do_sample = do_sample ,
545
+ ),
546
+ ** kwargs ,
547
+ )
548
+
549
+ hf_model .model .generate = types .MethodType (_generate , hf_model .model )
606
550
607
551
return hf_model
0 commit comments