Skip to content

Commit 940edc4

Browse files
ErezSC42mzusman
authored andcommitted
[MODEL] LoRA support for Jamba model (vllm-project#11209)
Signed-off-by: Erez Schwartz <[email protected]>
1 parent a66dabf commit 940edc4

File tree

5 files changed

+132
-32
lines changed

5 files changed

+132
-32
lines changed

tests/lora/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import MagicMock, patch
55

66
import pytest
7+
import safetensors
78
import torch
89
import torch.nn as nn
910
from huggingface_hub import snapshot_download
@@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
169170
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
170171

171172

173+
@pytest.fixture(scope="session")
174+
def jamba_lora_files():
175+
# some of the adapters have unnecessary weights for serving,
176+
# hence we remove them
177+
def remove_unnecessary_weights(path):
178+
lora_path = f"{adapter_path}/adapter_model.safetensors"
179+
tensors = safetensors.torch.load_file(lora_path)
180+
nonlora_keys = []
181+
for k in list(tensors.keys()):
182+
if "lora" not in k:
183+
nonlora_keys.append(k)
184+
for k in nonlora_keys:
185+
del tensors[k]
186+
safetensors.torch.save_file(tensors, lora_path)
187+
188+
adapter_path = snapshot_download(
189+
repo_id=
190+
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
191+
192+
remove_unnecessary_weights(adapter_path)
193+
return adapter_path
194+
195+
172196
@pytest.fixture(scope="session")
173197
def gemma_lora_files():
174198
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

tests/lora/test_jamba.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import List
2+
3+
import pytest
4+
import torch
5+
6+
import vllm
7+
from vllm.lora.request import LoRARequest
8+
9+
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
10+
11+
MAX_TOKENS = 40
12+
13+
14+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
15+
prompts: List[str]) -> List[str]:
16+
17+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
18+
outputs = llm.generate(
19+
prompts,
20+
sampling_params,
21+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
22+
if lora_id else None)
23+
# Print the outputs.
24+
generated_texts: List[str] = []
25+
for output in outputs:
26+
prompt = output.prompt
27+
generated_text = output.outputs[0].text.strip()
28+
generated_texts.append(generated_text)
29+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
30+
return generated_texts
31+
32+
33+
@pytest.mark.parametrize("tp_size", [4])
34+
def test_jamba_lora(jamba_lora_files, tp_size):
35+
"""Original test, the LoRA model has the common target modules, not all"""
36+
if torch.cuda.device_count() < tp_size:
37+
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
38+
39+
prompts = ["Write a story about a sheep and a goat."]
40+
41+
llm = vllm.LLM(
42+
MODEL_PATH,
43+
enable_lora=True,
44+
max_num_seqs=16,
45+
max_loras=4,
46+
distributed_executor_backend="ray",
47+
tensor_parallel_size=tp_size,
48+
)
49+
50+
expected_jamba_output = [
51+
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
52+
]
53+
assert do_sample(llm, jamba_lora_files, lora_id=1,
54+
prompts=prompts) == expected_jamba_output

vllm/model_executor/layers/mamba/mamba_mixer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ def __init__(self,
4242
use_rms_norm: bool,
4343
rms_norm_has_weight: bool = True,
4444
rms_norm_eps: float = 1e-5,
45-
activation="silu"):
45+
activation="silu",
46+
is_lora_enabled: bool = False):
4647
super().__init__()
4748
self.time_step_rank = time_step_rank
4849
self.ssm_state_size = ssm_state_size
4950
self.use_rms_norm = use_rms_norm
5051
self.activation = activation
52+
self.is_lora_enabled = is_lora_enabled
5153

