Skip to content

Commit 9587115

Browse files
kylesayrsnishith-fujitsu
authored andcommitted
[SupportsQuant] Bert, Blip, Blip2, Bloom (vllm-project#15573)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ebd7c7d commit 9587115

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

vllm/model_executor/models/bert.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.transformers_utils.config import (
2727
get_cross_encoder_activation_function)
2828

29-
from .interfaces import SupportsCrossEncoding, SupportsV0Only
29+
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
3030
from .utils import WeightsMapper, maybe_prefix
3131

3232

@@ -313,7 +313,8 @@ def forward(self, hidden_states: torch.Tensor,
313313
return hidden_states
314314

315315

316-
class BertModel(nn.Module):
316+
class BertModel(nn.Module, SupportsQuant):
317+
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
317318

318319
def __init__(self,
319320
*,
@@ -385,7 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
385386
return loaded_params
386387

387388

388-
class BertEmbeddingModel(nn.Module, SupportsV0Only):
389+
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
389390
"""A model that uses Bert to provide embedding functionalities.
390391
391392
This class encapsulates the BertModel and provides an interface for
@@ -443,7 +444,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
443444
softmax=False)
444445

445446

446-
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
447+
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
448+
SupportsQuant):
447449
"""A model that uses Bert to provide embedding functionalities.
448450
449451
This class encapsulates the BertModel and provides an interface for

vllm/model_executor/models/blip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm.model_executor.layers.quantization import QuantizationConfig
1717
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1818

19+
from .interfaces import SupportsQuant
20+
1921

2022
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
2123
assert image_size % patch_size == 0
@@ -243,9 +245,10 @@ def forward(self, inputs_embeds: torch.Tensor):
243245
return hidden_states
244246

245247

246-
class BlipVisionModel(nn.Module):
248+
class BlipVisionModel(nn.Module, SupportsQuant):
247249
config_class = BlipVisionConfig
248250
main_input_name = "pixel_values"
251+
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
249252

250253
def __init__(
251254
self,

vllm/model_executor/models/blip2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from vllm.sequence import IntermediateTensors
2525

2626
from .blip import BlipVisionModel
27-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
27+
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
28+
SupportsQuant)
2829
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
2930
maybe_prefix, merge_multimodal_embeddings)
3031

@@ -498,7 +499,8 @@ def _get_prompt_updates(
498499
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
499500
info=Blip2ProcessingInfo,
500501
dummy_inputs=Blip2DummyInputsBuilder)
501-
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
502+
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
503+
SupportsQuant):
502504

503505
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
504506

vllm/model_executor/models/bloom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.model_executor.sampling_metadata import SamplingMetadata
4343
from vllm.sequence import IntermediateTensors
4444

45-
from .interfaces import SupportsPP, SupportsV0Only
45+
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
4646
from .utils import (is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
@@ -279,7 +279,7 @@ def forward(
279279
return hidden_states
280280

281281

282-
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
282+
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
283283

284284
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
285285
super().__init__()

0 commit comments

Comments
 (0)