Skip to content

Commit 7b8b270

Browse files
aws-tailinpaaws-rishyrajaws-yishanmaws-patlangeaws-mgld
authored andcommitted
Upstream reordering seq-ids + draft model fp8 quantization
Co-authored-by: Rishabh Rajesh <[email protected]> Co-authored-by: Yishan McNabb <[email protected]> Co-authored-by: Patrick Lange <[email protected]> Co-authored-by: Maxwell Goldberg <[email protected]> Co-authored-by: Aakash Shetty <[email protected]> Signed-off-by: Elaine Zhao <[email protected]>
1 parent 0189a65 commit 7b8b270

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

vllm/model_executor/model_loader/neuronx_distributed.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,29 @@ def forward(
8484
input_block_ids: torch.Tensor,
8585
sampling_params: torch.Tensor,
8686
) -> torch.Tensor:
87+
# sort block ids sequentially for perf/neuron support reasons
88+
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
89+
input_ids = torch.index_select(input_ids, 0, sorted_indices)
90+
positions = torch.index_select(positions, 0, sorted_indices)
91+
sampling_params = torch.index_select(sampling_params, 0,
92+
sorted_indices)
93+
8794
output = self.model(input_ids,
8895
attention_mask=None,
8996
position_ids=positions,
90-
seq_ids=input_block_ids,
97+
seq_ids=sorted_input_block_ids,
9198
sampling_params=sampling_params)
9299
# on-device sampling
93100
if self.config.neuron_config.on_device_sampling_config:
94-
return output.hidden_states
101+
output = output.hidden_states
95102
else:
96-
return output.logits[:, -1, :]
103+
output = output.logits[:, -1, :]
104+
105+
restored_indices = torch.argsort(sorted_indices)
106+
if input_block_ids.shape[0] != 1:
107+
output = torch.index_select(output, 0, restored_indices)
108+
109+
return output
97110

98111
def compute_logits(self, hidden_states: torch.Tensor,
99112
sampling_metadata: SamplingMetadata) -> torch.Tensor:
@@ -337,14 +350,26 @@ def forward(
337350
input_block_ids: torch.Tensor,
338351
sampling_params: torch.Tensor,
339352
) -> torch.Tensor:
353+
# sort block ids sequentially for perf/neuron support reasons
354+
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
355+
input_ids = torch.index_select(input_ids, 0, sorted_indices)
356+
positions = torch.index_select(positions, 0, sorted_indices)
357+
sampling_params = torch.index_select(sampling_params, 0,
358+
sorted_indices)
359+
340360
output = self.model(input_ids,
341361
attention_mask=None,
342362
position_ids=positions,
343-
seq_ids=input_block_ids,
363+
seq_ids=sorted_input_block_ids,
344364
sampling_params=sampling_params)
365+
restored_indices = torch.argsort(sorted_indices)
366+
345367
# CTX encoding
346368
if (positions[:, 0]).sum().item() == 0:
347-
return output.fused_outputs[0][:, 0:1]
369+
output = output.fused_outputs[0][:, 0:1]
370+
if input_block_ids.shape[0] != 1:
371+
output = torch.index_select(output, 0, restored_indices)
372+
return output
348373

349374
# Fused Spec (Generation)
350375
accepted_tokens_with_padding = output.fused_outputs[0]
@@ -359,6 +384,10 @@ def forward(
359384
-1) >= generated_token_counts
360385
accepted_tokens_with_padding[mask] = -1
361386

387+
if input_block_ids.shape[0] != 1:
388+
accepted_tokens_with_padding = torch.index_select(
389+
accepted_tokens_with_padding, 0, restored_indices)
390+
362391
return accepted_tokens_with_padding
363392

364393
def sample(
@@ -413,6 +442,10 @@ def load_weights(self, model_name_or_path: str,
413442
draft_neuron_config.speculation_length = 0
414443
draft_neuron_config.trace_tokengen_model = True
415444
draft_neuron_config.enable_fused_speculation = False
445+
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
446+
None):
447+
draft_neuron_config.modules_to_not_convert = (
448+
draft_neuron_config.draft_model_modules_to_not_convert)
416449
if config.neuron_config.enable_eagle_speculation:
417450
draft_neuron_config.is_eagle_draft = True
418451
draft_neuron_config.sequence_parallel_enabled = False

0 commit comments

Comments
 (0)