Skip to content

Commit e3e1adc

Browse files
mgoinlulmer
authored andcommitted
Expand MLA to support most types of quantization (vllm-project#13181)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent 767c34e commit e3e1adc

File tree

3 files changed

+61
-132
lines changed

3 files changed

+61
-132
lines changed

vllm/attention/backends/mla/utils.py

+26-45
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2727
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
2828
from vllm.model_executor.layers.quantization.utils.quant_utils import (
29-
scaled_dequantize, scaled_quantize)
29+
scaled_quantize)
3030
from vllm.model_executor.layers.rotary_embedding import (
3131
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
3232

@@ -220,16 +220,6 @@ def _q_proj_and_k_up_proj(self, x):
220220
.view(-1, self.num_heads, self.kv_lora_rank)
221221

222222
def process_weights_after_loading(self, act_dtype: torch.dtype):
223-
224-
def is_layer_fp8(layer: LinearBase) -> bool:
225-
return isinstance(layer.quant_method, Fp8LinearMethod) or\
226-
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
227-
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
228-
229-
def quantization_scheme_supported(layer: LinearBase) -> bool:
230-
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
231-
is_layer_fp8(layer)
232-
233223
# TODO(lucas) This is very gross, we need a more wide scale refactor of
234224
# all the FP8 code with a more standard way of
235225
# defining schemes/group-shapes, we should also potentially force
@@ -239,7 +229,7 @@ def quantization_scheme_supported(layer: LinearBase) -> bool:
239229
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
240230
Tuple[Tuple[int, int], Tuple[int, int]]:
241231
if isinstance(layer.quant_method, Fp8LinearMethod):
242-
if layer.quant_method.block_quant is not None:
232+
if layer.quant_method.block_quant:
243233
weight_block_size = \
244234
layer.quant_method.quant_config.weight_block_size
245235
# per-token-group (1, X), block-quantized (X, Y)
@@ -267,41 +257,32 @@ def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
267257
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
268258
)
269259

270-
def get_scales(layer: LinearBase) -> torch.Tensor:
271-
if hasattr(layer, "weight_scale_inv"):
272-
return layer.weight_scale_inv
273-
return layer.weight_scale
274-
275-
def get_and_maybe_dequant_weights(layer: LinearBase):
276-
if is_layer_fp8(layer):
277-
if isinstance(layer.quant_method, \
278-
CompressedTensorsLinearMethod) and \
279-
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
280-
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
281-
# seems to store weights as (input, output) instead of
282-
# (output, input) so we need to transpose
283-
weight = layer.weight.T # standardize to (output, input)
284-
else:
285-
weight = layer.weight
286-
_, weight_scale_group_shape = \
287-
get_scale_group_shapes_for_fp8(layer)
288-
scales = get_scales(layer)
289-
290-
return scaled_dequantize(weight, scales,
291-
weight_scale_group_shape)
292-
else:
260+
def get_layer_weight(layer):
261+
if hasattr(layer, "weight"):
293262
return layer.weight
263+
elif hasattr(layer, "qweight"):
264+
return layer.qweight
265+
else:
266+
raise AttributeError(
267+
f"Layer '{layer}' has neither weight nor qweight")
294268

295-
if not (quantization_scheme_supported(self.kv_b_proj) and\
296-
quantization_scheme_supported(self.q_proj) and\
297-
quantization_scheme_supported(self.o_proj)):
298-
raise NotImplementedError(
299-
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
300-
", please run with VLLM_MLA_DISABLE=1")
301-
302-
weight_dtype = self.kv_b_proj.weight.dtype
303-
assert self.o_proj.weight.dtype == weight_dtype
304-
assert self.q_proj.weight.dtype == weight_dtype
269+
def get_and_maybe_dequant_weights(layer: LinearBase):
270+
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
271+
# NOTE: This should only be used offline, since it's O(N^3)
272+
eye = torch.eye(layer.input_size_per_partition,
273+
dtype=act_dtype,
274+
device=get_layer_weight(layer).device)
275+
dequant_weights = layer.quant_method.apply(layer,
276+
eye,
277+
bias=None)
278+
del eye
279+
# standardize to (output, input)
280+
return dequant_weights.T
281+
return layer.weight
282+
283+
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
284+
assert get_layer_weight(self.o_proj).dtype == weight_dtype
285+
assert get_layer_weight(self.q_proj).dtype == weight_dtype
305286

