Skip to content

Commit 66d6dd2

Browse files
jeejeeleetjtanaa
authored andcommitted
[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (vllm-project#12237)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent be57b24 commit 66d6dd2

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

vllm/model_executor/models/paligemma.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,18 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
136136
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
137137
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
138138
SupportsPP):
139-
139+
packed_modules_mapping = {
140+
"qkv_proj": [
141+
"q_proj",
142+
"k_proj",
143+
"v_proj",
144+
],
145+
"gate_up_proj": [
146+
"gate_proj",
147+
"up_proj",
148+
],
149+
}
150+
140151
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
141152
super().__init__()
142153
config = vllm_config.model_config.hf_config

vllm/model_executor/models/siglip.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,16 @@ def __init__(
344344

345345
self.config = config
346346
self.activation_fn = get_act_fn(config.hidden_act)
347-
348-
# For quantization, we require the hidden size to be a multiple of 64
349-
quantizable = (config.hidden_size % 64 == 0
350-
and config.intermediate_size % 64 == 0)
347+
# Special handling for BNB quantization
348+
if quant_config and quant_config.get_name() == "bitsandbytes":
349+
quantizable = True
350+
else:
351+
# For other quantization, we require the hidden size to be a
352+
# multiple of 64
353+
quantizable = (
354+
config.hidden_size % 64 == 0
355+
and config.intermediate_size % 64 == 0
356+
)
351357
self.fc1 = ColumnParallelLinear(
352358
config.hidden_size,
353359
config.intermediate_size,

0 commit comments

Comments
 (0)