@@ -84,16 +84,29 @@ def forward(
84
84
input_block_ids : torch .Tensor ,
85
85
sampling_params : torch .Tensor ,
86
86
) -> torch .Tensor :
87
+ # sort block ids sequentially for perf/neuron support reasons
88
+ sorted_input_block_ids , sorted_indices = torch .sort (input_block_ids )
89
+ input_ids = torch .index_select (input_ids , 0 , sorted_indices )
90
+ positions = torch .index_select (positions , 0 , sorted_indices )
91
+ sampling_params = torch .index_select (sampling_params , 0 ,
92
+ sorted_indices )
93
+
87
94
output = self .model (input_ids ,
88
95
attention_mask = None ,
89
96
position_ids = positions ,
90
- seq_ids = input_block_ids ,
97
+ seq_ids = sorted_input_block_ids ,
91
98
sampling_params = sampling_params )
92
99
# on-device sampling
93
100
if self .config .neuron_config .on_device_sampling_config :
94
- return output .hidden_states
101
+ output = output .hidden_states
95
102
else :
96
- return output .logits [:, - 1 , :]
103
+ output = output .logits [:, - 1 , :]
104
+
105
+ restored_indices = torch .argsort (sorted_indices )
106
+ if input_block_ids .shape [0 ] != 1 :
107
+ output = torch .index_select (output , 0 , restored_indices )
108
+
109
+ return output
97
110
98
111
def compute_logits (self , hidden_states : torch .Tensor ,
99
112
sampling_metadata : SamplingMetadata ) -> torch .Tensor :
@@ -337,14 +350,26 @@ def forward(
337
350
input_block_ids : torch .Tensor ,
338
351
sampling_params : torch .Tensor ,
339
352
) -> torch .Tensor :
353
+ # sort block ids sequentially for perf/neuron support reasons
354
+ sorted_input_block_ids , sorted_indices = torch .sort (input_block_ids )
355
+ input_ids = torch .index_select (input_ids , 0 , sorted_indices )
356
+ positions = torch .index_select (positions , 0 , sorted_indices )
357
+ sampling_params = torch .index_select (sampling_params , 0 ,
358
+ sorted_indices )
359
+
340
360
output = self .model (input_ids ,
341
361
attention_mask = None ,
342
362
position_ids = positions ,
343
- seq_ids = input_block_ids ,
363
+ seq_ids = sorted_input_block_ids ,
344
364
sampling_params = sampling_params )
365
+ restored_indices = torch .argsort (sorted_indices )
366
+
345
367
# CTX encoding
346
368
if (positions [:, 0 ]).sum ().item () == 0 :
347
- return output .fused_outputs [0 ][:, 0 :1 ]
369
+ output = output .fused_outputs [0 ][:, 0 :1 ]
370
+ if input_block_ids .shape [0 ] != 1 :
371
+ output = torch .index_select (output , 0 , restored_indices )
372
+ return output
348
373
349
374
# Fused Spec (Generation)
350
375
accepted_tokens_with_padding = output .fused_outputs [0 ]
@@ -359,6 +384,10 @@ def forward(
359
384
- 1 ) >= generated_token_counts
360
385
accepted_tokens_with_padding [mask ] = - 1
361
386
387
+ if input_block_ids .shape [0 ] != 1 :
388
+ accepted_tokens_with_padding = torch .index_select (
389
+ accepted_tokens_with_padding , 0 , restored_indices )
390
+
362
391
return accepted_tokens_with_padding
363
392
364
393
def sample (
@@ -413,6 +442,10 @@ def load_weights(self, model_name_or_path: str,
413
442
draft_neuron_config .speculation_length = 0
414
443
draft_neuron_config .trace_tokengen_model = True
415
444
draft_neuron_config .enable_fused_speculation = False
445
+ if getattr (config .neuron_config , "draft_model_modules_to_not_convert" ,
446
+ None ):
447
+ draft_neuron_config .modules_to_not_convert = (
448
+ draft_neuron_config .draft_model_modules_to_not_convert )
416
449
if config .neuron_config .enable_eagle_speculation :
417
450
draft_neuron_config .is_eagle_draft = True
418
451
draft_neuron_config .sequence_parallel_enabled = False
0 commit comments