|
21 | 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22 | 22 | # See the License for the specific language governing permissions and
|
23 | 23 | # limitations under the License.
|
24 |
| -"""Inference-only DeepseekV2 model.""" |
| 24 | +"""Inference-only DeepseekV2/DeepseekV3 model.""" |
25 | 25 | from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
26 | 26 |
|
27 | 27 | import torch
|
@@ -115,23 +115,32 @@ def __init__(
|
115 | 115 | raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
116 | 116 | "Only silu is supported for now.")
|
117 | 117 |
|
118 |
| - self.experts = FusedMoE(num_experts=config.n_routed_experts, |
119 |
| - top_k=config.num_experts_per_tok, |
120 |
| - hidden_size=config.hidden_size, |
121 |
| - intermediate_size=config.moe_intermediate_size, |
122 |
| - reduce_results=False, |
123 |
| - renormalize=config.norm_topk_prob, |
124 |
| - quant_config=quant_config, |
125 |
| - use_grouped_topk=True, |
126 |
| - num_expert_group=config.n_group, |
127 |
| - topk_group=config.topk_group, |
128 |
| - prefix=f"{prefix}.experts") |
129 |
| - |
130 | 118 | self.gate = ReplicatedLinear(config.hidden_size,
|
131 | 119 | config.n_routed_experts,
|
132 | 120 | bias=False,
|
133 | 121 | quant_config=None,
|
134 | 122 | prefix=f"{prefix}.gate")
|
| 123 | + if config.topk_method == "noaux_tc": |
| 124 | + self.gate.e_score_correction_bias = nn.Parameter( |
| 125 | + torch.empty(config.n_routed_experts)) |
| 126 | + else: |
| 127 | + self.gate.e_score_correction_bias = None |
| 128 | + |
| 129 | + self.experts = FusedMoE( |
| 130 | + num_experts=config.n_routed_experts, |
| 131 | + top_k=config.num_experts_per_tok, |
| 132 | + hidden_size=config.hidden_size, |
| 133 | + intermediate_size=config.moe_intermediate_size, |
| 134 | + reduce_results=False, |
| 135 | + renormalize=config.norm_topk_prob, |
| 136 | + quant_config=quant_config, |
| 137 | + use_grouped_topk=True, |
| 138 | + num_expert_group=config.n_group, |
| 139 | + topk_group=config.topk_group, |
| 140 | + prefix=f"{prefix}.experts", |
| 141 | + scoring_func=config.scoring_func, |
| 142 | + e_score_correction_bias=self.gate.e_score_correction_bias) |
| 143 | + |
135 | 144 | if config.n_shared_experts is not None:
|
136 | 145 | intermediate_size = (config.moe_intermediate_size *
|
137 | 146 | config.n_shared_experts)
|
@@ -732,6 +741,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
|
732 | 741 | for name, loaded_weight in weights:
|
733 | 742 | if "rotary_emb.inv_freq" in name:
|
734 | 743 | continue
|
| 744 | + |
| 745 | + # TODO(simon): support nextn predict layers |
| 746 | + if hasattr(self.config, "num_nextn_predict_layers" |
| 747 | + ) and self.config.num_nextn_predict_layers > 0: |
| 748 | + assert self.config.num_nextn_predict_layers == 1 |
| 749 | + layer_idx = self.config.num_hidden_layers |
| 750 | + if name.startswith(f"model.layers.{layer_idx}"): |
| 751 | + continue |
| 752 | + |
735 | 753 | for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
736 | 754 | # Skip non-stacked layers and experts (experts handled below).
|
737 | 755 | if weight_name not in name:
|
@@ -793,3 +811,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
|
793 | 811 | weight_loader(param, loaded_weight)
|
794 | 812 | loaded_params.add(name)
|
795 | 813 | return loaded_params
|
| 814 | + |
| 815 | + |
| 816 | +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): |
| 817 | + pass |
0 commit comments