Skip to content

[MODEL] Update LoRA modules supported by Jamba #11209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 27, 2024
36 changes: 26 additions & 10 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
Expand Down Expand Up @@ -42,12 +41,14 @@ def __init__(self,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu"):
activation="silu",
is_lora_enabled: bool = False):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
Expand All @@ -60,9 +61,13 @@ def __init__(self,
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
self.in_proj_lin = ColumnParallelLinear(hidden_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why modify this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the LoRA kernels did not work for us with MergedColumnParallelLinear so I split this param into two ColumnParallelLinear which are supported. This modification also required changes to our weight loading methods.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layer has 2 loras, and the slicing logic for each lora is different from llama's, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the reason is that for llama the gate_up_proj is a stacked param and in Jamba/Mamba in_proj is not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add the mapping for in_proj as shown below, there's no need to modify the loading logic

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"in_proj":["in_proj"]
}

intermediate_size,
bias=use_bias)
self.in_proj_gate = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=use_bias)

# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
Expand Down Expand Up @@ -134,8 +139,8 @@ def forward_cuda(self, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams):

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
gate = self.in_proj_gate(hidden_states)[0].transpose(-2, -1)
hidden_states = self.in_proj_lin(hidden_states)[0].transpose(-2, -1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
Expand Down Expand Up @@ -170,7 +175,13 @@ def forward_cuda(self, hidden_states: torch.Tensor,

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
Expand Down Expand Up @@ -222,6 +233,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
86 changes: 57 additions & 29 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(self,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__()
self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv,
Expand All @@ -118,7 +120,9 @@ def __init__(self,
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)

num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
Expand Down Expand Up @@ -154,14 +158,13 @@ def forward(

class JambaAttentionDecoderLayer(nn.Module):

def __init__(
self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
config: JambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -285,17 +288,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=config.vocab_size,
)

extra_kwargs = {"is_lora_enabled": bool(vllm_config)}

def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
return layer_class(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
**extra_kwargs)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
Expand Down Expand Up @@ -373,10 +377,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
Expand Down Expand Up @@ -421,9 +423,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
Expand All @@ -444,7 +446,6 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
Expand Down Expand Up @@ -579,11 +580,38 @@ def load_weights(self, weights: Iterable[Tuple[str,
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if "in_proj" in name:
# To support LoRA, in_proj weight needs to be split to
# two separate tensors, and here we load it manually
# manually splits in_proj_lin and in_proj_gate
name_lin = name.replace("in_proj", "in_proj_lin")
name_gate = name.replace("in_proj", "in_proj_gate")

# need to split the loaded weight of in_proj
param = params_dict[name_lin]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

weight_loader(param,
loaded_weight[:loaded_weight.shape[0] //
2, :]) # the lin split

loaded_params.add(name_lin)
param = params_dict[name_gate]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

weight_loader(param,
loaded_weight[loaded_weight.shape[0] //
2:, :]) # the lin split
loaded_params.add(name_gate)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if "in_proj" not in name:
loaded_params.add(name)
return loaded_params


Expand Down
49 changes: 40 additions & 9 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe using vllm_config: VllmConfig would be better, rather than adding another arg

super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.is_lora_enabled = is_lora_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you implement Lora support for MambaForCausalLM in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I had to do changes in MambaMixer that is being used by both Jamba and Mamba, I made the required changes for Mamba as well

Copy link
Collaborator

@jeejeelee jeejeelee Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 I think this PR just want to update Jamba lora module

mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
Expand All @@ -53,7 +55,8 @@ def __init__(self,
use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)

self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MambaDecoderLayer(
config, cache_config=cache_config, quant_config=quant_config),
lambda prefix: MambaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
prefix=f"{prefix}.layers")

self.norm_f = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -288,9 +294,34 @@ def load_weights(self, weights: Iterable[Tuple[str,
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if "in_proj" in name:
# To support LoRA, in_proj weight needs to be split to
# two separate tensors, and here we load it manually
# manually splits in_proj_lin and in_proj_gate
name_lin = name.replace("in_proj", "in_proj_lin")
name_gate = name.replace("in_proj", "in_proj_gate")

# need to split the loaded weight of in_proj
param = params_dict[name_lin]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

weight_loader(param, loaded_weight[:loaded_weight.shape[0] //
2, :]) # the lin split

loaded_params.add(name_lin)
param = params_dict[name_gate]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

weight_loader(param, loaded_weight[loaded_weight.shape[0] //
2:, :]) # the lin split
loaded_params.add(name_gate)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if "in_proj" not in name:
loaded_params.add(name)
return loaded_params
Loading