17
17
from lightllm .common .quantization import Quantcfg
18
18
from lightllm .utils .log_utils import init_logger
19
19
from lightllm .utils .dist_utils import get_dp_world_size
20
+ from lightllm .utils .envs_utils import get_env_start_args
21
+ from lightllm .distributed .communication_op import CustomProcessGroup , dist_group_manager
22
+ from lightllm .common .basemodel .microbatch_overlap_objs import DecodeMicroBatch
20
23
21
24
logger = init_logger (__name__ )
22
25
@@ -53,16 +56,15 @@ def __init__(self, kvargs):
53
56
self .return_all_prompt_logics = kvargs .get ("return_all_prompt_logics" , False )
54
57
assert not (self .is_token_healing and self .return_all_prompt_logics ), "can not be true in same time"
55
58
self .use_dynamic_prompt_cache = kvargs .get ("use_dynamic_prompt_cache" , False )
56
- enable_chunked_prefill = kvargs .get ("enable_chunked_prefill" , False ) # chunked prefill is default on.
57
- self .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache or enable_chunked_prefill
58
59
self .data_type = kvargs .get ("data_type" , "float16" )
59
60
self .graph_max_batch_size = kvargs .get ("graph_max_batch_size" , 16 )
60
61
self .graph_max_len_in_batch = kvargs .get ("graph_max_len_in_batch" , 8192 )
61
62
self .disable_cudagraph = kvargs .get ("disable_cudagraph" , False )
62
- self .quant_type = kvargs .get ("quant_type" , None )
63
+ self .quant_type = kvargs .get ("quant_type" , "none" )
63
64
self .quant_cfg_path = kvargs .get ("quant_cfg" , None )
64
65
self .mem_fraction = kvargs .get ("mem_fraction" , 0.9 )
65
66
self .tp_world_size_ = get_dp_world_size ()
67
+ self .enable_tpsp_mix_mode = get_env_start_args ().enable_tpsp_mix_mode
66
68
67
69
self ._init_datatype ()
68
70
self ._init_config ()
@@ -98,7 +100,6 @@ def _init_config(self):
98
100
repair_config (self .config , same_names = ["num_hidden_layers" , "n_layer" ])
99
101
if self .finetune_config :
100
102
self .config ["vocab_size" ] = self .finetune_config .vocab_size
101
-
102
103
return
103
104
104
105
@final
@@ -207,7 +208,10 @@ def _init_cudagraph(self):
207
208
None if self .disable_cudagraph else CudaGraph (self .graph_max_batch_size , self .graph_max_len_in_batch )
208
209
)
209
210
if self .graph is not None :
210
- self .graph .warmup (self )
211
+ if get_env_start_args ().enable_decode_microbatch_overlap :
212
+ self .graph .warmup_overlap (self )
213
+ else :
214
+ self .graph .warmup (self )
211
215
212
216
def _init_custom (self ):
213
217
pass
@@ -296,6 +300,7 @@ def _prefill(
296
300
dtype = self .data_type ,
297
301
device = "cuda" ,
298
302
)
303
+ infer_state .dist_group = dist_group_manager .get_default_group ()
299
304
300
305
init_req_to_token_indexes (
301
306
self .req_manager .req_to_token_indexs ,
@@ -346,6 +351,7 @@ def _decode(
346
351
dtype = self .data_type ,
347
352
device = "cuda" ,
348
353
)
354
+ infer_state .dist_group = dist_group_manager .get_default_group ()
349
355
copy_kv_index_to_req (self .req_manager .req_to_token_indexs , b_req_idx , b_seq_len , infer_state .mem_index )
350
356
351
357
infer_state .init_some_extra_state (self , input_ids )
@@ -359,32 +365,143 @@ def _decode(
359
365
predict_logics = self ._token_forward (input_ids , infer_state )
360
366
return predict_logics
361
367
368
+ @torch .no_grad ()
369
+ def microbatch_overlap_decode (self , batch : DecodeMicroBatch , batch1 : DecodeMicroBatch ):
370
+ assert batch .batch_size == batch1 .batch_size
371
+ assert batch .mem_indexes .is_cuda
372
+ assert batch1 .mem_indexes .is_cuda
373
+ input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
374
+
375
+ def create_inferstate (cur_batch : DecodeMicroBatch , batch_index ):
376
+ infer_state = self .infer_state_class ()
377
+ infer_state .is_prefill = False
378
+ infer_state .batch_size = cur_batch .batch_size
379
+ infer_state .total_token_num = cur_batch .total_token_num
380
+ infer_state .max_len_in_batch = cur_batch .max_len_in_batch
381
+ infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
382
+ assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_start_loc .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
383
+ infer_state .b_req_idx = cur_batch .b_req_idx
384
+ infer_state .b_start_loc = cur_batch .b_start_loc
385
+ infer_state .b_seq_len = cur_batch .b_seq_len
386
+ infer_state .multimodal_params = None
387
+ infer_state .microbatch_index = batch_index
388
+
389
+ infer_state .mem_manager = self .mem_manager
390
+ infer_state .req_manager = self .req_manager
391
+
392
+ # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
393
+ # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
394
+ infer_state .mem_is_contiguous = False
395
+ infer_state .mem_index = cur_batch .mem_indexes
396
+ infer_state .kv_buffer = torch .empty (
397
+ (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
398
+ dtype = self .data_type ,
399
+ device = "cuda" ,
400
+ )
401
+ infer_state .dist_group = dist_group_manager .get_group (batch_index )
402
+ copy_kv_index_to_req (
403
+ self .req_manager .req_to_token_indexs , cur_batch .b_req_idx , cur_batch .b_seq_len , infer_state .mem_index
404
+ )
405
+ return infer_state
406
+
407
+ infer_state = create_inferstate (batch , 0 )
408
+ infer_state1 = create_inferstate (batch1 , 1 )
409
+
410
+ infer_state .init_some_extra_state (self , input_ids )
411
+ infer_state1 .init_some_extra_state (self , input_ids1 )
412
+
413
+ batch_size = batch .batch_size
414
+ max_len_in_batch = max (batch .max_len_in_batch , batch1 .max_len_in_batch )
415
+
416
+ if self .graph is not None and self .graph .can_run (batch_size , max_len_in_batch ):
417
+ if self .graph .need_capture (batch_size ):
418
+ infer_state .is_cuda_graph = True
419
+ infer_state1 .is_cuda_graph = True
420
+
421
+ predict_logics , predict_logics1 = self .graph .capture_decode (
422
+ self ._overlap_tpsp_token_forward ,
423
+ input_ids ,
424
+ infer_state ,
425
+ input_ids1 = input_ids1 ,
426
+ infer_state1 = infer_state1 ,
427
+ )
428
+ else :
429
+ predict_logics , predict_logics1 = self .graph .replay (
430
+ input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
431
+ )
432
+ else :
433
+ predict_logics , predict_logics1 = self ._overlap_tpsp_token_forward (
434
+ input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
435
+ )
436
+ return predict_logics , predict_logics1
437
+
362
438
@final
363
439
def _context_forward (self , input_ids , infer_state : InferStateInfo ):
440
+ run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
364
441
g_cache_manager .cache_env_in ()
365
442
cuda_input_ids = input_ids
366
- input_embs = self .pre_infer .context_forward (cuda_input_ids , infer_state , self .pre_post_weight )
367
- for i in range (0 , self .layers_num ):
368
- input_embs = self .layers_infer [i ].context_forward (input_embs , infer_state , self .trans_layers_weight [i ])
369
- predict_logics = self .post_infer .token_forward (input_embs , infer_state , self .pre_post_weight )
443
+
444
+ pre_method = (self .pre_infer .context_forward , self .pre_infer .tpsp_context_forward )[run_mode_index ]
445
+ input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
446
+
447
+ for i in range (self .layers_num ):
448
+ layer = self .layers_infer [i ]
449
+ layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
450
+ input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
451
+
452
+ post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
453
+ predict_logics = post_method (input_embs , infer_state , self .pre_post_weight )
454
+
370
455
g_cache_manager .cache_env_out ()
371
456
return predict_logics
372
457
373
458
@final
374
459
def _token_forward (self , input_ids , infer_state : InferStateInfo ):
460
+ run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
375
461
g_cache_manager .cache_env_in (
376
462
is_cuda_graph = infer_state .is_cuda_graph ,
377
463
cur_batch_size = infer_state .batch_size ,
378
464
cuda_graph_max_batch_size = self .graph_max_batch_size ,
379
465
)
380
466
cuda_input_ids = input_ids
381
- input_embs = self .pre_infer .token_forward (cuda_input_ids , infer_state , self .pre_post_weight )
382
- for i in range (0 , self .layers_num ):
383
- input_embs = self .layers_infer [i ].token_forward (input_embs , infer_state , self .trans_layers_weight [i ])
384
- predict_logics = self .post_infer .token_forward (input_embs , infer_state , self .pre_post_weight )
467
+ pre_method = (self .pre_infer .token_forward , self .pre_infer .tpsp_token_forward )[run_mode_index ]
468
+ input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
469
+ for i in range (self .layers_num ):
470
+ layer = self .layers_infer [i ]
471
+ layer_method = (layer .token_forward , layer .tpsp_token_forward )[run_mode_index ]
472
+ input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
473
+
474
+ post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
475
+ predict_logics = post_method (input_embs , infer_state , self .pre_post_weight )
476
+
385
477
g_cache_manager .cache_env_out ()
386
478
return predict_logics
387
479
480
+ @final
481
+ def _overlap_tpsp_token_forward (
482
+ self , input_ids , infer_state : InferStateInfo , input_ids1 , infer_state1 : InferStateInfo
483
+ ):
484
+ g_cache_manager .cache_env_in (
485
+ is_cuda_graph = infer_state .is_cuda_graph ,
486
+ cur_batch_size = infer_state .batch_size ,
487
+ cuda_graph_max_batch_size = self .graph_max_batch_size ,
488
+ )
489
+ input_embs , input_embs1 = self .pre_infer .overlap_tpsp_token_forward (
490
+ input_ids , input_ids1 , infer_state , infer_state1 , self .pre_post_weight
491
+ )
492
+
493
+ for i in range (self .layers_num ):
494
+ input_embs , input_embs1 = self .layers_infer [i ].overlap_tpsp_token_forward (
495
+ input_embs , input_embs1 , infer_state , infer_state1 , self .trans_layers_weight [i ]
496
+ )
497
+
498
+ predict_logics , predict_logics1 = self .post_infer .overlap_tpsp_token_forward (
499
+ input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
500
+ )
501
+
502
+ g_cache_manager .cache_env_out ()
503
+ return predict_logics , predict_logics1
504
+
388
505
@final
389
506
@torch .no_grad ()
390
507
def _check_max_len_infer (self ):
0 commit comments