5254
self.conv1d = ColumnParallelLinear(
5355
input_size=conv_kernel_size,
@@ -63,6 +65,7 @@ def __init__(self,
6365
self.in_proj = MergedColumnParallelLinear(hidden_size,
6466
[intermediate_size] * 2,
6567
bias=use_bias)
68+
6669
# selective projection used to make dt, B and C input dependent
6770
self.x_proj = RowParallelLinear(
6871
intermediate_size,
@@ -170,7 +173,13 @@ def forward_cuda(self, hidden_states: torch.Tensor,
170173

171174
# 3. State Space Model sequence transformation
172175
# 3.a. input varying initialization of time_step, B and C
173-
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
176+
177+
if self.is_lora_enabled:
178+
# lora kernel requires contiguous tensor
179+
ssm_parameters = self.x_proj(
180+
hidden_states.transpose(-2, -1).contiguous())[0]
181+
else:
182+
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
174183

175184
time_step, B, C = torch.split(
176185
ssm_parameters,
@@ -222,6 +231,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
222231
scan_outputs = scan_outputs.transpose(0, 1)
223232

224233
# 4. Final linear projection
225-
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
226-
-1))[0]
234+
if self.is_lora_enabled:
235+
# lora kernel requires contiguous tensor
236+
contextualized_states = self.out_proj(
237+
scan_outputs.transpose(-2, -1).contiguous())[0]
238+
else:
239+
contextualized_states = self.out_proj(
240+
scan_outputs.transpose(-2, -1))[0]
227241
return contextualized_states

vllm/model_executor/models/jamba.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,11 @@ def __init__(self,
107107
layer_idx: int,
108108
cache_config: Optional[CacheConfig] = None,
109109
quant_config: Optional[QuantizationConfig] = None,
110-
prefix: str = "") -> None:
110+
is_lora_enabled: Optional[bool] = False,
111+
**kwargs) -> None:
111112
super().__init__()
112113
self.config = config
114+
self.is_lora_enabled = is_lora_enabled
113115
self.mamba = MambaMixer(hidden_size= config.hidden_size,
114116
ssm_state_size = config.mamba_d_state,
115117
conv_kernel_size = config.mamba_d_conv,
@@ -120,7 +122,9 @@ def __init__(self,
120122
use_bias = config.mamba_proj_bias,
121123
use_rms_norm=True,
122124
rms_norm_eps=config.rms_norm_eps,
123-
activation=config.hidden_act)
125+
activation=config.hidden_act,
126+
is_lora_enabled = self.is_lora_enabled
127+
)
124128

125129
num_experts = config.layers_num_experts[layer_idx]
126130
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@@ -156,14 +160,13 @@ def forward(
156160

157161
class JambaAttentionDecoderLayer(nn.Module):
158162

159-
def __init__(
160-
self,
161-
config: JambaConfig,
162-
layer_idx: int,
163-
cache_config: Optional[CacheConfig] = None,
164-
quant_config: Optional[QuantizationConfig] = None,
165-
prefix: str = "",
166-
) -> None:
163+
def __init__(self,
164+
config: JambaConfig,
165+
layer_idx: int,
166+
cache_config: Optional[CacheConfig] = None,
167+
quant_config: Optional[QuantizationConfig] = None,
168+
prefix: str = "",
169+
**kwargs) -> None:
167170
super().__init__()
168171
self.hidden_size = config.hidden_size
169172
tp_size = get_tensor_model_parallel_world_size()
@@ -287,17 +290,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
287290
org_num_embeddings=config.vocab_size,
288291
)
289292

293+
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
294+
290295
def get_layer(prefix: str):
291296
layer_idx = int(prefix.rsplit(".", 1)[1])
292297
layer_class = ALL_DECODER_LAYER_TYPES[
293298
config.layers_block_type[layer_idx]]
294-
return layer_class(
295-
config,
296-
layer_idx,
297-
cache_config,
298-
quant_config=quant_config,
299-
prefix=prefix,
300-
)
299+
return layer_class(config,
300+
layer_idx,
301+
cache_config,
302+
quant_config=quant_config,
303+
prefix=prefix,
304+
**extra_kwargs)
301305

302306
self.start_layer, self.end_layer, self.layers = make_layers(
303307
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
371375
"k_proj",
372376
"v_proj",
373377
],
378+
"in_proj": ["in_proj"],
374379
}
375380

376381
# LoRA specific attributes
377382
supported_lora_modules = [
378-
"qkv_proj",
379-
"o_proj",
380-
"embed_tokens",
381-
"lm_head",
383+
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
384+
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
382385
]
383386
embedding_modules = {
384387
"embed_tokens": "input_embeddings",
@@ -423,9 +426,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
423426
self.make_empty_intermediate_tensors = (
424427
self.model.make_empty_intermediate_tensors)
425428
if self.scheduler_config is not None and \
426-
not self.model_config.enforce_eager:
429+
not self.model_config.enforce_eager:
427430
if self.scheduler_config.max_num_seqs > \
428-
vllm_config.compilation_config.max_capture_size:
431+
vllm_config.compilation_config.max_capture_size:
429432
self.max_batch_size = \
430433
vllm_config.compilation_config.max_capture_size
431434
else:
@@ -446,7 +449,6 @@ def forward(self,
446449
inputs_embeds: Optional[torch.Tensor] = None,
447450
**kwargs):
448451
if self.mamba_cache is None:
449-
450452
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
451453
self.vllm_config.parallel_config, LayerBlockType.mamba)
452454
self.mamba_cache = MambaCacheManager(

vllm/model_executor/models/mamba.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
3838
def __init__(self,
3939
config: MambaConfig,
4040
cache_config: Optional[CacheConfig] = None,
41-
quant_config: Optional[QuantizationConfig] = None) -> None:
41+
quant_config: Optional[QuantizationConfig] = None,
42+
is_lora_enabled: Optional[bool] = False) -> None:
4243
super().__init__()
4344
self.config = config
4445
self.is_falcon_mamba = config.model_type == "falcon_mamba"
46+
self.is_lora_enabled = is_lora_enabled
4547
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
4648
self.mixer = MambaMixer(hidden_size=config.hidden_size,
4749
ssm_state_size=config.state_size,
@@ -53,7 +55,8 @@ def __init__(self,
5355
use_rms_norm=self.is_falcon_mamba,
5456
rms_norm_has_weight=not self.is_falcon_mamba,
5557
rms_norm_eps=mixer_rms_eps,
56-
activation=config.hidden_act)
58+
activation=config.hidden_act,
59+
is_lora_enabled=self.is_lora_enabled)
5760

5861
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
5962

@@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
8588
cache_config = vllm_config.cache_config
8689
quant_config = vllm_config.quant_config
8790
lora_config = vllm_config.lora_config
91+
is_lora_enabled = bool(lora_config)
8892

8993
self.config = config
9094
self.padding_idx = config.pad_token_id
@@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
101105

102106
self.start_layer, self.end_layer, self.layers = make_layers(
103107
config.num_hidden_layers,
104-
lambda prefix: MambaDecoderLayer(
105-
config, cache_config=cache_config, quant_config=quant_config),
108+
lambda prefix: MambaDecoderLayer(config,
109+
cache_config=cache_config,
110+
quant_config=quant_config,
111+
is_lora_enabled=is_lora_enabled),
106112
prefix=f"{prefix}.layers")
107113

108114
self.norm_f = RMSNorm(config.hidden_size,

0 commit comments

Comments
 (0)