Skip to content

Commit 3aa75eb

Browse files
jeejeeleegarg-amit
authored andcommitted
[Misc] Adjust max_position_embeddings for LoRA compatibility (vllm-project#8957)
Signed-off-by: Amit Garg <[email protected]>
1 parent f41dfa2 commit 3aa75eb

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/worker/model_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,9 +1069,17 @@ def load_model(self) -> None:
10691069
assert supports_lora(
10701070
self.model
10711071
), f"{self.model.__class__.__name__} does not support LoRA yet."
1072+
10721073
if supports_multimodal(self.model):
10731074
logger.warning("Regarding multimodal models, vLLM currently "
10741075
"only supports adding LoRA to language model.")
1076+
# It's necessary to distinguish between the max_position_embeddings
1077+
# of VLMs and LLMs.
1078+
if hasattr(self.model.config, "max_position_embeddings"):
1079+
max_pos_embeddings = self.model.config.max_position_embeddings
1080+
else:
1081+
max_pos_embeddings = (
1082+
self.model.config.text_config.max_position_embeddings)
10751083

10761084
self.lora_manager = LRUCacheWorkerLoRAManager(
10771085
self.scheduler_config.max_num_seqs,
@@ -1081,8 +1089,7 @@ def load_model(self) -> None:
10811089
self.device,
10821090
self.model.embedding_modules,
10831091
self.model.embedding_padding_modules,
1084-
max_position_embeddings=self.model.config.
1085-
max_position_embeddings,
1092+
max_position_embeddings=max_pos_embeddings,
10861093
)
10871094
self.model = self.lora_manager.create_lora_manager(self.model)
10881095

0 commit comments

Comments
 (0)