Skip to content

Fix ep #814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
infer_state.b_ready_cache_len = torch.zeros_like(
cur_batch.b_seq_len, dtype=cur_batch.b_seq_len.dtype, device=cur_batch.b_seq_len.device
)
infer_state.multimodal_params = None
infer_state.multimodal_params = cur_batch.multimodal_params
infer_state.microbatch_index = batch_index

infer_state.mem_manager = self.mem_manager
Expand Down
1 change: 1 addition & 0 deletions lightllm/common/basemodel/microbatch_overlap_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ class PrefillMicroBatch:
b_start_loc: torch.Tensor
b_seq_len: torch.Tensor
b_ready_cache_len: torch.Tensor
multimodal_params: list
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit
from .pre_process import padded_prepare_decode_inputs

kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs(
decode_reqs, max_decode_num, is_multimodal=False
decode_reqs, max_decode_num, is_multimodal=self.is_multimodal
)
logits = self.model.forward(**kwargs)

Expand All @@ -118,7 +118,7 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
micro_batch1,
run_reqs1,
padded_req_num1,
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=False)
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1)
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
req_num, req_num1 = len(run_reqs), len(run_reqs1)
Expand Down Expand Up @@ -147,7 +147,7 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
micro_batch1,
run_reqs1,
padded_req_num1,
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=False)
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal)
logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1)
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
req_num, req_num1 = len(run_reqs), len(run_reqs1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=
b_start_loc=nopad_b_start_loc,
b_seq_len=nopad_b_seq_len,
b_ready_cache_len=b_ready_cache_len,
multimodal_params=batch_multimodal_params,
)

return micro_batch, run_reqs, padded_req_num