forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodeling_deepseekv3.py
1208 lines (1070 loc) · 52.4 KB
/
modeling_deepseekv3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import os
import warnings
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
DeepseekAllReduce, ParallelConfig, allgather)
from ..model_config import ModelConfig
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
from ..modules.fused_moe import BaseMoeRoutingMethod, FusedMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..pipeline_interface import PipelineInterface
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
disable_fp4_allgather)
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, register_auto_model)
class DeepseekV3MTPHead(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__()
config = model_config.pretrained_config
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
def forward(self, hidden_states: torch.Tensor, lm_head: Linear,
attn_metadata: AttentionMetadata) -> torch.Tensor:
if attn_metadata is not None:
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
last_token_hidden_states = hidden_states[last_tokens]
else:
last_token_hidden_states = hidden_states[-1].unsqueeze(0)
logits = lm_head(last_token_hidden_states)
return logits
class DeepseekV3RotaryEmbedding(RotaryEmbedding):
def __init__(self,
config: PretrainedConfig,
device: Optional[torch.device] = None):
head_dim = config.hidden_size // config.num_attention_heads
super().__init__(config,
head_dim=head_dim,
num_attention_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
device=device,
rope_type="default")
class DeepseekV3Attention(MLA):
def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
if model_config.fuse_pos_embd:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.yarn,
rope=RopeParams.from_config(config),
)
else:
pos_embd_params = None
predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
qk_rope_head_dim=config.qk_rope_head_dim,
qk_nope_head_dim=config.qk_nope_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
v_head_dim=config.v_head_dim,
predicted_tokens_per_seq=predicted_tokens_per_seq,
max_position_embeddings=config.max_position_embeddings,
bias=False,
rotary_emb=DeepseekV3RotaryEmbedding(config),
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
aux_stream=aux_stream)
class Deepseekv3RoutingImpl():
def __init__(
self,
top_k: int,
n_group: int,
topk_group: int,
routed_scaling_factor: float,
is_thop: bool = True,
):
super().__init__()
self.top_k = top_k
self.topk_group = topk_group
self.n_group = n_group
self.routed_scaling_factor = routed_scaling_factor
self.is_thop = is_thop
def noaux_tc(self, logits, e_score_correction_bias):
n_group = self.n_group
scores = F.sigmoid(logits)
scores_with_bias = scores + e_score_correction_bias
scores_shape = list(scores_with_bias.shape)
if enable_llm_debug():
has_nan = torch.isnan(scores_with_bias).any()
if has_nan:
warnings.warn(
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
)
if self.is_thop == False:
group_scores = torch.sum(torch.topk(
scores_with_bias.view(scores_shape[:-1] +
[n_group, scores_shape[-1] // n_group]),
k=2,
dim=-1,
largest=True,
sorted=True)[0],
dim=-1)
_, group_idx = torch.topk(group_scores,
k=self.topk_group,
dim=-1,
largest=True,
sorted=True)
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(-1, group_idx, 1)
score_mask = group_mask.unsqueeze(-1).expand(
scores_shape[:-1] +
[n_group, scores_shape[-1] // n_group]).reshape(scores_shape)
scores_with_bias = scores_with_bias * score_mask
_, topk_idx = torch.topk(scores_with_bias,
k=self.top_k,
dim=-1,
largest=True,
sorted=True)
new_mask = torch.zeros_like(scores)
new_mask.scatter_(-1, topk_idx, 1)
scores = scores * new_mask
score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20
scores = scores / score_sum * \
self.routed_scaling_factor
topk_values, topk_indices = torch.topk(scores,
k=self.top_k,
dim=-1,
largest=True)
return topk_values, topk_indices
else:
topk_values, topk_indices = torch.ops.trtllm.noaux_tc_op(
scores, scores_with_bias, n_group, self.topk_group, self.top_k,
self.routed_scaling_factor)
return topk_values, topk_indices
def apply(
self, logits: torch.Tensor, e_score_correction_bias: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
topk_values, topk_indices = self.noaux_tc(logits,
e_score_correction_bias)
return topk_indices.to(torch.int32), topk_values.to(torch.float32)
class Deepseekv3Gate(BaseMoeRoutingMethod):
def __init__(
self,
hidden_size: int,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
routed_scaling_factor: float,
dtype: Optional[torch.dtype] = None,
is_thop: bool = True,
):
super().__init__()
self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
dtype=dtype),
requires_grad=False)
self.e_score_correction_bias = nn.Parameter(torch.empty(
(num_experts), dtype=torch.float32),
requires_grad=False)
# TODO: e_score_correction_bias makes sense to live in this gate class, but it is needed for the routing impl
# So we don't run into issues with weight loading, we make this gate object the BaseMoeRoutingMethod
# and then dispatch to the routing impl for the actual implementation.
# This is a bit of a hack and we should clean this up in the future.
self.routing_impl = Deepseekv3RoutingImpl(
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
is_thop=is_thop)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = torch.ops.trtllm.cublas_mm(hidden_states,
self.weight.t(),
bias=None,
out_dtype=torch.float32)
return logits
def load_weights(self, weights: List[Dict]):
assert len(weights) == 1
self.weight.copy_(weights[0]["weight"][:])
self.e_score_correction_bias.copy_(
weights[0]["e_score_correction_bias"][:].to(torch.float32))
def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.routing_impl.apply(logits, self.e_score_correction_bias)
@property
def routing_method(self) -> BaseMoeRoutingMethod:
return self
def get_experts_per_token(self):
return self.routing_impl.top_k
class Deepseekv3MoE(nn.Module):
def __init__(self,
*,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
shared_expert_intermediate_size: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig()):
from ..distributed import AllReduce
super().__init__()
config = model_config.pretrained_config
self.top_k = top_k
self.use_dp = model_config.mapping.enable_attention_dp
self.gate = Deepseekv3Gate(
hidden_size,
num_experts,
top_k=top_k,
n_group=config.n_group,
topk_group=config.topk_group,
routed_scaling_factor=config.routed_scaling_factor,
dtype=dtype)
self.experts = FusedMoE(
num_experts=num_experts,
routing_method=self.gate.routing_method,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=
False, # In both low latency and attention dp scenarios, FusedMoE needs not to do allreduce inside op.
model_config=model_config,
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
self.shared_output_scale = None
if self.use_dp:
shared_tp_size = 1
else:
assert shared_expert_intermediate_size % 128 == 0
shared_tp_size = math.gcd(
shared_expert_intermediate_size // 128,
model_config.mapping.tp_size,
)
if shared_tp_size != model_config.mapping.tp_size:
self.shared_output_scale = shared_tp_size / model_config.mapping.tp_size
self.shared_experts = GatedMLP(
hidden_size=hidden_size,
intermediate_size=shared_expert_intermediate_size,
bias=False,
dtype=dtype,
config=model_config,
overridden_tp_size=shared_tp_size,
is_expert=True)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.all_reduce = AllReduce(self.parallel_config)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared]
}
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, min_latency_mode):
if self.use_dp and self.parallel_config.tensor_parallel_size > 1:
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
if disable_fp4_allgather():
hidden_states = allgather(hidden_states,
self.parallel_config,
gather_dim=0)
router_logits = self.gate(hidden_states)
if hidden_states_fp4 is not None:
routed_output = self.experts(hidden_states_fp4,
router_logits,
min_latency_mode,
output_dtype=hidden_states.dtype)
else:
routed_output = self.experts(
hidden_states,
router_logits,
min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens)
return routed_output
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
min_latency_mode: Optional[bool] = False,
) -> torch.Tensor:
if min_latency_mode:
assert not self.use_dp
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
# This design is mainly for low latency use case. Need to improve for max throughput use case.
do_multi_stream = is_graph_capturing()
if do_multi_stream:
self.event_dict[EventType.Main].record()
shared_output = self.shared_experts(hidden_states)
if self.shared_output_scale is not None:
shared_output *= self.shared_output_scale
if do_multi_stream:
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
routed_output = self.compute_routed_output(
hidden_states, hidden_states_fp4, all_rank_num_tokens,
min_latency_mode)
self.event_dict[EventType.MoeShared].record()
self.event_dict[EventType.MoeShared].wait()
else:
routed_output = self.compute_routed_output(hidden_states,
hidden_states_fp4,
all_rank_num_tokens,
min_latency_mode)
if min_latency_mode:
return [shared_output, *routed_output]
assert shared_output.size() == routed_output.size(
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if not self.use_dp and self.parallel_config.tensor_parallel_size > 1:
final_hidden_states = self.all_reduce(
final_hidden_states, all_reduce_params=final_all_reduce_params)
return final_hidden_states
class DeepseekV3DecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__()
config = model_config.pretrained_config
self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
self.num_shared_experts = config.n_shared_experts
self.top_k = config.num_experts_per_tok
self.self_attn = DeepseekV3Attention(
model_config,
layer_idx=layer_idx,
aux_stream=aux_stream_dict[AuxStreamType.Attention])
self.fusion_config = EagerFusionConfig()
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mlp_tp_size = model_config.mapping.tp_size
self.enable_fusion = os.environ.get(
"TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0"
pp_layer_offset = model_config.mapping.pp_layers_torch(
config.num_hidden_layers)[0]
global_layer_idx = pp_layer_offset + layer_idx
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
if (config.n_routed_experts is not None
and global_layer_idx >= config.first_k_dense_replace
and global_layer_idx % config.moe_layer_freq == 0):
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and model_config.mapping.has_tp(
) and not self.enable_attention_dp
self.fusion_config.POST_MOE_FUSION = self.enable_fusion and model_config.mapping.has_tp(
) and not self.enable_attention_dp and not model_config.mapping.has_pp(
)
self.mlp = Deepseekv3MoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=self.moe_intermediate_size,
shared_expert_intermediate_size=self.moe_intermediate_size *
self.num_shared_experts,
dtype=config.torch_dtype,
model_config=model_config,
aux_stream_dict=aux_stream_dict)
else:
if self.enable_attention_dp:
self.mlp_tp_size = 1
else:
assert config.intermediate_size % 128 == 0
self.mlp_tp_size = math.gcd(
math.gcd(
config.intermediate_size // 128,
model_config.mapping.tp_size,
),
model_config.mapping.
gpus_per_node, # Avoid costly inter-node TP
)
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and model_config.mapping.has_tp(
) and self.is_nvfp4 and not self.enable_attention_dp
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not self.enable_attention_dp and not model_config.mapping.has_pp(
)
self.mlp = GatedMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
overridden_tp_size=self.mlp_tp_size,
is_expert=False)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.layer_idx = layer_idx
self.all_reduce = AllReduce(self.parallel_config)
self.next_layer_layernorm: RMSNorm = None
self.deepseek_allreduce_disabled = os.environ.get(
"TRTLLM_DEEPSEEK_ALLREDUCE_FUSION_DISABLED", "0") == "1"
if model_config.mapping.is_multi_node():
self.deepseek_allreduce_disabled = True
if not self.deepseek_allreduce_disabled:
self.deepseek_allreduce = DeepseekAllReduce(self.parallel_config)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
0) > 128
min_latency_mode = True if hidden_states.size(
0
) <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 else False
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.PRE_MOE_FUSION or self.fusion_config.
PRE_MLP_FUSION or self.parallel_config.tensor_parallel_size == 1
or self.enable_attention_dp)),
**kwargs,
)
if self.fusion_config.PRE_MOE_FUSION:
# Custom AR Fusion for DeepseekV3
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
else:
if min_latency_mode:
hidden_states, hidden_states_act, hidden_states_sf, residual = self.deepseek_allreduce(
hidden_states,
[
residual, self.post_attention_layernorm.weight,
self.mlp.experts.fc31_input_scale
],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM_AND_QUANT_NVFP4,
)
hidden_states_fp4 = Fp4QuantizedTensor(
hidden_states_act, hidden_states_sf)
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.post_attention_layernorm.weight],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
elif self.fusion_config.PRE_MLP_FUSION:
# Custom AR Fusion for DeepseekV3 with quant_fp4
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
hidden_states, self.mlp.gate_up_proj.input_scale,
self.mlp.gate_up_proj.scaling_vector_size, False)
else:
act_fp4, act_sf, residual = self.deepseek_allreduce(
hidden_states,
[
residual, self.post_attention_layernorm.weight,
self.mlp.gate_up_proj.input_scale
],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
)
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.fusion_config.PRE_MOE_FUSION and min_latency_mode:
hidden_states = self.mlp(
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
== 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
else:
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
== 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
if self.fusion_config.POST_MOE_FUSION:
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
if min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
activated_expert_global_ids = hidden_states[4]
hidden_states, residual = self.deepseek_allreduce(
hidden_states_activated_experts, # not used
[
residual, self.next_layer_layernorm.weight,
num_activated_experts_per_node,
experts_to_token_score,
hidden_states_activated_experts, shared_output,
activated_expert_global_ids
],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM,
)
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.next_layer_layernorm.weight],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
elif self.fusion_config.POST_MLP_FUSION:
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.next_layer_layernorm.weight],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
return hidden_states, residual
class DeepseekV3MTP(DeepseekV3DecoderLayer):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__(model_config, layer_idx, aux_stream_dict)
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
self.num_shared_experts = config.n_shared_experts
self.top_k = config.num_experts_per_tok
self.enorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.eh_proj = Linear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
skip_create_weights=model_config.skip_create_weights,
)
self.shared_head = DeepseekV3MTPHead(model_config)
def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
lm_head: Linear,
embed_tokens: nn.Embedding,
attn_metadata: AttentionMetadata,
spec_metadata: MTPSpecMetadata,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# deepseek allreduce kernel is better when m < 512
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
0) >= 512
inputs_embeds = self.enorm(embed_tokens(input_ids))
hidden_states = self.hnorm(hidden_states)
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.eh_proj(hidden_states)
# Input layer norm
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.PRE_MOE_FUSION or self.parallel_config.
tensor_parallel_size == 1 or self.enable_attention_dp)),
**kwargs,
)
# MTP Layer Must have sparse MOE
if self.fusion_config.PRE_MOE_FUSION:
# Custom AR Fusion for DeepseekV3
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.post_attention_layernorm.weight],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# Fully Connected
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.parallel_config.
tensor_parallel_size == 1 or self.enable_attention_dp)),
)
if self.fusion_config.POST_MOE_FUSION:
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.shared_head.norm.weight,
eps=self.shared_head.norm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.shared_head.norm.weight],
self.shared_head.norm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
else:
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
logits = self.shared_head(hidden_states, lm_head, attn_metadata).float()
return hidden_states, logits
@support_pp
class DeepseekV3Model(DecoderModel):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(model_config)
config = model_config.pretrained_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.aux_stream_dict = {
key: torch.cuda.Stream()
for key in [
AuxStreamType.Attention, AuxStreamType.MoeShared,
AuxStreamType.MoeChunkingOverlap
]
}
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype)
self.layers = nn.ModuleList([
DeepseekV3DecoderLayer(model_config, layer_idx,
self.aux_stream_dict)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = hidden_states
hidden_states = self.layers[0].input_layernorm(hidden_states)
for decoder_layer in self.layers[:self.num_hidden_layers]:
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
return hidden_states
def _pp_forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pipeline_interface: Optional[PipelineInterface] = None,
) -> torch.Tensor:
if self.pp_rank != 0:
if pipeline_interface is None:
raise ValueError(
"pipeline_interface is required for non-first pp rank.")
hidden_states, residual = pipeline_interface
else:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = hidden_states
local_decoder_layers = ([
self.layers[layer_id] for layer_id in self.pp_layer_list
] if self.pp_size > 1 else self.layers)
if self.pp_rank == 0:
hidden_states = local_decoder_layers[0].input_layernorm(
hidden_states)
else:
hidden_states, residual = local_decoder_layers[0].input_layernorm(
hidden_states, residual)
for decoder_layer in local_decoder_layers:
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
if not self.pp_rank == self.pp_size - 1:
return PipelineInterface(hidden_states, residual)
else:
return hidden_states
@register_auto_model("DeepseekV3ForCausalLM")
class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
PretrainedConfig]):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__(DeepseekV3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
self.model_nextn = 0
if model_config.spec_config is not None:
assert not model_config.mapping.has_pp(
), "PP + MTP combination is not supported"
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1:
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
else:
# TODO: fix the accuracy issue and remove this assert.
assert False, "Cannot support num_nextn_predict_layers>1 in checkpoint now. Will fix it soon"
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
layer_idx + self.num_hidden_layers,
self.model.aux_stream_dict)
for layer_idx in range(model_nextn)
])
self.model.layers.extend(mtp_layers)
self.mtp_worker = MTPWorker(model_config.spec_config)
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []
for model_mtp_idx in range(
self.num_hidden_layers,
self.num_hidden_layers + model_nextn):
ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers
) % ckpt_nextn + self.num_hidden_layers
model_prefix = f"model.layers.{model_mtp_idx}"
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
for exclude_module in model_config.quant_config.exclude_modules:
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
extend_exclude_modules.append(
exclude_module.replace(
ckpt_prefix, model_prefix))
self.model_config.quant_config.exclude_modules.extend(
extend_exclude_modules)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[MTPSpecMetadata] = None,
pipeline_interface: Optional[PipelineInterface] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
attn_metadata.num_generations_per_batch = self.model_nextn + 1
if self._supports_pp and self.pp_size > 1:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pipeline_interface=pipeline_interface,
)
# No need to compute logits for non-last PP ranks
if self.pp_rank < self.pp_size - 1:
return output
else:
hidden_states = output
else:
hidden_states = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)
if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
# get logits
logits = self.logits_processor.forward(
hidden_states[spec_metadata.gather_ids],
self.lm_head,
attn_metadata,
True,
)
# get accepetd tokens and next draft tokens
return self.mtp_worker(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
logits=logits,
lm_head=self.lm_head,
embed_tokens=self.model.embed_tokens,