306287
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
307288
assert kv_b_proj_weight.shape == (

vllm/config.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -991,37 +991,7 @@ def is_cross_encoder(self) -> bool:
991991

992992
@property
993993
def use_mla(self) -> bool:
994-
if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE:
995-
return False
996-
997-
if self.quantization is not None and self.quantization not in [\
998-
"fp8", "compressed-tensors"]:
999-
logger.warning(
1000-
"MLA is not supported with %s quantization. "
1001-
"Disabling MLA.", self.quantization)
1002-
return False
1003-
1004-
# If using a "compressed-tensors" checkpoint, check that all groups
1005-
# have fp8 for both weights and activations.
1006-
if self.quantization == "compressed-tensors":
1007-
quant_config = self._parse_quant_hf_config()
1008-
for group_name, cfg in quant_config.get("config_groups", {
1009-
"": {}
1010-
}).items():
1011-
act_cfg = cfg.get("input_activations", {})
1012-
act_type = None if act_cfg is None else act_cfg.get("type", "")
1013-
w_cfg = cfg.get("weights", {})
1014-
w_type = None if w_cfg is None else w_cfg.get("type", "")
1015-
if act_type != "fp8" or w_type != "fp8":
1016-
logger.warning(
1017-
"compressed-tensors MLA support requires fp8 "
1018-
"activations and weights in group '%s', but got "
1019-
"activations type '%s' and weights type '%s'.\n "
1020-
"Full config: %s", group_name, act_type, w_type,
1021-
quant_config)
1022-
return False
1023-
1024-
return True
994+
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
1025995

1026996
@property
1027997
def supported_runner_types(self) -> Set[RunnerType]:

vllm/model_executor/model_loader/loader.py

+34-56
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,30 @@ def _initialize_model(
153153
return model_class(**kwargs)
154154

155155

156+
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
157+
target_device: torch.device) -> None:
158+
for _, module in model.named_modules():
159+
quant_method = getattr(module, "quant_method", None)
160+
if isinstance(quant_method, QuantizeMethodBase):
161+
# When quant methods need to process weights after loading
162+
# (for repacking, quantizing, etc), they expect parameters
163+
# to be on the global target device. This scope is for the
164+
# case where cpu offloading is used, where we will move the
165+
# parameters onto device for processing and back off after.
166+
with device_loading_context(module, target_device):
167+
quant_method.process_weights_after_loading(module)
168+
169+
# Currently only used by MLA.
170+
# NOTE: This intentionally happens after other modules so we can easily
171+
# decompress the weights for MLA.
172+
for _, module in model.named_modules():
173+
if isinstance(module, Attention) and \
174+
hasattr(module, "process_weights_after_loading"):
175+
# TODO(lucas): see if there is a way to unify the signatures
176+
# of process_weights_after_loading
177+
module.process_weights_after_loading(model_config.dtype)
178+
179+
156180
class BaseModelLoader(ABC):
157181
"""Base class for model loaders."""
158182

@@ -376,7 +400,6 @@ def download_model(self, model_config: ModelConfig) -> None:
376400
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
377401
device_config = vllm_config.device_config
378402
model_config = vllm_config.model_config
379-
380403
target_device = torch.device(device_config.device)
381404
with set_default_torch_dtype(model_config.dtype):
382405
with target_device:
@@ -394,23 +417,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
394417
"Following weights were not initialized from "
395418
f"checkpoint: {weights_not_loaded}")
396419

