Skip to content

Use RMSNorm in TransformersModel #12776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 78 additions & 9 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, Optional, Union
from math import prod
from typing import Iterable, Literal, Optional, Union

import torch
from torch import nn
Expand All @@ -27,6 +28,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -72,15 +74,64 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm:
"""
Replace Transformers RMS norm class with vLLM's RMSNorm class.

Args:
rms_norm (nn.Module): RMS norm module to be replaced.

Returns:
RMSNorm: The new RMSNorm.
"""
# Get hidden size
parameters = dict(rms_norm.named_parameters())
if len(parameters) != 1:
class_name = rms_norm.__class__.__name__
logger.warning(
"Unable to determine `hidden_size` of %s. "
"This layer will not benefit from vLLM's custom ops.", class_name)
return rms_norm
weight = next(iter(parameters.values()))
if weight is None and isinstance(rms_norm, nn.RMSNorm):
hidden_size = prod(rms_norm.normalized_shape)
else:
hidden_size = weight.numel()

# Get eps
attrs = vars(rms_norm)
condition = lambda k, v: not k.startswith("_") and isinstance(v, float)
attrs = {k: v for k, v in attrs.items() if condition(k, v)}
if len(attrs) != 1:
class_name = rms_norm.__class__.__name__
logger.warning(
"Unable to determine `eps` of %s. "
"This layer will not benefit from vLLM's custom ops.", class_name)
return rms_norm
eps = next(iter(attrs.values()))

return RMSNorm(
hidden_size=hidden_size,
eps=eps,
)


def replace_linear_class(
linear: nn.Linear,
style: str,
style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
Replace nn.Linear with one of vLLM's tensor parallel linear classes.

`quant_config` is not yet supported.

Quant config is not supported yet
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.

Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""

if not isinstance(style, str):
Expand All @@ -93,7 +144,10 @@ def replace_linear_class(
}.get(style)

if vllm_linear_cls is None:
raise ValueError(f"Unsupported parallel style value: {style}")
logger.warning(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized.", style)
return linear

class HFCompatibleLinear(vllm_linear_cls):
"""
Expand Down Expand Up @@ -137,7 +191,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
prefix = self.model.base_model_prefix

# MLP modifications
self.tensor_parallelize(self.model)
self.apply_base_model_tp_plan(self.model)

# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -155,6 +209,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
]

# Model modifications
self.replace_rms_norm_class(self.model)
self.replace_vocab_embed_class(self.model)

# ForCausalLM modifications
Expand All @@ -174,7 +229,21 @@ def log_replacement(self, name: str, old_module: nn.Module,
new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)

def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
def replace_rms_norm_class(self, module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
if "RMSNorm" in child_module.__class__.__name__:
new_module = replace_rms_norm_class(child_module)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
self.replace_rms_norm_class(child_module, prefix=qual_name)

def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
"""
if (self.config.base_model_tp_plan is None
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
raise ValueError(
Expand All @@ -191,7 +260,7 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
self.tensor_parallelize(child_module, prefix=qual_name)
self.apply_base_model_tp_plan(child_module, prefix=qual_name)

def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
Expand Down