@@ -87,16 +87,29 @@ def forward(
87
87
input_block_ids : torch .Tensor ,
88
88
sampling_params : torch .Tensor ,
89
89
) -> torch .Tensor :
90
+ # sort block ids sequentially for perf/neuron support reasons
91
+ sorted_input_block_ids , sorted_indices = torch .sort (input_block_ids )
92
+ input_ids = torch .index_select (input_ids , 0 , sorted_indices )
93
+ positions = torch .index_select (positions , 0 , sorted_indices )
94
+ sampling_params = torch .index_select (sampling_params , 0 ,
95
+ sorted_indices )
96
+
90
97
output = self .model (input_ids ,
91
98
attention_mask = None ,
92
99
position_ids = positions ,
93
- seq_ids = input_block_ids ,
100
+ seq_ids = sorted_input_block_ids ,
94
101
sampling_params = sampling_params )
95
102
# on-device sampling
96
103
if self .config .neuron_config .on_device_sampling_config :
97
- return output .hidden_states
104
+ output = output .hidden_states
98
105
else :
99
- return output .logits [:, - 1 , :]
106
+ output = output .logits [:, - 1 , :]
107
+
108
+ restored_indices = torch .argsort (sorted_indices )
109
+ if input_block_ids .shape [0 ] != 1 :
110
+ output = torch .index_select (output , 0 , restored_indices )
111
+
112
+ return output
100
113
101
114
def compute_logits (self , hidden_states : torch .Tensor ,
102
115
sampling_metadata : SamplingMetadata ) -> torch .Tensor :
@@ -340,14 +353,26 @@ def forward(
340
353
input_block_ids : torch .Tensor ,
341
354
sampling_params : torch .Tensor ,
342
355
) -> torch .Tensor :
356
+ # sort block ids sequentially for perf/neuron support reasons
357
+ sorted_input_block_ids , sorted_indices = torch .sort (input_block_ids )
358
+ input_ids = torch .index_select (input_ids , 0 , sorted_indices )
359
+ positions = torch .index_select (positions , 0 , sorted_indices )
360
+ sampling_params = torch .index_select (sampling_params , 0 ,
361
+ sorted_indices )
362
+
343
363
output = self .model (input_ids ,
344
364
attention_mask = None ,
345
365
position_ids = positions ,
346
- seq_ids = input_block_ids ,
366
+ seq_ids = sorted_input_block_ids ,
347
367
sampling_params = sampling_params )
368
+ restored_indices = torch .argsort (sorted_indices )
369
+
348
370
# CTX encoding
349
371
if (positions [:, 0 ]).sum ().item () == 0 :
350
- return output .fused_outputs [0 ][:, 0 :1 ]
372
+ output = output .fused_outputs [0 ][:, 0 :1 ]
373
+ if input_block_ids .shape [0 ] != 1 :
374
+ output = torch .index_select (output , 0 , restored_indices )
375
+ return output
351
376
352
377
# Fused Spec (Generation)
353
378
accepted_tokens_with_padding = output .fused_outputs [0 ]
@@ -362,6 +387,10 @@ def forward(
362
387
- 1 ) >= generated_token_counts
363
388
accepted_tokens_with_padding [mask ] = - 1
364
389
390
+ if input_block_ids .shape [0 ] != 1 :
391
+ accepted_tokens_with_padding = torch .index_select (
392
+ accepted_tokens_with_padding , 0 , restored_indices )
393
+
365
394
return accepted_tokens_with_padding
366
395
367
396
def sample (
@@ -416,6 +445,10 @@ def load_weights(self, model_name_or_path: str,
416
445
draft_neuron_config .speculation_length = 0
417
446
draft_neuron_config .trace_tokengen_model = True
418
447
draft_neuron_config .enable_fused_speculation = False
448
+ if getattr (config .neuron_config , "draft_model_modules_to_not_convert" ,
449
+ None ):
450
+ draft_neuron_config .modules_to_not_convert = (
451
+ draft_neuron_config .draft_model_modules_to_not_convert )
419
452
if config .neuron_config .enable_eagle_speculation :
420
453
draft_neuron_config .is_eagle_draft = True
421
454
draft_neuron_config .sequence_parallel_enabled = False
0 commit comments