Skip to content

Commit fa72f9a

Browse files
elaineyzaws-tailinpaaws-rishyrajaws-yishanmaws-patlange
authored
Order sequence ids + config update to support specifying custom quantization layers (#18279)
Signed-off-by: Elaine Zhao <[email protected]> Co-authored-by: Tailin Pan <[email protected]> Co-authored-by: Rishabh Rajesh <[email protected]> Co-authored-by: Yishan McNabb <[email protected]> Co-authored-by: Patrick Lange <[email protected]> Co-authored-by: Maxwell Goldberg <[email protected]> Co-authored-by: Aakash Shetty <[email protected]>
1 parent ebed81f commit fa72f9a

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

tests/neuron/2_core/test_mistral.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,54 @@ def test_mistral():
77
llm = LLM(model="mistralai/Mistral-7B-v0.1",
88
tensor_parallel_size=2,
99
max_num_seqs=4,
10-
max_model_len=512,
10+
max_model_len=128,
1111
use_v2_block_manager=True,
1212
override_neuron_config={
1313
"sequence_parallel_enabled": False,
1414
"skip_warmup": True
1515
},
1616
device="neuron")
1717

18+
# Send more prompts than the compiled batch size (4) and request
19+
# varying generation lengths to test accuracy related to Neuron
20+
# specific sequence id sorting.
1821
prompts = [
1922
"The president of the United States is",
2023
"The capital of France is",
24+
"What is Annapurna labs?",
25+
"I believe the meaning of life is",
26+
"Tell me a story about a brave knight",
27+
"Hello, my name is Llama",
2128
]
22-
outputs = llm.generate(prompts, SamplingParams(top_k=1))
29+
30+
sampling_params = [
31+
SamplingParams(top_k=1, max_tokens=10),
32+
SamplingParams(top_k=1, max_tokens=20),
33+
SamplingParams(top_k=1, max_tokens=30),
34+
SamplingParams(top_k=1, max_tokens=40),
35+
SamplingParams(top_k=1, max_tokens=50),
36+
SamplingParams(top_k=1, max_tokens=60)
37+
]
38+
39+
outputs = llm.generate(prompts, sampling_params)
2340

2441
expected_outputs = [
25-
" the most powerful person in the world. He is the head of state "
26-
"and head",
27-
" a city of many faces. It is a city of history, culture, art"
42+
" the most powerful person in the world. He is",
43+
" a city of many faces. It is a city of history, culture, art, "
44+
"fashion, and",
45+
"\n\nAnnapurna Labs is a semiconductor company that was founded "
46+
"in 2013 by Amazon. The company is",
47+
" to be happy.\n\nI believe that happiness is a choice.\n\nI "
48+
"believe that happiness is a state of mind.\n\nI believe that "
49+
"happiness is a journey.\n\nI believe",
50+
" who rescued a princess from a dragon.\n\nTell me a story about"
51+
" a princess who rescued herself from a dragon.\n\nTell me a "
52+
"story about a princess who rescued herself from a dragon and "
53+
"then rescued a knight from",
54+
" and I am a 10 year old male. I am a very friendly and "
55+
"affectionate boy who loves to be around people. I am a very "
56+
"active boy who loves to play and run around. I am a very smart "
57+
"boy who loves to learn new things. I am a very loyal boy"
2858
]
2959

3060
for expected_output, output in zip(expected_outputs, outputs):

vllm/model_executor/model_loader/neuronx_distributed.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,29 @@ def forward(
8787
input_block_ids: torch.Tensor,
8888
sampling_params: torch.Tensor,
8989
) -> torch.Tensor:
90+
# sort block ids sequentially for perf/neuron support reasons
91+
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
92+
input_ids = torch.index_select(input_ids, 0, sorted_indices)
93+
positions = torch.index_select(positions, 0, sorted_indices)
94+
sampling_params = torch.index_select(sampling_params, 0,
95+
sorted_indices)
96+
9097
output = self.model(input_ids,
9198
attention_mask=None,
9299
position_ids=positions,
93-
seq_ids=input_block_ids,
100+
seq_ids=sorted_input_block_ids,
94101
sampling_params=sampling_params)
95102
# on-device sampling
96103
if self.config.neuron_config.on_device_sampling_config:
97-
return output.hidden_states
104+
output = output.hidden_states
98105
else:
99-
return output.logits[:, -1, :]
106+
output = output.logits[:, -1, :]
107+
108+
restored_indices = torch.argsort(sorted_indices)
109+
if input_block_ids.shape[0] != 1:
110+
output = torch.index_select(output, 0, restored_indices)
111+
112+
return output
100113

101114
def compute_logits(self, hidden_states: torch.Tensor,
102115
sampling_metadata: SamplingMetadata) -> torch.Tensor:
@@ -340,14 +353,26 @@ def forward(
340353
input_block_ids: torch.Tensor,
341354
sampling_params: torch.Tensor,
342355
) -> torch.Tensor:
356+
# sort block ids sequentially for perf/neuron support reasons
357+
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
358+
input_ids = torch.index_select(input_ids, 0, sorted_indices)
359+
positions = torch.index_select(positions, 0, sorted_indices)
360+
sampling_params = torch.index_select(sampling_params, 0,
361+
sorted_indices)
362+
343363
output = self.model(input_ids,
344364
attention_mask=None,
345365
position_ids=positions,
346-
seq_ids=input_block_ids,
366+
seq_ids=sorted_input_block_ids,
347367
sampling_params=sampling_params)
368+
restored_indices = torch.argsort(sorted_indices)
369+
348370
# CTX encoding
349371
if (positions[:, 0]).sum().item() == 0:
350-
return output.fused_outputs[0][:, 0:1]
372+
output = output.fused_outputs[0][:, 0:1]
373+
if input_block_ids.shape[0] != 1:
374+
output = torch.index_select(output, 0, restored_indices)
375+
return output
351376

352377
# Fused Spec (Generation)
353378
accepted_tokens_with_padding = output.fused_outputs[0]
@@ -362,6 +387,10 @@ def forward(
362387
-1) >= generated_token_counts
363388
accepted_tokens_with_padding[mask] = -1
364389

390+
if input_block_ids.shape[0] != 1:
391+
accepted_tokens_with_padding = torch.index_select(
392+
accepted_tokens_with_padding, 0, restored_indices)
393+
365394
return accepted_tokens_with_padding
366395

367396
def sample(
@@ -416,6 +445,10 @@ def load_weights(self, model_name_or_path: str,
416445
draft_neuron_config.speculation_length = 0
417446
draft_neuron_config.trace_tokengen_model = True
418447
draft_neuron_config.enable_fused_speculation = False
448+
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
449+
None):
450+
draft_neuron_config.modules_to_not_convert = (
451+
draft_neuron_config.draft_model_modules_to_not_convert)
419452
if config.neuron_config.enable_eagle_speculation:
420453
draft_neuron_config.is_eagle_draft = True
421454
draft_neuron_config.sequence_parallel_enabled = False

0 commit comments

Comments
 (0)