Skip to content

Commit f8ec7d8

Browse files
youkaichaoweilong.yu
authored and
weilong.yu
committed
[bugfix] fix aria model and add torch.compile (vllm-project#10645)
Signed-off-by: youkaichao <[email protected]>
1 parent 366438e commit f8ec7d8

File tree

2 files changed

+14
-28
lines changed

2 files changed

+14
-28
lines changed

vllm/model_executor/models/aria.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
LlamaModel)
3030
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
3131
is_pp_missing_parameter,
32-
make_layers, maybe_prefix,
32+
maybe_prefix,
3333
merge_multimodal_embeddings)
3434
from vllm.multimodal import MULTIMODAL_REGISTRY
3535
from vllm.multimodal.base import MultiModalInputs
@@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
363363
"""
364364

365365
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366-
super().__init__(vllm_config=vllm_config, prefix=prefix)
367-
368-
config = vllm_config.model_config.hf_config
369-
cache_config = vllm_config.cache_config
370-
quant_config = vllm_config.quant_config
371-
372-
# FIXME: this is a hack to disable the compilation of the model
373-
self.do_not_compile = True
374-
375-
self.layers = None
376-
377-
self.start_layer, self.end_layer, self.layers = make_layers(
378-
config.num_hidden_layers,
379-
lambda prefix: MoEDecoderLayer(
380-
config=config,
381-
cache_config=cache_config,
382-
quant_config=quant_config,
383-
prefix=prefix,
384-
),
385-
prefix=f"{prefix}.layers",
386-
)
366+
super().__init__(vllm_config=vllm_config,
367+
prefix=prefix,
368+
layer_type=MoEDecoderLayer)
387369

388370
# Adapted from LlamaModel.load_weights with the modification of adding
389371
# the expert weights mapping to `stacked_params_mapping`

vllm/model_executor/models/llama.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
2222
"""Inference-only LLaMA model compatible with HuggingFace weights."""
23-
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
23+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
2424

2525
import torch
2626
from torch import nn
@@ -273,7 +273,11 @@ def forward(
273273
@support_torch_compile
274274
class LlamaModel(nn.Module):
275275

276-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
276+
def __init__(self,
277+
*,
278+
vllm_config: VllmConfig,
279+
prefix: str = "",
280+
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
277281
super().__init__()
278282

279283
config = vllm_config.model_config.hf_config
@@ -299,10 +303,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
299303
self.embed_tokens = PPMissingLayer()
300304
self.start_layer, self.end_layer, self.layers = make_layers(
301305
config.num_hidden_layers,
302-
lambda prefix: LlamaDecoderLayer(config=config,
303-
cache_config=cache_config,
304-
quant_config=quant_config,
305-
prefix=prefix),
306+
lambda prefix: layer_type(config=config,
307+
cache_config=cache_config,
308+
quant_config=quant_config,
309+
prefix=prefix),
306310
prefix=f"{prefix}.layers",
307311
)
308312
if get_pp_group().is_last_rank:

0 commit comments

Comments
 (0)