Skip to content

Commit 6c46415

Browse files
shihaobaibaishihao
and
baishihao
authored
Fix ep (#814)
Co-authored-by: baishihao <[email protected]>
1 parent 5fcf0fa commit 6c46415

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

Diff for: lightllm/common/basemodel/basemodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
460460
infer_state.b_ready_cache_len = torch.zeros_like(
461461
cur_batch.b_seq_len, dtype=cur_batch.b_seq_len.dtype, device=cur_batch.b_seq_len.device
462462
)
463-
infer_state.multimodal_params = None
463+
infer_state.multimodal_params = cur_batch.multimodal_params
464464
infer_state.microbatch_index = batch_index
465465

466466
infer_state.mem_manager = self.mem_manager

Diff for: lightllm/common/basemodel/microbatch_overlap_objs.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ class PrefillMicroBatch:
2525
b_start_loc: torch.Tensor
2626
b_seq_len: torch.Tensor
2727
b_ready_cache_len: torch.Tensor
28+
multimodal_params: list

Diff for: lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit
9292
from .pre_process import padded_prepare_decode_inputs
9393

9494
kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs(
95-
decode_reqs, max_decode_num, is_multimodal=False
95+
decode_reqs, max_decode_num, is_multimodal=self.is_multimodal
9696
)
9797
logits = self.model.forward(**kwargs)
9898

@@ -118,7 +118,7 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
118118
micro_batch1,
119119
run_reqs1,
120120
padded_req_num1,
121-
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=False)
121+
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
122122
logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1)
123123
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
124124
req_num, req_num1 = len(run_reqs), len(run_reqs1)
@@ -147,7 +147,7 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
147147
micro_batch1,
148148
run_reqs1,
149149
padded_req_num1,
150-
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=False)
150+
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal)
151151
logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1)
152152
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
153153
req_num, req_num1 = len(run_reqs), len(run_reqs1)

Diff for: lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py

+1
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=
336336
b_start_loc=nopad_b_start_loc,
337337
b_seq_len=nopad_b_seq_len,
338338
b_ready_cache_len=b_ready_cache_len,
339+
multimodal_params=batch_multimodal_params,
339340
)
340341

341342
return micro_batch, run_reqs, padded_req_num

0 commit comments

Comments
 (0)