Skip to content

Commit 708e6c1

Browse files
authored
[FIX] Fix class naming (#1803)
1 parent b943890 commit 708e6c1

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

vllm/engine/llm_engine.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from vllm.outputs import RequestOutput
1313
from vllm.sampling_params import SamplingParams
1414
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
15-
SequenceGroupMetadata, SequenceGroupOutputs,
16-
SequenceOutputs, SequenceStatus)
15+
SequenceGroupMetadata, SequenceGroupOutput,
16+
SequenceOutput, SequenceStatus)
1717
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
1818
get_tokenizer)
1919
from vllm.utils import Counter
@@ -363,7 +363,7 @@ def _check_beam_search_early_stopping(
363363
return current_worst_score >= highest_attainable_score
364364

365365
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
366-
outputs: SequenceGroupOutputs) -> None:
366+
outputs: SequenceGroupOutput) -> None:
367367
# Process prompt logprobs
368368
prompt_logprobs = outputs.prompt_logprobs
369369
if prompt_logprobs is not None:
@@ -384,7 +384,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
384384

385385
# Process the child samples for each parent sequence
386386
for parent in parent_seqs:
387-
child_samples: List[SequenceOutputs] = parent_child_dict[
387+
child_samples: List[SequenceOutput] = parent_child_dict[
388388
parent.seq_id]
389389
if len(child_samples) == 0:
390390
# This parent sequence has no children samples. Remove

vllm/model_executor/layers/sampler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
tensor_model_parallel_all_gather)
1010
from vllm.sampling_params import SamplingParams, SamplingType
1111
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
12-
SequenceData, SequenceGroupOutputs, SequenceOutputs)
12+
SequenceData, SequenceGroupOutput, SequenceOutput)
1313

1414
_SAMPLING_EPS = 1e-5
1515

@@ -641,7 +641,7 @@ def _build_sampler_output(
641641
next_token_ids,
642642
group_sample_logprobs):
643643
seq_outputs.append(
644-
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
644+
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
645645
sampler_output.append(
646-
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
646+
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
647647
return sampler_output

vllm/sequence.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(
352352
self.block_tables = block_tables
353353

354354

355-
class SequenceOutputs:
355+
class SequenceOutput:
356356
"""The model output associated with a sequence.
357357
358358
Args:
@@ -374,40 +374,40 @@ def __init__(
374374
self.logprobs = logprobs
375375

376376
def __repr__(self) -> str:
377-
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
377+
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
378378
f"output_token={self.output_token}, "
379379
f"logprobs={self.logprobs})")
380380

381381
def __eq__(self, other: object) -> bool:
382-
if not isinstance(other, SequenceOutputs):
382+
if not isinstance(other, SequenceOutput):
383383
raise NotImplementedError()
384384
return (self.parent_seq_id == other.parent_seq_id
385385
and self.output_token == other.output_token
386386
and self.logprobs == other.logprobs)
387387

388388

389-
class SequenceGroupOutputs:
390-
"""The model outputs associated with a sequence group."""
389+
class SequenceGroupOutput:
390+
"""The model output associated with a sequence group."""
391391

392392
def __init__(
393393
self,
394-
samples: List[SequenceOutputs],
394+
samples: List[SequenceOutput],
395395
prompt_logprobs: Optional[PromptLogprobs],
396396
) -> None:
397397
self.samples = samples
398398
self.prompt_logprobs = prompt_logprobs
399399

400400
def __repr__(self) -> str:
401-
return (f"SequenceGroupOutputs(samples={self.samples}, "
401+
return (f"SequenceGroupOutput(samples={self.samples}, "
402402
f"prompt_logprobs={self.prompt_logprobs})")
403403

404404
def __eq__(self, other: object) -> bool:
405-
if not isinstance(other, SequenceGroupOutputs):
405+
if not isinstance(other, SequenceGroupOutput):
406406
raise NotImplementedError()
407407
return (self.samples == other.samples
408408
and self.prompt_logprobs == other.prompt_logprobs)
409409

410410

411-
# For each sequence group, we generate a list of SequenceOutputs object,
411+
# For each sequence group, we generate a list of SequenceOutput object,
412412
# each of which contains one possible candidate for the next token.
413-
SamplerOutput = List[SequenceGroupOutputs]
413+
SamplerOutput = List[SequenceGroupOutput]

0 commit comments

Comments
 (0)