@@ -815,7 +815,9 @@ def set_finished_time(self, time: Optional[float]) -> None:
815
815
def get_max_num_running_seqs (self ) -> int :
816
816
"""The maximum number of sequences running in parallel in the remaining
817
817
lifetime of the request."""
818
- return 0 if self .first_seq .is_finished () else 1
818
+ if self .is_single_seq :
819
+ return 0 if self .first_seq .is_finished () else 1
820
+ return self .num_seqs () - self .num_finished_seqs ()
819
821
820
822
def get_seqs (
821
823
self ,
@@ -824,7 +826,10 @@ def get_seqs(
824
826
if status is None :
825
827
return self .seqs
826
828
827
- return self .seqs if self .first_seq .status == status else []
829
+ if self .is_single_seq :
830
+ return self .seqs if self .first_seq .status == status else []
831
+
832
+ return [seq for seq in self .seqs if seq .status == status ]
828
833
829
834
def is_encoder_decoder (self ) -> bool :
830
835
return self .encoder_seq is not None
@@ -833,19 +838,22 @@ def get_encoder_seq(self) -> Optional[Sequence]:
833
838
return self .encoder_seq
834
839
835
840
def get_finished_seqs (self ) -> List [Sequence ]:
836
- return self .seqs if self .first_seq .is_finished () else []
841
+ if self .is_single_seq :
842
+ return self .seqs if self .first_seq .is_finished () else []
843
+
844
+ return [seq for seq in self .seqs if seq .is_finished ()]
837
845
838
846
def update_num_computed_tokens (self , num_new_computed_tokens : int ):
839
847
"""Update number of tokens computed so far."""
840
- seq = self .first_seq
841
- if not seq .is_finished ():
842
- seq .data .update_num_computed_tokens (num_new_computed_tokens )
848
+ for seq in self .seqs :
849
+ if not seq .is_finished ():
850
+ seq .data .update_num_computed_tokens (num_new_computed_tokens )
843
851
844
852
def get_num_uncomputed_tokens (self ) -> int :
845
853
num_uncomputed_tokens = 0
846
- seq = self .first_seq
847
- if not seq .is_finished ():
848
- num_uncomputed_tokens += seq .data .get_num_uncomputed_tokens ()
854
+ for seq in self .seqs :
855
+ if not seq .is_finished ():
856
+ num_uncomputed_tokens += seq .data .get_num_uncomputed_tokens ()
849
857
return num_uncomputed_tokens
850
858
851
859
def num_seqs (self , status : Optional [SequenceStatus ] = None ) -> int :
@@ -860,10 +868,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
860
868
return len (self .get_seqs (status ))
861
869
862
870
def num_finished_seqs (self ) -> int :
863
- return 1 if self .first_seq .is_finished () else 0
871
+ if self .is_single_seq :
872
+ return 1 if self .seqs [0 ].is_finished () else 0
873
+ return len (self .get_finished_seqs ())
864
874
865
875
def is_finished (self ) -> bool :
866
- return self .first_seq .is_finished ()
876
+ if self .is_single_seq :
877
+ return self .first_seq .is_finished ()
878
+ return all (seq .is_finished () for seq in self .seqs )
867
879
868
880
def is_prefill (self ) -> bool :
869
881
return self .first_seq .is_prefill ()
@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
1391
1403
@staticmethod
1392
1404
def add_request (request_id : str , engine , params , ** kwargs ):
1393
1405
original_params = params
1394
- params = original_params .clone ()
1395
- params .n = 1
1396
1406
group = ParallelSampleSequenceGroup (request_id )
1397
1407
seqs = []
1398
1408
for i in range (original_params .n ):
1399
1409
request_id_i = f"{ request_id } _parallel_sample_{ i } "
1400
1410
group .seq_id_to_index [request_id_i ] = i
1411
+ params = copy .deepcopy (original_params )
1412
+ params .n = 1
1413
+ if params .seed is not None :
1414
+ params .seed += i
1401
1415
seq_group = engine ._add_processed_request (
1402
1416
request_id_i ,
1403
1417
params = params ,
@@ -1432,33 +1446,34 @@ def maybe_assemble_group(
1432
1446
self , seq_group : SequenceGroup ) -> Optional [SequenceGroup ]:
1433
1447
1434
1448
# in the streaming mode, we will return the assembled sequence
1435
- # for the first sequence, and then return None for the rest of
1436
- # sequences
1449
+ # for the first remaining sequence, and then return None for the
1450
+ # rest of sequences
1437
1451
if self .streaming :
1438
- if self .seq_id_to_index [seq_group .request_id ] == 0 :
1452
+ first_remaining_id = next (iter (self .to_be_finished ))
1453
+ if seq_group .request_id == first_remaining_id :
1439
1454
return self .assembled_seq_group
1440
1455
return None
1441
1456
1442
1457
# in the non-streaming mode, we will return the assembled sequence
1443
- # once after all sequences finish , and then return None for the
1458
+ # when the last sequences finishes , and then return None for the
1444
1459
# rest of the time
1445
-
1446
- if len ( self . to_be_finished ) > 0 :
1447
- return None
1448
-
1449
- assert self . assembled_seq_group is not None
1450
- params = self . assembled_seq_group . sampling_params
1451
- assert isinstance ( params , SamplingParams )
1452
- if not self .output_produced :
1453
- self . output_produced = True
1454
- if params . _real_n is not None :
1455
- # Get the top-n sequences.
1456
- n = params . _real_n or params . n
1457
- seqs = self . assembled_seq_group . seqs
1458
- sorting_key = lambda seq : seq . get_cumulative_logprob ( )
1459
- sorted_seqs = sorted ( seqs , key = sorting_key , reverse = True )
1460
- top_n_seqs = sorted_seqs [: n ]
1461
- self .assembled_seq_group . seqs = top_n_seqs
1462
- return self .assembled_seq_group
1463
- if self . output_produced :
1464
- return None
1460
+ if ( len ( self . to_be_finished ) == 1
1461
+ and seq_group . request_id in self . to_be_finished
1462
+ and seq_group . is_finished ()):
1463
+ assert self . assembled_seq_group is not None
1464
+ params = self . assembled_seq_group . sampling_params
1465
+ assert isinstance ( params , SamplingParams )
1466
+ if not self . output_produced :
1467
+ self .output_produced = True
1468
+ if params . _real_n is not None :
1469
+ # Get the top-n sequences.
1470
+ n = params . _real_n or params . n
1471
+ seqs = self . assembled_seq_group . seqs
1472
+ sorting_key = lambda seq : seq . get_cumulative_logprob ()
1473
+ sorted_seqs = sorted ( seqs , key = sorting_key , reverse = True )
1474
+ top_n_seqs = sorted_seqs [: n ]
1475
+ self . assembled_seq_group . seqs = top_n_seqs
1476
+ return self .assembled_seq_group
1477
+ if self .output_produced :
1478
+ return None
1479
+ return None
0 commit comments