Skip to content

Commit 051a2bc

Browse files
hiworldwzjbaishihaobillywwwflyinglandlord
authored
DeepseekV3 support deepep, deepgemm, PD, DP TP SP Mix mode. (#783)
Co-authored-by: baishihao <[email protected]> Co-authored-by: wanghao <[email protected]> Co-authored-by: FlyingFlame <[email protected]>
1 parent 6234bd3 commit 051a2bc

File tree

105 files changed

+4248
-1837
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+4248
-1837
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ dist
55
*.egg-info
66
.idea
77
.vscode
8+
tmp/

Diff for: README.md

+2-5
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ We welcome any coopoeration and contribution. If there is a project requires lig
5959

6060
</details>
6161

62-
63-
## Star History
64-
65-
[![Star History Chart](https://api.star-history.com/svg?repos=ModelTC/lightllm&type=Timeline)](https://star-history.com/#ModelTC/lightllm&Timeline)
66-
6762
## Community
6863

6964
For further information and discussion, [join our discord server](https://discord.gg/WzzfwVSguU). Welcome to be a member and look forward to your contribution!
@@ -78,5 +73,7 @@ We learned a lot from the following projects when developing LightLLM.
7873
- [Faster Transformer](https://github.com/NVIDIA/FasterTransformer)
7974
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference)
8075
- [vLLM](https://github.com/vllm-project/vllm)
76+
- [SGLang](https://github.com/sgl-project/sglang)
77+
- [flashinfer](https://github.com/flashinfer-ai/flashinfer/tree/main)
8178
- [Flash Attention 1&2](https://github.com/Dao-AILab/flash-attention)
8279
- [OpenAI Triton](https://github.com/openai/triton)

Diff for: docs/CN/source/models/test.rst

-5
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,8 @@ internlm2-1_8b
135135
136136
$ python -m lightllm.server.api_server
137137
$ --model_dir ~/models/internlm2-1_8b \
138-
$ --enable_chunked_prefill \
139138
$ --trust_remote_code
140139
141-
.. tip::
142-
143-
``--enable_chunked_prefill`` 表示使用chunkedprefill进行长文本推理。
144-
145140
146141
**测试服务**
147142

Diff for: docs/EN/source/models/test.rst

-5
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,8 @@ internlm2-1_8b
213213
.. code-block:: console
214214
215215
$ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b \
216-
$ --enable_chunked_prefill \
217216
$ --trust_remote_code
218217
219-
.. tip::
220-
221-
``--enable_chunked_prefill`` Indicates the use of chunkedprefill for long context.
222-
223218
224219
**Test Server**
225220

Diff for: format_out/grammer/json.ebnf

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
root ::= basic_array | basic_object
2+
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
3+
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
4+
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
5+
basic_string ::= (([\"] basic_string_1 [\"]))
6+
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
7+
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
8+
basic_boolean ::= "true" | "false"
9+
basic_null ::= "null"
10+
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
11+
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
12+
ws ::= [ \n\t]*

Diff for: lightllm/common/basemodel/basemodel.py

+130-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from lightllm.common.quantization import Quantcfg
1818
from lightllm.utils.log_utils import init_logger
1919
from lightllm.utils.dist_utils import get_dp_world_size
20+
from lightllm.utils.envs_utils import get_env_start_args
21+
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
22+
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch
2023

2124
logger = init_logger(__name__)
2225

@@ -53,16 +56,15 @@ def __init__(self, kvargs):
5356
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
5457
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
5558
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
56-
enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) # chunked prefill is default on.
57-
self.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache or enable_chunked_prefill
5859
self.data_type = kvargs.get("data_type", "float16")
5960
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
6061
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
6162
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
62-
self.quant_type = kvargs.get("quant_type", None)
63+
self.quant_type = kvargs.get("quant_type", "none")
6364
self.quant_cfg_path = kvargs.get("quant_cfg", None)
6465
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
6566
self.tp_world_size_ = get_dp_world_size()
67+
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
6668

6769
self._init_datatype()
6870
self._init_config()
@@ -98,7 +100,6 @@ def _init_config(self):
98100
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
99101
if self.finetune_config:
100102
self.config["vocab_size"] = self.finetune_config.vocab_size
101-
102103
return
103104

104105
@final
@@ -207,7 +208,10 @@ def _init_cudagraph(self):
207208
None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch)
208209
)
209210
if self.graph is not None:
210-
self.graph.warmup(self)
211+
if get_env_start_args().enable_decode_microbatch_overlap:
212+
self.graph.warmup_overlap(self)
213+
else:
214+
self.graph.warmup(self)
211215

212216
def _init_custom(self):
213217
pass
@@ -296,6 +300,7 @@ def _prefill(
296300
dtype=self.data_type,
297301
device="cuda",
298302
)
303+
infer_state.dist_group = dist_group_manager.get_default_group()
299304

300305
init_req_to_token_indexes(
301306
self.req_manager.req_to_token_indexs,
@@ -346,6 +351,7 @@ def _decode(
346351
dtype=self.data_type,
347352
device="cuda",
348353
)
354+
infer_state.dist_group = dist_group_manager.get_default_group()
349355
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
350356

351357
infer_state.init_some_extra_state(self, input_ids)
@@ -359,32 +365,143 @@ def _decode(
359365
predict_logics = self._token_forward(input_ids, infer_state)
360366
return predict_logics
361367

368+
@torch.no_grad()
369+
def microbatch_overlap_decode(self, batch: DecodeMicroBatch, batch1: DecodeMicroBatch):
370+
assert batch.batch_size == batch1.batch_size
371+
assert batch.mem_indexes.is_cuda
372+
assert batch1.mem_indexes.is_cuda
373+
input_ids, input_ids1 = batch.input_ids, batch1.input_ids
374+
375+
def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
376+
infer_state = self.infer_state_class()
377+
infer_state.is_prefill = False
378+
infer_state.batch_size = cur_batch.batch_size
379+
infer_state.total_token_num = cur_batch.total_token_num
380+
infer_state.max_len_in_batch = cur_batch.max_len_in_batch
381+
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
382+
assert cur_batch.b_req_idx.shape[0] == cur_batch.b_start_loc.shape[0] == cur_batch.b_seq_len.shape[0]
383+
infer_state.b_req_idx = cur_batch.b_req_idx
384+
infer_state.b_start_loc = cur_batch.b_start_loc
385+
infer_state.b_seq_len = cur_batch.b_seq_len
386+
infer_state.multimodal_params = None
387+
infer_state.microbatch_index = batch_index
388+
389+
infer_state.mem_manager = self.mem_manager
390+
infer_state.req_manager = self.req_manager
391+
392+
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
393+
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
394+
infer_state.mem_is_contiguous = False
395+
infer_state.mem_index = cur_batch.mem_indexes
396+
infer_state.kv_buffer = torch.empty(
397+
(cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
398+
dtype=self.data_type,
399+
device="cuda",
400+
)
401+
infer_state.dist_group = dist_group_manager.get_group(batch_index)
402+
copy_kv_index_to_req(
403+
self.req_manager.req_to_token_indexs, cur_batch.b_req_idx, cur_batch.b_seq_len, infer_state.mem_index
404+
)
405+
return infer_state
406+
407+
infer_state = create_inferstate(batch, 0)
408+
infer_state1 = create_inferstate(batch1, 1)
409+
410+
infer_state.init_some_extra_state(self, input_ids)
411+
infer_state1.init_some_extra_state(self, input_ids1)
412+
413+
batch_size = batch.batch_size
414+
max_len_in_batch = max(batch.max_len_in_batch, batch1.max_len_in_batch)
415+
416+
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
417+
if self.graph.need_capture(batch_size):
418+
infer_state.is_cuda_graph = True
419+
infer_state1.is_cuda_graph = True
420+
421+
predict_logics, predict_logics1 = self.graph.capture_decode(
422+
self._overlap_tpsp_token_forward,
423+
input_ids,
424+
infer_state,
425+
input_ids1=input_ids1,
426+
infer_state1=infer_state1,
427+
)
428+
else:
429+
predict_logics, predict_logics1 = self.graph.replay(
430+
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
431+
)
432+
else:
433+
predict_logics, predict_logics1 = self._overlap_tpsp_token_forward(
434+
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
435+
)
436+
return predict_logics, predict_logics1
437+
362438
@final
363439
def _context_forward(self, input_ids, infer_state: InferStateInfo):
440+
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
364441
g_cache_manager.cache_env_in()
365442
cuda_input_ids = input_ids
366-
input_embs = self.pre_infer.context_forward(cuda_input_ids, infer_state, self.pre_post_weight)
367-
for i in range(0, self.layers_num):
368-
input_embs = self.layers_infer[i].context_forward(input_embs, infer_state, self.trans_layers_weight[i])
369-
predict_logics = self.post_infer.token_forward(input_embs, infer_state, self.pre_post_weight)
443+
444+
pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index]
445+
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
446+
447+
for i in range(self.layers_num):
448+
layer = self.layers_infer[i]
449+
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
450+
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
451+
452+
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
453+
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
454+
370455
g_cache_manager.cache_env_out()
371456
return predict_logics
372457

373458
@final
374459
def _token_forward(self, input_ids, infer_state: InferStateInfo):
460+
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
375461
g_cache_manager.cache_env_in(
376462
is_cuda_graph=infer_state.is_cuda_graph,
377463
cur_batch_size=infer_state.batch_size,
378464
cuda_graph_max_batch_size=self.graph_max_batch_size,
379465
)
380466
cuda_input_ids = input_ids
381-
input_embs = self.pre_infer.token_forward(cuda_input_ids, infer_state, self.pre_post_weight)
382-
for i in range(0, self.layers_num):
383-
input_embs = self.layers_infer[i].token_forward(input_embs, infer_state, self.trans_layers_weight[i])
384-
predict_logics = self.post_infer.token_forward(input_embs, infer_state, self.pre_post_weight)
467+
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
468+
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
469+
for i in range(self.layers_num):
470+
layer = self.layers_infer[i]
471+
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
472+
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
473+
474+
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
475+
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
476+
385477
g_cache_manager.cache_env_out()
386478
return predict_logics
387479

480+
@final
481+
def _overlap_tpsp_token_forward(
482+
self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo
483+
):
484+
g_cache_manager.cache_env_in(
485+
is_cuda_graph=infer_state.is_cuda_graph,
486+
cur_batch_size=infer_state.batch_size,
487+
cuda_graph_max_batch_size=self.graph_max_batch_size,
488+
)
489+
input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward(
490+
input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight
491+
)
492+
493+
for i in range(self.layers_num):
494+
input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward(
495+
input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i]
496+
)
497+
498+
predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward(
499+
input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight
500+
)
501+
502+
g_cache_manager.cache_env_out()
503+
return predict_logics, predict_logics1
504+
388505
@final
389506
@torch.no_grad()
390507
def _check_max_len_infer(self):

0 commit comments

Comments
 (0)