@@ -506,22 +506,23 @@ def _sample(
506
506
# sampling_tensors)
507
507
508
508
509
- def _get_ranks (x : torch .Tensor , indices : List [ int ] ) -> torch .Tensor :
509
+ def _get_ranks (x : torch .Tensor , indices : torch . Tensor ) -> torch .Tensor :
510
510
"""
511
511
This function calculates the ranks of the chosen tokens in a logprob tensor.
512
512
513
513
Args:
514
514
x (torch.Tensor): 2D logprob tensor of shape (N, M)
515
515
where N is the no. of tokens and M is the vocab dim.
516
- indices (List[int] ): List of chosen token indices.
516
+ indices (torch.Tensor ): List of chosen token indices.
517
517
518
518
Returns:
519
519
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
520
520
Each element in the returned tensor represents the rank
521
521
of the chosen token in the input logprob tensor.
522
522
"""
523
- vals = x [range (len (x )), indices ]
524
- return (x > vals [:, None ]).long ().sum (1 ) + 1
523
+ vals = x [torch .arange (0 , len (x ), device = x .device , dtype = indices .dtype ),
524
+ indices ]
525
+ return (x > vals [:, None ]).long ().sum (1 ).add_ (1 )
525
526
526
527
527
528
def _get_logprobs (
@@ -561,12 +562,21 @@ def _get_logprobs(
561
562
sample_idx += num_parent_seqs
562
563
assert sample_idx == logprobs .size (0 )
563
564
565
+ batched_logprobs_query_seq_indices_gpu = torch .tensor (
566
+ batched_logprobs_query_seq_indices , device = logprobs .device )
567
+ batched_logprobs_query_token_indices_gpu = torch .tensor (
568
+ batched_logprobs_query_token_indices , device = logprobs .device )
569
+
564
570
# Batched query for logprobs of selected token
565
571
batched_logprobs_query_result = logprobs [[
566
- batched_logprobs_query_seq_indices ,
567
- batched_logprobs_query_token_indices
572
+ batched_logprobs_query_seq_indices_gpu ,
573
+ batched_logprobs_query_token_indices_gpu
568
574
]]
569
575
576
+ batched_ranks_query_result = _get_ranks (
577
+ logprobs [batched_logprobs_query_seq_indices_gpu ],
578
+ batched_logprobs_query_token_indices_gpu )
579
+
570
580
# Batched query for logprobs of topk tokens
571
581
if largest_num_logprobs > 0 :
572
582
top_logprobs , top_token_ids = torch .topk (logprobs ,
@@ -578,10 +588,7 @@ def _get_logprobs(
578
588
top_logprobs , top_token_ids = None , None
579
589
580
590
batched_logprobs_query_result = batched_logprobs_query_result .cpu ()
581
-
582
- batched_ranks_query_result = _get_ranks (
583
- logprobs [batched_logprobs_query_seq_indices ],
584
- batched_logprobs_query_token_indices )
591
+ batched_ranks_query_result = batched_ranks_query_result .cpu ()
585
592
586
593
# Gather results
587
594
result_prompt_logprobs : List [Optional [PromptLogprobs ]] = []
0 commit comments