Skip to content

Commit 41f9ff3

Browse files
mgoinlulmer
authored andcommitted
[Misc] Remove duplicated DeepSeek V2/V3 model definition (vllm-project#12793)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent c0db76b commit 41f9ff3

File tree

4 files changed

+36
-821
lines changed

4 files changed

+36
-821
lines changed

vllm/config.py

-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,6 @@ def get_hidden_size(self) -> int:
754754

755755
@property
756756
def is_deepseek_mla(self) -> bool:
757-
# TODO add deepseek_v3
758757
return (hasattr(self.hf_text_config, "model_type")) \
759758
and (self.hf_text_config.model_type in \
760759
('deepseek_v2', 'deepseek_v3'))\

vllm/model_executor/models/deepseek_v2.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
24-
"""Inference-only DeepseekV2 model."""
24+
"""Inference-only DeepseekV2/DeepseekV3 model."""
2525
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
2626

2727
import torch
@@ -115,23 +115,32 @@ def __init__(
115115
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
116116
"Only silu is supported for now.")
117117

118-
self.experts = FusedMoE(num_experts=config.n_routed_experts,
119-
top_k=config.num_experts_per_tok,
120-
hidden_size=config.hidden_size,
121-
intermediate_size=config.moe_intermediate_size,
122-
reduce_results=False,
123-
renormalize=config.norm_topk_prob,
124-
quant_config=quant_config,
125-
use_grouped_topk=True,
126-
num_expert_group=config.n_group,
127-
topk_group=config.topk_group,
128-
prefix=f"{prefix}.experts")
129-
130118
self.gate = ReplicatedLinear(config.hidden_size,
131119
config.n_routed_experts,
132120
bias=False,
133121
quant_config=None,
134122
prefix=f"{prefix}.gate")
123+
if config.topk_method == "noaux_tc":
124+
self.gate.e_score_correction_bias = nn.Parameter(
125+
torch.empty(config.n_routed_experts))
126+
else:
127+
self.gate.e_score_correction_bias = None
128+
129+
self.experts = FusedMoE(
130+
num_experts=config.n_routed_experts,
131+
top_k=config.num_experts_per_tok,
132+
hidden_size=config.hidden_size,
133+
intermediate_size=config.moe_intermediate_size,
134+
reduce_results=False,
135+
renormalize=config.norm_topk_prob,
136+
quant_config=quant_config,
137+
use_grouped_topk=True,
138+
num_expert_group=config.n_group,
139+
topk_group=config.topk_group,
140+
prefix=f"{prefix}.experts",
141+
scoring_func=config.scoring_func,
142+
e_score_correction_bias=self.gate.e_score_correction_bias)
143+
135144
if config.n_shared_experts is not None:
136145
intermediate_size = (config.moe_intermediate_size *
137146
config.n_shared_experts)
@@ -732,6 +741,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
732741
for name, loaded_weight in weights:
733742
if "rotary_emb.inv_freq" in name:
734743
continue
744+
745+
# TODO(simon): support nextn predict layers
746+
if hasattr(self.config, "num_nextn_predict_layers"
747+
) and self.config.num_nextn_predict_layers > 0:
748+
assert self.config.num_nextn_predict_layers == 1
749+
layer_idx = self.config.num_hidden_layers
750+
if name.startswith(f"model.layers.{layer_idx}"):
751+
continue
752+
735753
for (param_name, weight_name, shard_id) in stacked_params_mapping:
736754
# Skip non-stacked layers and experts (experts handled below).
737755
if weight_name not in name:
@@ -793,3 +811,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
793811
weight_loader(param, loaded_weight)
794812
loaded_params.add(name)
795813
return loaded_params
814+
815+
816+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
817+
pass

0 commit comments

Comments
 (0)