Skip to content

Commit 9e743e7

Browse files
SzymonOzoglulmer
authored andcommitted
[Model] Deepseek GGUF support (vllm-project#13167)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 964c06d commit 9e743e7

File tree

8 files changed

+198
-10
lines changed

8 files changed

+198
-10
lines changed

docs/source/features/quantization/gguf.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
2929
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
3030
:::
3131

32+
GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path
33+
34+
```console
35+
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
36+
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
37+
```
38+
3239
You can also use the GGUF model directly through the LLM entrypoint:
3340

3441
```python

vllm/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
trust_remote_code: bool,
230230
dtype: Union[str, torch.dtype],
231231
seed: int,
232+
hf_config_path: Optional[str] = None,
232233
allowed_local_media_path: str = "",
233234
revision: Optional[str] = None,
234235
code_revision: Optional[str] = None,
@@ -259,6 +260,7 @@ def __init__(
259260
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
260261
) -> None:
261262
self.model = model
263+
self.hf_config_path = hf_config_path
262264
self.tokenizer = tokenizer
263265
self.tokenizer_mode = tokenizer_mode
264266
self.trust_remote_code = trust_remote_code
@@ -321,8 +323,9 @@ def __init__(
321323
if self.enable_sleep_mode and not current_platform.is_cuda():
322324
raise ValueError("Sleep mode is only supported on CUDA devices.")
323325

324-
hf_config = get_config(self.model, trust_remote_code, revision,
325-
code_revision, config_format)
326+
hf_config = get_config(self.hf_config_path or self.model,
327+
trust_remote_code, revision, code_revision,
328+
config_format)
326329

327330
if hf_overrides_kw:
328331
logger.info("Overriding HF config with %s", hf_overrides_kw)
@@ -947,7 +950,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
947950
def try_get_generation_config(self) -> Dict[str, Any]:
948951
if self.generation_config is None or self.generation_config == "auto":
949952
config = try_get_generation_config(
950-
self.model,
953+
self.hf_config_path or self.model,
951954
trust_remote_code=self.trust_remote_code,
952955
revision=self.revision,
953956
)

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class EngineArgs:
9393
model: str = 'facebook/opt-125m'
9494
served_model_name: Optional[Union[str, List[str]]] = None
9595
tokenizer: Optional[str] = None
96+
hf_config_path: Optional[str] = None
9697
task: TaskOption = "auto"
9798
skip_tokenizer_init: bool = False
9899
tokenizer_mode: str = 'auto'
@@ -262,6 +263,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
262263
default=EngineArgs.tokenizer,
263264
help='Name or path of the huggingface tokenizer to use. '
264265
'If unspecified, model name or path will be used.')
266+
parser.add_argument(
267+
"--hf-config-path",
268+
type=nullable_str,
269+
default=EngineArgs.hf_config_path,
270+
help='Name or path of the huggingface config to use. '
271+
'If unspecified, model name or path will be used.')
265272
parser.add_argument(
266273
'--skip-tokenizer-init',
267274
action='store_true',
@@ -1076,6 +1083,7 @@ def create_model_config(self) -> ModelConfig:
10761083

10771084
return ModelConfig(
10781085
model=self.model,
1086+
hf_config_path=self.hf_config_path,
10791087
task=self.task,
10801088
# We know this is not None because we set it in __post_init__
10811089
tokenizer=cast(str, self.tokenizer),

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable, List, Optional, Tuple
66

77
import torch
8+
from torch.nn.parameter import UninitializedParameter
89

910
import vllm.envs as envs
1011
from vllm.distributed import (get_tensor_model_parallel_rank,
@@ -514,7 +515,12 @@ def weight_loader(self, param: torch.nn.Parameter,
514515
# dimension intermediate_size_per_partition is used.
515516
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
516517

517-
expert_data = param.data[expert_id]
518+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
519+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
520+
if is_gguf_weight_type:
521+
param.weight_type = loaded_weight.item()
522+
param.data.copy_(loaded_weight)
523+
return
518524

519525
# is_transposed: if the dim to shard the weight
520526
# should be flipped. Required by GPTQ, compressed-tensors
@@ -524,6 +530,20 @@ def weight_loader(self, param: torch.nn.Parameter,
524530
if is_transposed:
525531
shard_dim = int(not shard_dim)
526532

533+
full_load = len(loaded_weight.shape) == 3
534+
if full_load:
535+
shard_dim += 1
536+
537+
# Materialize GGUF UninitializedParameter
538+
if is_gguf_weight and isinstance(param, UninitializedParameter):
539+
final_shape = list(loaded_weight.shape)
540+
if shard_id in ["w1", "w3"]:
541+
final_shape[1] *= 2
542+
final_shape[shard_dim] = final_shape[
543+
shard_dim] // get_tensor_model_parallel_world_size()
544+
param.materialize(final_shape, dtype=loaded_weight.dtype)
545+
546+
expert_data = param.data if full_load else param.data[expert_id]
527547
# Case input scale: input_scale loading is only supported for fp8
528548
if "input_scale" in weight_name:
529549
# this is needed for compressed-tensors only

vllm/model_executor/layers/linear.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,23 @@ def __init__(self,
235235
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
236236
# If the weight on disk does not have a shape, give it one
237237
# (such scales for AutoFp8).
238+
# Special case for GGUF
239+
240+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
241+
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
242+
if is_gguf_weight_type:
243+
param.weight_type = loaded_weight.item()
244+
245+
# Materialize GGUF UninitializedParameter
246+
if is_gguf_weight and isinstance(param, UninitializedParameter):
247+
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
248+
238249
if len(loaded_weight.shape) == 0:
239250
loaded_weight = loaded_weight.reshape(1)
240251

241-
assert param.size() == loaded_weight.size()
252+
assert param.size() == loaded_weight.size(), (
253+
f"Tried to load weights of size {loaded_weight.size()}"
254+
f"to a parameter of size {param.size()}")
242255
param.data.copy_(loaded_weight)
243256

244257
def forward(self,

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional
44

55
import gguf
66
import torch
77
from gguf import GGMLQuantizationType as WeightType
88
from torch.nn.parameter import Parameter, UninitializedParameter
99

1010
from vllm import _custom_ops as ops
11+
from vllm.model_executor.layers.activation import SiluAndMul
12+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
13+
FusedMoEMethodBase)
1114
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
1215
from vllm.model_executor.layers.quantization.base_config import (
1316
QuantizationConfig, QuantizeMethodBase)
@@ -29,7 +32,7 @@ def get_name(self) -> str:
2932
return "gguf"
3033

3134
def get_supported_act_dtypes(self) -> List[torch.dtype]:
32-
return [torch.half, torch.bfloat16]
35+
return [torch.half]
3336

3437
@classmethod
3538
def get_min_capability(cls) -> int:
@@ -49,6 +52,8 @@ def get_quant_method(self, layer: torch.nn.Module,
4952
return GGUFLinearMethod(self)
5053
elif isinstance(layer, VocabParallelEmbedding):
5154
return GGUFEmbeddingMethod(self)
55+
elif isinstance(layer, FusedMoE):
56+
return GGUFMoEMethod(self)
5257
return None
5358

5459

@@ -184,6 +189,124 @@ def apply(self,
184189
return out
185190

186191

192+
class GGUFMoEMethod(FusedMoEMethodBase):
193+
"""MoE method for GGUF.
194+
195+
Args:
196+
quant_config: The GGUF quantization config.
197+
"""
198+
199+
def __init__(self, quant_config: GGUFConfig):
200+
self.quant_config = quant_config
201+
202+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
203+
hidden_size: int, intermediate_size_per_partition: int,
204+
params_dtype: torch.dtype, **extra_weight_attrs):
205+
206+
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
207+
hidden_size)
208+
#gate up proj
209+
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
210+
set_weight_attrs(
211+
w13_qweight, {
212+
"input_dim": 1,
213+
"output_dim": 0,
214+
"tensor_shape": tensor_shape,
215+
"is_gguf_weight": True,
216+
"data_container": [],
217+
})
218+
set_weight_attrs(w13_qweight, extra_weight_attrs)
219+
layer.register_parameter("w13_qweight", w13_qweight)
220+
221+
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
222+
requires_grad=False)
223+
set_weight_attrs(w13_qweight_type, {
224+
"is_gguf_weight_type": True,
225+
"weight_type": 0,
226+
"ignore_warning": True
227+
})
228+
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
229+
layer.register_parameter("w13_qweight_type", w13_qweight_type)
230+
231+
tensor_shape = (num_experts, intermediate_size_per_partition,
232+
hidden_size)
233+
#gate down proj
234+
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
235+
set_weight_attrs(
236+
w2_qweight, {
237+
"input_dim": 1,
238+
"output_dim": 0,
239+
"tensor_shape": tensor_shape,
240+
"is_gguf_weight": True,
241+
"data_container": [],
242+
})
243+
set_weight_attrs(w2_qweight, extra_weight_attrs)
244+
layer.register_parameter("w2_qweight", w2_qweight)
245+
246+
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
247+
requires_grad=False)
248+
set_weight_attrs(w2_qweight_type, {
249+
"is_gguf_weight_type": True,
250+
"weight_type": 0,
251+
"ignore_warning": True
252+
})
253+
254+
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
255+
layer.register_parameter("w2_qweight_type", w2_qweight_type)
256+
self.act = SiluAndMul()
257+
258+
def apply(
259+
self,
260+
layer: torch.nn.Module,
261+
x: torch.Tensor,
262+
router_logits: torch.Tensor,
263+
top_k: int,
264+
renormalize: bool,
265+
use_grouped_topk: bool = False,
266+
topk_group: Optional[int] = None,
267+
num_expert_group: Optional[int] = None,
268+
global_num_experts: int = -1,
269+
expert_map: Optional[torch.Tensor] = None,
270+
custom_routing_function: Optional[Callable] = None,
271+
scoring_func: str = "softmax",
272+
e_score_correction_bias: Optional[torch.Tensor] = None,
273+
activation: str = "silu",
274+
):
275+
assert activation == "silu", "Only SiLU activation is supported."
276+
topk_weights, topk_ids = FusedMoE.select_experts(
277+
hidden_states=x,
278+
router_logits=router_logits,
279+
use_grouped_topk=use_grouped_topk,
280+
top_k=top_k,
281+
renormalize=renormalize,
282+
topk_group=topk_group,
283+
num_expert_group=num_expert_group,
284+
custom_routing_function=custom_routing_function,
285+
scoring_func=scoring_func,
286+
e_score_correction_bias=e_score_correction_bias)
287+
final_hidden_states = torch.empty_like(x)
288+
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
289+
inp = x[tok].reshape((1, ) + x.shape[1:])
290+
current_hidden_state = None
291+
for ww, ii in zip(w, idx):
292+
expert_up = layer.w13_qweight[ii]
293+
294+
out = _fuse_mul_mat(inp, expert_up,
295+
layer.w13_qweight_type.weight_type)
296+
out = self.act(out)
297+
298+
expert_down = layer.w2_qweight[ii]
299+
current_state = _fuse_mul_mat(
300+
out, expert_down,
301+
layer.w2_qweight_type.weight_type).mul_(ww)
302+
if current_hidden_state is None:
303+
current_hidden_state = current_state
304+
else:
305+
current_hidden_state.add_(current_state)
306+
final_hidden_states[tok] = current_hidden_state
307+
return final_hidden_states
308+
309+
187310
class GGUFEmbeddingMethod(GGUFLinearMethod):
188311
"""Embedding method for GGUF.
189312

vllm/model_executor/model_loader/loader.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,9 +1245,24 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
12451245
"""
12461246
config = model_config.hf_config
12471247
model_type = config.model_type
1248+
gguf_to_hf_name_map = {}
12481249
# hack: ggufs have a different name than transformers
12491250
if model_type == "cohere":
12501251
model_type = "command-r"
1252+
if model_type in ("deepseek_v3", "deepseek_v2"):
1253+
model_type = "deepseek2"
1254+
# GGUF layer map assumes that we will have a merged expert weights
1255+
# so we need to map them manually
1256+
for idx in range(config.num_hidden_layers):
1257+
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
1258+
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
1259+
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
1260+
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
1261+
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
1262+
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
1263+
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
1264+
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
1265+
12511266
arch = None
12521267
for key, value in gguf.MODEL_ARCH_NAMES.items():
12531268
if value == model_type:
@@ -1258,10 +1273,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
12581273
num_layers = config.num_hidden_layers
12591274
name_map = gguf.get_tensor_name_map(arch, num_layers)
12601275
with torch.device("meta"):
1261-
dummy_model = AutoModelForCausalLM.from_config(config)
1276+
dummy_model = AutoModelForCausalLM.from_config(
1277+
config, trust_remote_code=model_config.trust_remote_code)
12621278
state_dict = dummy_model.state_dict()
12631279

1264-
gguf_to_hf_name_map = {}
12651280
for hf_name in state_dict:
12661281
name, suffix = hf_name.rsplit(".", 1)
12671282
gguf_name = name_map.get_name(name)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
496496
weight = tensor.data
497497
weight_type = tensor.tensor_type
498498
name = gguf_to_hf_name_map[tensor.name]
499-
500499
if weight_type.name != "F32":
501500
name = name.replace("weight", "qweight")
502501
param = torch.tensor(weight)

0 commit comments

Comments
 (0)