Skip to content

Commit ff03178

Browse files
andylolu2Isotr0py
authored andcommitted
[Bugfix] Multi-sequence broken (vllm-project#11898)
Signed-off-by: Andy Lo <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent ede847b commit ff03178

File tree

3 files changed

+59
-39
lines changed

3 files changed

+59
-39
lines changed

tests/samplers/test_seeded_generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_random_sample_with_seed(
3131

3232
sampling_params = SamplingParams(
3333
# Parameters to ensure sufficient randomness
34-
temperature=2.0,
34+
temperature=3.0,
3535
top_p=min(random.random() + 0.3, 1),
3636
top_k=random.randint(5, 20),
3737
n=random.randint(1, 10),
@@ -75,3 +75,8 @@ def test_random_sample_with_seed(
7575
# verify requests with the same seed match
7676
assert outputs[1] == outputs[4]
7777
assert outputs[2] == outputs[5]
78+
79+
# verify generations within the same parallel sampling group differ
80+
for output in outputs:
81+
for sub_output_a, sub_output_b in combinations(output, 2):
82+
assert sub_output_a != sub_output_b

vllm/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def from_seq_group(
172172
if seq_group.request_id in seq_id_to_seq_group:
173173
group: SequenceGroupBase = seq_id_to_seq_group[
174174
seq_group.request_id]
175+
assembled_seq_group = group.maybe_assemble_group(seq_group)
175176
if finished:
176177
group.finish_seq(seq_group)
177-
assembled_seq_group = group.maybe_assemble_group(seq_group)
178178
if assembled_seq_group is None:
179179
return None
180180
return cls.from_seq_group(assembled_seq_group, use_cache,

vllm/sequence.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,9 @@ def set_finished_time(self, time: Optional[float]) -> None:
815815
def get_max_num_running_seqs(self) -> int:
816816
"""The maximum number of sequences running in parallel in the remaining
817817
lifetime of the request."""
818-
return 0 if self.first_seq.is_finished() else 1
818+
if self.is_single_seq:
819+
return 0 if self.first_seq.is_finished() else 1
820+
return self.num_seqs() - self.num_finished_seqs()
819821

820822
def get_seqs(
821823
self,
@@ -824,7 +826,10 @@ def get_seqs(
824826
if status is None:
825827
return self.seqs
826828

827-
return self.seqs if self.first_seq.status == status else []
829+
if self.is_single_seq:
830+
return self.seqs if self.first_seq.status == status else []
831+
832+
return [seq for seq in self.seqs if seq.status == status]
828833

829834
def is_encoder_decoder(self) -> bool:
830835
return self.encoder_seq is not None
@@ -833,19 +838,22 @@ def get_encoder_seq(self) -> Optional[Sequence]:
833838
return self.encoder_seq
834839

835840
def get_finished_seqs(self) -> List[Sequence]:
836-
return self.seqs if self.first_seq.is_finished() else []
841+
if self.is_single_seq:
842+
return self.seqs if self.first_seq.is_finished() else []
843+
844+
return [seq for seq in self.seqs if seq.is_finished()]
837845

838846
def update_num_computed_tokens(self, num_new_computed_tokens: int):
839847
"""Update number of tokens computed so far."""
840-
seq = self.first_seq
841-
if not seq.is_finished():
842-
seq.data.update_num_computed_tokens(num_new_computed_tokens)
848+
for seq in self.seqs:
849+
if not seq.is_finished():
850+
seq.data.update_num_computed_tokens(num_new_computed_tokens)
843851

844852
def get_num_uncomputed_tokens(self) -> int:
845853
num_uncomputed_tokens = 0
846-
seq = self.first_seq
847-
if not seq.is_finished():
848-
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
854+
for seq in self.seqs:
855+
if not seq.is_finished():
856+
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
849857
return num_uncomputed_tokens
850858

851859
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
@@ -860,10 +868,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
860868
return len(self.get_seqs(status))
861869

862870
def num_finished_seqs(self) -> int:
863-
return 1 if self.first_seq.is_finished() else 0
871+
if self.is_single_seq:
872+
return 1 if self.seqs[0].is_finished() else 0
873+
return len(self.get_finished_seqs())
864874

865875
def is_finished(self) -> bool:
866-
return self.first_seq.is_finished()
876+
if self.is_single_seq:
877+
return self.first_seq.is_finished()
878+
return all(seq.is_finished() for seq in self.seqs)
867879

868880
def is_prefill(self) -> bool:
869881
return self.first_seq.is_prefill()
@@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
13911403
@staticmethod
13921404
def add_request(request_id: str, engine, params, **kwargs):
13931405
original_params = params
1394-
params = original_params.clone()
1395-
params.n = 1
13961406
group = ParallelSampleSequenceGroup(request_id)
13971407
seqs = []
13981408
for i in range(original_params.n):
13991409
request_id_i = f"{request_id}_parallel_sample_{i}"
14001410
group.seq_id_to_index[request_id_i] = i
1411+
params = copy.deepcopy(original_params)
1412+
params.n = 1
1413+
if params.seed is not None:
1414+
params.seed += i
14011415
seq_group = engine._add_processed_request(
14021416
request_id_i,
14031417
params=params,
@@ -1432,33 +1446,34 @@ def maybe_assemble_group(
14321446
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
14331447

14341448
# in the streaming mode, we will return the assembled sequence
1435-
# for the first sequence, and then return None for the rest of
1436-
# sequences
1449+
# for the first remaining sequence, and then return None for the
1450+
# rest of sequences
14371451
if self.streaming:
1438-
if self.seq_id_to_index[seq_group.request_id] == 0:
1452+
first_remaining_id = next(iter(self.to_be_finished))
1453+
if seq_group.request_id == first_remaining_id:
14391454
return self.assembled_seq_group
14401455
return None
14411456

14421457
# in the non-streaming mode, we will return the assembled sequence
1443-
# once after all sequences finish, and then return None for the
1458+
# when the last sequences finishes, and then return None for the
14441459
# rest of the time
1445-
1446-
if len(self.to_be_finished) > 0:
1447-
return None
1448-
1449-
assert self.assembled_seq_group is not None
1450-
params = self.assembled_seq_group.sampling_params
1451-
assert isinstance(params, SamplingParams)
1452-
if not self.output_produced:
1453-
self.output_produced = True
1454-
if params._real_n is not None:
1455-
# Get the top-n sequences.
1456-
n = params._real_n or params.n
1457-
seqs = self.assembled_seq_group.seqs
1458-
sorting_key = lambda seq: seq.get_cumulative_logprob()
1459-
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
1460-
top_n_seqs = sorted_seqs[:n]
1461-
self.assembled_seq_group.seqs = top_n_seqs
1462-
return self.assembled_seq_group
1463-
if self.output_produced:
1464-
return None
1460+
if (len(self.to_be_finished) == 1
1461+
and seq_group.request_id in self.to_be_finished
1462+
and seq_group.is_finished()):
1463+
assert self.assembled_seq_group is not None
1464+
params = self.assembled_seq_group.sampling_params
1465+
assert isinstance(params, SamplingParams)
1466+
if not self.output_produced:
1467+
self.output_produced = True
1468+
if params._real_n is not None:
1469+
# Get the top-n sequences.
1470+
n = params._real_n or params.n
1471+
seqs = self.assembled_seq_group.seqs
1472+
sorting_key = lambda seq: seq.get_cumulative_logprob()
1473+
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
1474+
top_n_seqs = sorted_seqs[:n]
1475+
self.assembled_seq_group.seqs = top_n_seqs
1476+
return self.assembled_seq_group
1477+
if self.output_produced:
1478+
return None
1479+
return None

0 commit comments

Comments
 (0)