1
1
"""Inference-only Snowflake Arctic model."""
2
- from typing import Iterable , List , Optional , Tuple
2
+ from typing import Iterable , List , Optional , Tuple , Union
3
3
4
4
import torch
5
5
from torch import nn
6
6
7
7
from vllm .attention import Attention , AttentionMetadata
8
8
from vllm .config import CacheConfig
9
- from vllm .distributed import (get_tensor_model_parallel_rank ,
9
+ from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
10
10
get_tensor_model_parallel_world_size ,
11
11
tensor_model_parallel_all_reduce )
12
12
from vllm .logger import init_logger
32
32
from vllm .sequence import IntermediateTensors , SamplerOutput
33
33
from vllm .transformers_utils .configs .arctic import ArcticConfig
34
34
35
+ from .utils import is_pp_missing_parameter , make_layers , make_empty_intermediate_tensors_factory
36
+
35
37
logger = init_logger (__name__ )
36
38
37
39
@@ -364,6 +366,7 @@ def __init__(
364
366
config : ArcticConfig ,
365
367
cache_config : Optional [CacheConfig ] = None ,
366
368
quant_config : Optional [QuantizationConfig ] = None ,
369
+ prefix : str = "" ,
367
370
) -> None :
368
371
super ().__init__ ()
369
372
self .padding_idx = config .pad_token_id
@@ -372,28 +375,35 @@ def __init__(
372
375
self .vocab_size ,
373
376
config .hidden_size ,
374
377
org_num_embeddings = self .vocab_size )
375
- self .layers = nn .ModuleList ([
376
- ArcticDecoderLayer (config ,
377
- layer_idx ,
378
- cache_config ,
379
- quant_config = quant_config )
380
- for layer_idx in range (config .num_hidden_layers )
381
- ])
378
+ self .start_layer , self .end_layer , self .layers = make_layers (
379
+ config .num_hidden_layers ,
380
+ lambda prefix : ArcticDecoderLayer (config ,
381
+ int (prefix .split ("." )[- 1 ]),
382
+ cache_config , quant_config ),
383
+ prefix = f"{ prefix } .layers" )
382
384
self ._attn_implementation = config ._attn_implementation
383
385
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
386
+ self .make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory (["hidden_states" ], config .hidden_size )
384
387
385
388
def forward (
386
389
self ,
387
390
input_ids : torch .Tensor ,
388
391
positions : torch .Tensor ,
389
392
kv_caches : List [torch .Tensor ],
390
393
attn_metadata : AttentionMetadata ,
391
- ) -> torch .Tensor :
392
- hidden_states = self .embed_tokens (input_ids )
393
- for i in range (len (self .layers )):
394
+ intermediate_tensors : IntermediateTensors ,
395
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
396
+ if get_pp_group ().is_first_rank :
397
+ hidden_states = self .embed_tokens (input_ids )
398
+ else :
399
+ assert intermediate_tensors is not None
400
+ hidden_states = intermediate_tensors ["hidden_states" ]
401
+ for i in range (self .start_layer , self .end_layer ):
394
402
layer = self .layers [i ]
395
- hidden_states = layer (positions , hidden_states , kv_caches [i ],
403
+ hidden_states = layer (positions , hidden_states , kv_caches [i - self . start_layer ],
396
404
attn_metadata )
405
+ if not get_pp_group ().is_last_rank :
406
+ return IntermediateTensors ({"hidden_states" : hidden_states })
397
407
hidden_states = self .norm (hidden_states )
398
408
return hidden_states
399
409
@@ -420,6 +430,7 @@ def __init__(self,
420
430
self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
421
431
config .vocab_size )
422
432
self .sampler = Sampler ()
433
+ self .make_empty_intermediate_tensors = self .model .make_empty_intermediate_tensors
423
434
424
435
def forward (
425
436
self ,
@@ -428,9 +439,9 @@ def forward(
428
439
kv_caches : List [torch .Tensor ],
429
440
attn_metadata : AttentionMetadata ,
430
441
intermediate_tensors : Optional [IntermediateTensors ] = None ,
431
- ) -> torch .Tensor :
442
+ ) -> Union [ torch .Tensor , IntermediateTensors ] :
432
443
hidden_states = self .model (input_ids , positions , kv_caches ,
433
- attn_metadata )
444
+ attn_metadata , intermediate_tensors )
434
445
return hidden_states
435
446
436
447
def compute_logits (self , hidden_states : torch .Tensor ,
@@ -498,6 +509,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
498
509
# Skip loading extra bias for GPTQ models.
499
510
if name .endswith (".bias" ) and name not in params_dict :
500
511
continue
512
+ if is_pp_missing_parameter (name , self ):
513
+ continue
501
514
param = params_dict [name ]
502
515
weight_loader = param .weight_loader
503
516
weight_loader (param , loaded_weight , shard_id )
@@ -507,6 +520,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
507
520
if weight_name not in name :
508
521
continue
509
522
name = name .replace (weight_name , param_name )
523
+ if is_pp_missing_parameter (name , self ):
524
+ continue
510
525
param = params_dict [name ]
511
526
weight_loader = param .weight_loader
512
527
weight_loader (param , loaded_weight , shard_id )
@@ -517,6 +532,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
517
532
if weight_name not in name :
518
533
continue
519
534
name = name .replace (weight_name , param_name )
535
+ if is_pp_missing_parameter (name , self ):
536
+ continue
520
537
param = params_dict [name ]
521
538
weight_loader = param .weight_loader
522
539
weight_loader (param ,
@@ -527,6 +544,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
527
544
else :
528
545
if name .endswith (".bias" ) and name not in params_dict :
529
546
continue
547
+ if is_pp_missing_parameter (name , self ):
548
+ continue
530
549
param = params_dict [name ]
531
550
532
551
weight_loader = getattr (param , "weight_loader" ,
0 commit comments