-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[MODEL] Update LoRA modules supported by Jamba #11209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
0b276b3
a80bf8d
623046d
8140c34
00f9a14
1ccb594
467c11f
3f47f5a
9e7ec54
492bcf3
cf7bfa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): | |
def __init__(self, | ||
config: MambaConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None) -> None: | ||
quant_config: Optional[QuantizationConfig] = None, | ||
is_lora_enabled: Optional[bool] = False) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe using |
||
super().__init__() | ||
self.config = config | ||
self.is_falcon_mamba = config.model_type == "falcon_mamba" | ||
self.is_lora_enabled = is_lora_enabled | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will you implement Lora support for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since I had to do changes in MambaMixer that is being used by both Jamba and Mamba, I made the required changes for Mamba as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @DarkLight1337 I think this PR just want to update Jamba lora module |
||
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None | ||
self.mixer = MambaMixer(hidden_size=config.hidden_size, | ||
ssm_state_size=config.state_size, | ||
|
@@ -53,7 +55,8 @@ def __init__(self, | |
use_rms_norm=self.is_falcon_mamba, | ||
rms_norm_has_weight=not self.is_falcon_mamba, | ||
rms_norm_eps=mixer_rms_eps, | ||
activation=config.hidden_act) | ||
activation=config.hidden_act, | ||
is_lora_enabled=self.is_lora_enabled) | ||
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | ||
|
||
|
@@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
is_lora_enabled = bool(lora_config) | ||
|
||
self.config = config | ||
self.padding_idx = config.pad_token_id | ||
|
@@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
|
||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: MambaDecoderLayer( | ||
config, cache_config=cache_config, quant_config=quant_config), | ||
lambda prefix: MambaDecoderLayer(config, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
is_lora_enabled=is_lora_enabled), | ||
prefix=f"{prefix}.layers") | ||
|
||
self.norm_f = RMSNorm(config.hidden_size, | ||
|
@@ -288,9 +294,34 @@ def load_weights(self, weights: Iterable[Tuple[str, | |
if is_pp_missing_parameter(name, self): | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(name) | ||
if "in_proj" in name: | ||
# To support LoRA, in_proj weight needs to be split to | ||
# two separate tensors, and here we load it manually | ||
# manually splits in_proj_lin and in_proj_gate | ||
name_lin = name.replace("in_proj", "in_proj_lin") | ||
name_gate = name.replace("in_proj", "in_proj_gate") | ||
|
||
# need to split the loaded weight of in_proj | ||
param = params_dict[name_lin] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
|
||
weight_loader(param, loaded_weight[:loaded_weight.shape[0] // | ||
2, :]) # the lin split | ||
|
||
loaded_params.add(name_lin) | ||
param = params_dict[name_gate] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
|
||
weight_loader(param, loaded_weight[loaded_weight.shape[0] // | ||
2:, :]) # the lin split | ||
loaded_params.add(name_gate) | ||
else: | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
if "in_proj" not in name: | ||
loaded_params.add(name) | ||
return loaded_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Why modify this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the LoRA kernels did not work for us with MergedColumnParallelLinear so I split this param into two ColumnParallelLinear which are supported. This modification also required changes to our weight loading methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This layer has 2 loras, and the slicing logic for each lora is different from llama's, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the reason is that for llama the gate_up_proj is a stacked param and in Jamba/Mamba in_proj is not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just add the mapping for
in_proj
as shown below, there's no need to modify the loading logicpacked_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"in_proj":["in_proj"]
}