@@ -92,7 +92,7 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit
92
92
from .pre_process import padded_prepare_decode_inputs
93
93
94
94
kwargs , run_reqs , padded_req_num = padded_prepare_decode_inputs (
95
- decode_reqs , max_decode_num , is_multimodal = False
95
+ decode_reqs , max_decode_num , is_multimodal = self . is_multimodal
96
96
)
97
97
logits = self .model .forward (** kwargs )
98
98
@@ -118,7 +118,7 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
118
118
micro_batch1 ,
119
119
run_reqs1 ,
120
120
padded_req_num1 ,
121
- ) = padded_overlap_prepare_decode_inputs (decode_reqs , max_decode_num , is_multimodal = False )
121
+ ) = padded_overlap_prepare_decode_inputs (decode_reqs , max_decode_num , is_multimodal = self . is_multimodal )
122
122
logits , logits1 = self .model .microbatch_overlap_decode (micro_batch , micro_batch1 )
123
123
self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
124
124
req_num , req_num1 = len (run_reqs ), len (run_reqs1 )
@@ -147,7 +147,7 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
147
147
micro_batch1 ,
148
148
run_reqs1 ,
149
149
padded_req_num1 ,
150
- ) = padded_overlap_prepare_prefill_inputs (prefill_reqs , max_prefill_num , is_multimodal = False )
150
+ ) = padded_overlap_prepare_prefill_inputs (prefill_reqs , max_prefill_num , is_multimodal = self . is_multimodal )
151
151
logits , logits1 = self .model .microbatch_overlap_prefill (micro_batch , micro_batch1 )
152
152
self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
153
153
req_num , req_num1 = len (run_reqs ), len (run_reqs1 )
0 commit comments