Skip to content

Commit c3852ee

Browse files
shawntannjhill
authored andcommitted
[Model] Adding Granite MoE. (vllm-project#8206)
Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent 5510852 commit c3852ee

File tree

4 files changed

+492
-3
lines changed

4 files changed

+492
-3
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
2+
3+
Run `pytest tests/models/test_granite.py`.
4+
"""
5+
import pytest
6+
7+
from ...utils import check_logprobs_close
8+
9+
MODELS = [
10+
"ibm/PowerMoE-3b",
11+
]
12+
13+
14+
@pytest.mark.parametrize("model", MODELS)
15+
@pytest.mark.parametrize("dtype", ["bfloat16"])
16+
@pytest.mark.parametrize("max_tokens", [64])
17+
@pytest.mark.parametrize("num_logprobs", [5])
18+
def test_models(
19+
hf_runner,
20+
vllm_runner,
21+
example_prompts,
22+
model: str,
23+
dtype: str,
24+
max_tokens: int,
25+
num_logprobs: int,
26+
) -> None:
27+
with hf_runner(model, dtype=dtype) as hf_model:
28+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
29+
example_prompts, max_tokens, num_logprobs)
30+
31+
with vllm_runner(model, dtype=dtype) as vllm_model:
32+
vllm_outputs = vllm_model.generate_greedy_logprobs(
33+
example_prompts, max_tokens, num_logprobs)
34+
check_logprobs_close(
35+
outputs_0_lst=hf_outputs,
36+
outputs_1_lst=vllm_outputs,
37+
name_0="hf",
38+
name_1="vllm",
39+
)

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
3333
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
3434
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
35+
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
3536
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
3637
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
3738
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),

vllm/model_executor/models/granite.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,12 @@ def __init__(
404404
self.lm_head.weight = self.model.embed_tokens.weight
405405

406406
logit_scale = getattr(config, "logit_scale", 1.0)
407+
408+
if hasattr(config, "logits_scaling"):
409+
logit_scale /= config.logits_scaling
407410
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
408411
config.vocab_size,
409-
logit_scale)
412+
scale=logit_scale)
410413
self.sampler = Sampler()
411414
else:
412415
self.lm_head = PPMissingLayer()
@@ -428,8 +431,6 @@ def compute_logits(
428431
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
429432
logits = self.logits_processor(self.lm_head, hidden_states,
430433
sampling_metadata)
431-
if logits is not None:
432-
logits /= self.config.logits_scaling
433434
return logits
434435

435436
def sample(

0 commit comments

Comments
 (0)