397-
for _, module in model.named_modules():
398-
quant_method = getattr(module, "quant_method", None)
399-
if isinstance(quant_method, QuantizeMethodBase):
400-
# When quant methods need to process weights after loading
401-
# (for repacking, quantizing, etc), they expect parameters
402-
# to be on the global target device. This scope is for the
403-
# case where cpu offloading is used, where we will move the
404-
# parameters onto device for processing and back off after.
405-
with device_loading_context(module, target_device):
406-
quant_method.process_weights_after_loading(module)
407-
if isinstance(module, Attention) and \
408-
hasattr(module, "process_weights_after_loading"):
409-
# When attention modules need to process weights after
410-
# currently only used by MLA
411-
# TODO(lucas): see if there is a way to unify the signatures
412-
# of process_weights_after_loading
413-
module.process_weights_after_loading(model_config.dtype)
420+
_process_weights_after_loading(model, model_config, target_device)
421+
414422
return model.eval()
415423

416424

@@ -429,29 +437,15 @@ def download_model(self, model_config: ModelConfig) -> None:
429437
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
430438
device_config = vllm_config.device_config
431439
model_config = vllm_config.model_config
440+
target_device = torch.device(device_config.device)
432441
with set_default_torch_dtype(model_config.dtype):
433-
with torch.device(device_config.device):
442+
with target_device:
434443
model = _initialize_model(vllm_config=vllm_config)
435444
# NOTE(woosuk): For accurate performance evaluation, we assign
436445
# random values to the weights.
437446
initialize_dummy_weights(model)
438447

439-
for _, module in model.named_modules():
440-
quant_method = getattr(module, "quant_method", None)
441-
if quant_method is not None:
442-
# When quant methods need to process weights after loading
443-
# (for repacking, quantizing, etc), they expect parameters
444-
# to be on the global target device. This scope is for the
445-
# case where cpu offloading is used, where we will move the
446-
# parameters onto device for processing and back off after.
447-
with device_loading_context(
448-
module, torch.device(device_config.device)):
449-
quant_method.process_weights_after_loading(module)
450-
if isinstance(module, Attention) and \
451-
hasattr(module, "process_weights_after_loading"):
452-
# When attention modules need to process weights after
453-
# currently only used by MLA
454-
module.process_weights_after_loading(model_config.dtype)
448+
_process_weights_after_loading(model, model_config, target_device)
455449
return model.eval()
456450

457451

@@ -632,6 +626,7 @@ def download_model(self, model_config: ModelConfig) -> None:
632626
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
633627
device_config = vllm_config.device_config
634628
model_config = vllm_config.model_config
629+
target_device = torch.device(device_config.device)
635630
from safetensors.torch import safe_open
636631

637632
from vllm.distributed import get_tensor_model_parallel_rank
@@ -640,18 +635,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
640635
model_config.revision)
641636

642637
with set_default_torch_dtype(model_config.dtype):
643-
with torch.device(device_config.device):
638+
with target_device:
644639
model = _initialize_model(vllm_config=vllm_config)
645-
for _, module in model.named_modules():
646-
quant_method = getattr(module, "quant_method", None)
647-
if quant_method is not None:
648-
quant_method.process_weights_after_loading(module)
649-
if isinstance(module, Attention) and \
650-
hasattr(module, "process_weights_after_loading"):
651-
# When attention modules need to process weights after
652-
# currently only used by MLA
653-
module.process_weights_after_loading(
654-
model_config.dtype)
640+
_process_weights_after_loading(model, model_config,
641+
target_device)
655642
rank = get_tensor_model_parallel_rank()
656643
pattern = os.path.join(
657644
local_model_path,
@@ -1401,16 +1388,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
14011388
self._get_weights_iterator(model_weights,
14021389
model_config.revision))
14031390

1404-
for _, module in model.named_modules():
1405-
quant_method = getattr(module, "quant_method", None)
1406-
if quant_method is not None:
1407-
with device_loading_context(module, target_device):
1408-
quant_method.process_weights_after_loading(module)
1409-
if isinstance(module, Attention) and \
1410-
hasattr(module, "process_weights_after_loading"):
1411-
# When attention modules need to process weights after
1412-
# currently only used by MLA
1413-
module.process_weights_after_loading(model_config.dtype)
1391+
_process_weights_after_loading(model, model_config, target_device)
14141392
return model.eval()
14151393

14161394

0 commit comments

Comments
 (0)