Skip to content

Commit 4d358bb

Browse files
author
Muralidhar Andoorveedu
committed
Model PP support
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
1 parent 4db5176 commit 4d358bb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1047
-663
lines changed

vllm/model_executor/models/arctic.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Inference-only Snowflake Arctic model."""
2-
from typing import Iterable, List, Optional, Tuple
2+
from typing import Iterable, List, Optional, Tuple, Union
33

44
import torch
55
from torch import nn
66

77
from vllm.attention import Attention, AttentionMetadata
88
from vllm.config import CacheConfig
9-
from vllm.distributed import (get_tensor_model_parallel_rank,
9+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
1010
get_tensor_model_parallel_world_size,
1111
tensor_model_parallel_all_reduce)
1212
from vllm.logger import init_logger
@@ -32,6 +32,8 @@
3232
from vllm.sequence import IntermediateTensors, SamplerOutput
3333
from vllm.transformers_utils.configs.arctic import ArcticConfig
3434

35+
from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory
36+
3537
logger = init_logger(__name__)
3638

3739

@@ -364,6 +366,7 @@ def __init__(
364366
config: ArcticConfig,
365367
cache_config: Optional[CacheConfig] = None,
366368
quant_config: Optional[QuantizationConfig] = None,
369+
prefix: str = "",
367370
) -> None:
368371
super().__init__()
369372
self.padding_idx = config.pad_token_id
@@ -372,28 +375,35 @@ def __init__(
372375
self.vocab_size,
373376
config.hidden_size,
374377
org_num_embeddings=self.vocab_size)
375-
self.layers = nn.ModuleList([
376-
ArcticDecoderLayer(config,
377-
layer_idx,
378-
cache_config,
379-
quant_config=quant_config)
380-
for layer_idx in range(config.num_hidden_layers)
381-
])
378+
self.start_layer, self.end_layer, self.layers = make_layers(
379+
config.num_hidden_layers,
380+
lambda prefix: ArcticDecoderLayer(config,
381+
int(prefix.split(".")[-1]),
382+
cache_config, quant_config),
383+
prefix=f"{prefix}.layers")
382384
self._attn_implementation = config._attn_implementation
383385
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
386+
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)
384387

385388
def forward(
386389
self,
387390
input_ids: torch.Tensor,
388391
positions: torch.Tensor,
389392
kv_caches: List[torch.Tensor],
390393
attn_metadata: AttentionMetadata,
391-
) -> torch.Tensor:
392-
hidden_states = self.embed_tokens(input_ids)
393-
for i in range(len(self.layers)):
394+
intermediate_tensors: IntermediateTensors,
395+
) -> Union[torch.Tensor, IntermediateTensors]:
396+
if get_pp_group().is_first_rank:
397+
hidden_states = self.embed_tokens(input_ids)
398+
else:
399+
assert intermediate_tensors is not None
400+
hidden_states = intermediate_tensors["hidden_states"]
401+
for i in range(self.start_layer, self.end_layer):
394402
layer = self.layers[i]
395-
hidden_states = layer(positions, hidden_states, kv_caches[i],
403+
hidden_states = layer(positions, hidden_states, kv_caches[i-self.start_layer],
396404
attn_metadata)
405+
if not get_pp_group().is_last_rank:
406+
return IntermediateTensors({"hidden_states": hidden_states})
397407
hidden_states = self.norm(hidden_states)
398408
return hidden_states
399409

@@ -420,6 +430,7 @@ def __init__(self,
420430
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
421431
config.vocab_size)
422432
self.sampler = Sampler()
433+
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
423434

424435
def forward(
425436
self,
@@ -428,9 +439,9 @@ def forward(
428439
kv_caches: List[torch.Tensor],
429440
attn_metadata: AttentionMetadata,
430441
intermediate_tensors: Optional[IntermediateTensors] = None,
431-
) -> torch.Tensor:
442+
) -> Union[torch.Tensor, IntermediateTensors]:
432443
hidden_states = self.model(input_ids, positions, kv_caches,
433-
attn_metadata)
444+
attn_metadata, intermediate_tensors)
434445
return hidden_states
435446

436447
def compute_logits(self, hidden_states: torch.Tensor,
@@ -498,6 +509,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
498509
# Skip loading extra bias for GPTQ models.
499510
if name.endswith(".bias") and name not in params_dict:
500511
continue
512+
if is_pp_missing_parameter(name, self):
513+
continue
501514
param = params_dict[name]
502515
weight_loader = param.weight_loader
503516
weight_loader(param, loaded_weight, shard_id)
@@ -507,6 +520,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
507520
if weight_name not in name:
508521
continue
509522
name = name.replace(weight_name, param_name)
523+
if is_pp_missing_parameter(name, self):
524+
continue
510525
param = params_dict[name]
511526
weight_loader = param.weight_loader
512527
weight_loader(param, loaded_weight, shard_id)
@@ -517,6 +532,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
517532
if weight_name not in name:
518533
continue
519534
name = name.replace(weight_name, param_name)
535+
if is_pp_missing_parameter(name, self):
536+
continue
520537
param = params_dict[name]
521538
weight_loader = param.weight_loader
522539
weight_loader(param,
@@ -527,6 +544,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
527544
else:
528545
if name.endswith(".bias") and name not in params_dict:
529546
continue
547+
if is_pp_missing_parameter(name, self):
548+
continue
530549
param = params_dict[name]
531550

532551
weight_loader = getattr(param, "weight_loader",

vllm/model_executor/models/baichuan.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# limitations under the License.
2020
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
2121
import math
22-
from typing import Iterable, List, Optional, Tuple
22+
from typing import Iterable, List, Optional, Tuple, Union
2323

2424
import torch
2525
from torch import nn
@@ -28,7 +28,8 @@
2828
from vllm.attention import Attention, AttentionMetadata
2929
from vllm.config import CacheConfig, LoRAConfig
3030
from vllm.distributed import (get_tensor_model_parallel_rank,
31-
get_tensor_model_parallel_world_size)
31+
get_tensor_model_parallel_world_size,
32+
get_pp_group)
3233
from vllm.model_executor.layers.activation import SiluAndMul
3334
from vllm.model_executor.layers.layernorm import RMSNorm
3435
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -46,6 +47,7 @@
4647
from vllm.sequence import IntermediateTensors, SamplerOutput
4748

4849
from .interfaces import SupportsLoRA
50+
from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory
4951

5052

5153
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -255,7 +257,8 @@ def __init__(self,
255257
config: PretrainedConfig,
256258
position_embedding: str,
257259
cache_config: Optional[CacheConfig] = None,
258-
quant_config: Optional[QuantizationConfig] = None):
260+
quant_config: Optional[QuantizationConfig] = None,
261+
prefix: str = ""):
259262
super().__init__()
260263
self.config = config
261264
self.padding_idx = config.pad_token_id
@@ -265,31 +268,43 @@ def __init__(self,
265268
config.vocab_size,
266269
config.hidden_size,
267270
)
268-
self.layers = nn.ModuleList([
269-
BaiChuanDecoderLayer(config, position_embedding, cache_config,
270-
quant_config)
271-
for _ in range(config.num_hidden_layers)
272-
])
271+
self.start_layer, self.end_layer, self.layers = make_layers(
272+
config.num_hidden_layers,
273+
lambda prefix: BaiChuanDecoderLayer(config, position_embedding, cache_config, quant_config),
274+
prefix=f"{prefix}.layers",
275+
)
273276
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
277+
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(["hidden_states", "residual"], config.hidden_size)
274278

275279
def forward(
276280
self,
277281
input_ids: torch.Tensor,
278282
positions: torch.Tensor,
279283
kv_caches: List[torch.Tensor],
280284
attn_metadata: AttentionMetadata,
281-
) -> torch.Tensor:
282-
hidden_states = self.embed_tokens(input_ids)
283-
residual = None
284-
for i in range(len(self.layers)):
285+
intermediate_tensors: IntermediateTensors,
286+
) -> Union[torch.Tensor. IntermediateTensors]:
287+
if get_pp_group().is_first_rank:
288+
hidden_states = self.embed_tokens(input_ids)
289+
residual = None
290+
else:
291+
assert intermediate_tensors is not None
292+
hidden_states = intermediate_tensors["hidden_states"]
293+
residual = intermediate_tensors["residual"]
294+
for i in range(self.start_layer, self.end_layer):
285295
layer = self.layers[i]
286296
hidden_states, residual = layer(
287297
positions,
288298
hidden_states,
289-
kv_caches[i],
299+
kv_caches[i-self.start_layer],
290300
attn_metadata,
291301
residual,
292302
)
303+
if not get_pp_group().is_last_rank:
304+
return IntermediateTensors({
305+
"hidden_states": hidden_states,
306+
"residual": residual,
307+
})
293308
hidden_states, _ = self.norm(hidden_states, residual)
294309
return hidden_states
295310

@@ -333,6 +348,7 @@ def __init__(
333348
quant_config=quant_config)
334349
self.logits_processor = LogitsProcessor(config.vocab_size)
335350
self.sampler = Sampler()
351+
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
336352

337353
def forward(
338354
self,
@@ -341,9 +357,9 @@ def forward(
341357
kv_caches: List[torch.Tensor],
342358
attn_metadata: AttentionMetadata,
343359
intermediate_tensors: Optional[IntermediateTensors] = None,
344-
) -> torch.Tensor:
360+
) -> Union[torch.Tensor, IntermediateTensors]:
345361
hidden_states = self.model(input_ids, positions, kv_caches,
346-
attn_metadata)
362+
attn_metadata, intermediate_tensors)
347363
return hidden_states
348364

349365
def compute_logits(self, hidden_states: torch.Tensor,
@@ -389,6 +405,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
389405
# Skip loading extra bias for GPTQ models.
390406
if name.endswith(".bias") and name not in params_dict:
391407
continue
408+
if is_pp_missing_parameter(name, self):
409+
continue
392410
param = params_dict[name]
393411
weight_loader = param.weight_loader
394412
weight_loader(param, loaded_weight, shard_id)
@@ -397,6 +415,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
397415
# Skip loading extra bias for GPTQ models.
398416
if name.endswith(".bias") and name not in params_dict:
399417
continue
418+
if is_pp_missing_parameter(name, self):
419+
continue
400420
param = params_dict[name]
401421
weight_loader = getattr(param, "weight_loader",
402422
default_weight_loader)

vllm/model_executor/models/blip2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
1+
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
22

33
import torch
44
import torch.nn as nn
@@ -21,7 +21,7 @@
2121
from .blip import (BlipVisionModel, dummy_image_for_blip,
2222
get_max_blip_image_tokens)
2323
from .interfaces import SupportsVision
24-
from .utils import merge_vision_embeddings
24+
from .utils import merge_vision_embeddings, is_pp_missing_parameter, make_empty_intermediate_tensors_factory
2525

2626
_KEYS_TO_MODIFY_MAPPING = {
2727
"language_model.lm_head": "lm_head",
@@ -558,7 +558,7 @@ def forward(
558558
attn_metadata: AttentionMetadata,
559559
intermediate_tensors: Optional[IntermediateTensors] = None,
560560
**kwargs: object,
561-
) -> SamplerOutput:
561+
) -> Union[SamplerOutput, IntermediateTensors]:
562562
"""Run forward pass for BLIP-2.
563563
564564
One key thing to understand is the `input_ids` already accounts for the
@@ -607,6 +607,7 @@ def forward(
607607
positions,
608608
kv_caches,
609609
attn_metadata,
610+
intermediate_tensors=intermediate_tensors,
610611
inputs_embeds=inputs_embeds)
611612

612613
return hidden_states
@@ -656,13 +657,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
656657
shard_id) in stacked_params_mapping:
657658
if weight_name not in name:
658659
continue
660+
if is_pp_missing_parameter(name.replace(weight_name, param_name), self):
661+
continue
659662
param = params_dict[name.replace(weight_name, param_name)]
660663
weight_loader = param.weight_loader
661664
weight_loader(param, loaded_weight, shard_id)
662665
break
663666
else:
664667
use_default_weight_loading = True
665668
if use_default_weight_loading:
669+
if is_pp_missing_parameter(name, self):
670+
continue
666671
param = params_dict[name]
667672
weight_loader = getattr(param, "weight_loader",
668673
default_weight_loader)

0 commit comments

Comments
 (0)