Skip to content

[Sampler] Vectorized sampler #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
afb9a6e
Update benchmark code
Yard1 Aug 19, 2023
349597d
WIP
Yard1 Aug 20, 2023
ed5af82
Fix benchmark
Yard1 Aug 20, 2023
b99d749
WIP
Yard1 Aug 20, 2023
e0f47fe
Beam search
Yard1 Aug 21, 2023
c2d365d
WIP
Yard1 Aug 22, 2023
caeaad3
Add comments
Yard1 Aug 22, 2023
e1f1e4f
WIP
Yard1 Aug 23, 2023
2eeac87
WIP
Yard1 Aug 23, 2023
8d81f87
Update vllm/model_executor/layers/sampler.py
Yard1 Aug 23, 2023
82458d7
Update vllm/model_executor/layers/sampler.py
Yard1 Aug 23, 2023
0bca9f1
Fix assert
Yard1 Aug 23, 2023
3aaa397
Lint
Yard1 Aug 23, 2023
fae01a3
Set replacement=True in multinomial
Yard1 Aug 25, 2023
35004ea
Merge branch 'main' into sampler_speedup
Yard1 Aug 25, 2023
7b95a7b
WIP
Yard1 Sep 1, 2023
68632bc
Merge branch 'upstream_main' into sampler_speedup
Yard1 Sep 1, 2023
98c0b15
Merge branch 'upstream_main' into sampler_speedup
Yard1 Sep 5, 2023
755040f
Apply feedback
Yard1 Sep 5, 2023
f8d37cc
Lint
Yard1 Sep 5, 2023
a995f06
Apply feedback from code review
Yard1 Sep 6, 2023
6a2b1b5
Merge branch 'upstream_main' into sampler_speedup
Yard1 Sep 6, 2023
e8935e7
Fix all beam case
Yard1 Sep 8, 2023
1b69567
Add test, fix issues
Yard1 Sep 8, 2023
3e51e1f
Lint
Yard1 Sep 8, 2023
a64d8b7
Merge branch 'upstream_main' into sampler_speedup
Yard1 Sep 12, 2023
5779d2d
Tweak
Yard1 Sep 12, 2023
db133d4
Fix
Yard1 Sep 12, 2023
7871574
Update vllm/worker/worker.py
Yard1 Sep 12, 2023
5a2615c
Update tests/samplers/test_sampler.py
Yard1 Sep 12, 2023
471a8bf
Update vllm/model_executor/layers/sampler.py
Yard1 Sep 12, 2023
f7d4c82
Nits
Yard1 Sep 12, 2023
d39834d
Merge branch 'sampler_speedup' of https://github.com/Yard1/vllm into …
Yard1 Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
Expand Down Expand Up @@ -78,15 +73,19 @@ def run_vllm(
)

# Add the requests to the engine.
do_sample = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: What is this variable used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows us to alternate between greedy and stochastic sampling for requests

for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
temperature=0.0 if use_beam_search else
(0.0 if not do_sample else 0.1),
top_p=0.9 if do_sample else 1.0,
presence_penalty=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
do_sample = not do_sample
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
Expand All @@ -96,8 +95,11 @@ def run_vllm(

start = time.time()
# FIXME(woosuk): Do use internal method.
llm._run_engine(use_tqdm=True)
outputs = llm._run_engine(use_tqdm=True)
end = time.time()
with open("output.txt", "w") as f:
for output in outputs:
f.write(output.__repr__() + "\n")
Comment on lines +98 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is to check if vLLM generated valid outputs, right? If so, I believe this ought to be done by our test code, rather than in the benchmarking code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will revert the changes to the benchmarking script later.

return end - start


Expand All @@ -111,8 +113,8 @@ def run_hf(
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -132,13 +134,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(
max_output_len, next_output_len)) <= 2048:
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue

# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
Expand All @@ -165,44 +168,53 @@ def main(args: argparse.Namespace):
random.seed(args.seed)

# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.tensor_parallel_size, args.seed, args.n,
args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(
requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
prompt_len + output_len
for _, prompt_len, output_len in requests
)
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm")
parser.add_argument("--dataset", type=str, required=True,
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1,
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
Expand Down
186 changes: 186 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import random
from typing import Tuple
from unittest.mock import patch

import torch

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker


class MockLogitsSampler(Sampler):

def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
self.fake_logits = fake_logits

def _get_logits(self, *args, **kwargs) -> torch.Tensor:
return self.fake_logits

def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x):
return super().forward(*args, **kwargs)


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker


def _test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == expected[i].item()


def _test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == i


def _test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.


def _test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)

seq_group_metadata_list = []
expected_tokens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens


def test_sampler():
for i in range(128):
print(f"Testing seed {i}...")
_test_sampler_all_greedy(i)
_test_sampler_all_random(i)
_test_sampler_all_beam(i)
_test_sampler_mixed(i)
2 changes: 2 additions & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
sampling_type_indices: torch.Tensor,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
Expand All @@ -37,6 +38,7 @@ def __init__(
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.sampling_type_indices = sampling_type_indices

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
Expand Down
Loading