-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[Sampler] Vectorized sampler #820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
afb9a6e
349597d
ed5af82
b99d749
e0f47fe
c2d365d
caeaad3
e1f1e4f
2eeac87
8d81f87
82458d7
0bca9f1
3aaa397
fae01a3
35004ea
7b95a7b
68632bc
98c0b15
755040f
f8d37cc
a995f06
6a2b1b5
e8935e7
1b69567
3e51e1f
a64d8b7
5779d2d
db133d4
7871574
5a2615c
471a8bf
f7d4c82
d39834d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,15 +22,10 @@ def sample_requests( | |
with open(dataset_path) as f: | ||
dataset = json.load(f) | ||
# Filter out the conversations with less than 2 turns. | ||
dataset = [ | ||
data for data in dataset | ||
if len(data["conversations"]) >= 2 | ||
] | ||
dataset = [data for data in dataset if len(data["conversations"]) >= 2] | ||
# Only keep the first two turns of each conversation. | ||
dataset = [ | ||
(data["conversations"][0]["value"], data["conversations"][1]["value"]) | ||
for data in dataset | ||
] | ||
dataset = [(data["conversations"][0]["value"], | ||
data["conversations"][1]["value"]) for data in dataset] | ||
|
||
# Tokenize the prompts and completions. | ||
prompts = [prompt for prompt, _ in dataset] | ||
|
@@ -78,15 +73,19 @@ def run_vllm( | |
) | ||
|
||
# Add the requests to the engine. | ||
do_sample = False | ||
for prompt, _, output_len in requests: | ||
sampling_params = SamplingParams( | ||
n=n, | ||
temperature=0.0 if use_beam_search else 1.0, | ||
top_p=1.0, | ||
temperature=0.0 if use_beam_search else | ||
(0.0 if not do_sample else 0.1), | ||
top_p=0.9 if do_sample else 1.0, | ||
presence_penalty=1.0, | ||
use_beam_search=use_beam_search, | ||
ignore_eos=True, | ||
max_tokens=output_len, | ||
) | ||
do_sample = not do_sample | ||
# FIXME(woosuk): Do not use internal method. | ||
llm._add_request( | ||
prompt=prompt, | ||
|
@@ -96,8 +95,11 @@ def run_vllm( | |
|
||
start = time.time() | ||
# FIXME(woosuk): Do use internal method. | ||
llm._run_engine(use_tqdm=True) | ||
outputs = llm._run_engine(use_tqdm=True) | ||
end = time.time() | ||
with open("output.txt", "w") as f: | ||
for output in outputs: | ||
f.write(output.__repr__() + "\n") | ||
Comment on lines
+98
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is to check if vLLM generated valid outputs, right? If so, I believe this ought to be done by our test code, rather than in the benchmarking code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I will revert the changes to the benchmarking script later. |
||
return end - start | ||
|
||
|
||
|
@@ -111,8 +113,8 @@ def run_hf( | |
trust_remote_code: bool, | ||
) -> float: | ||
assert not use_beam_search | ||
llm = AutoModelForCausalLM.from_pretrained(model, | ||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code) | ||
llm = AutoModelForCausalLM.from_pretrained( | ||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) | ||
if llm.config.model_type == "llama": | ||
# To enable padding in the HF backend. | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
@@ -132,13 +134,14 @@ def run_hf( | |
if len(batch) < max_batch_size and i != len(requests) - 1: | ||
# Check if we can add more requests to the batch. | ||
_, next_prompt_len, next_output_len = requests[i + 1] | ||
if (max(max_prompt_len, next_prompt_len) + max( | ||
max_output_len, next_output_len)) <= 2048: | ||
if (max(max_prompt_len, next_prompt_len) + | ||
max(max_output_len, next_output_len)) <= 2048: | ||
# We can add more requests to the batch. | ||
continue | ||
|
||
# Generate the sequences. | ||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids | ||
input_ids = tokenizer(batch, return_tensors="pt", | ||
padding=True).input_ids | ||
llm_outputs = llm.generate( | ||
input_ids=input_ids.cuda(), | ||
do_sample=not use_beam_search, | ||
|
@@ -165,44 +168,53 @@ def main(args: argparse.Namespace): | |
random.seed(args.seed) | ||
|
||
# Sample the requests. | ||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) | ||
tokenizer = get_tokenizer(args.tokenizer, | ||
trust_remote_code=args.trust_remote_code) | ||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) | ||
|
||
if args.backend == "vllm": | ||
elapsed_time = run_vllm( | ||
requests, args.model, args.tokenizer, args.tensor_parallel_size, | ||
args.seed, args.n, args.use_beam_search, args.trust_remote_code) | ||
elapsed_time = run_vllm(requests, args.model, args.tokenizer, | ||
args.tensor_parallel_size, args.seed, args.n, | ||
args.use_beam_search, args.trust_remote_code) | ||
elif args.backend == "hf": | ||
assert args.tensor_parallel_size == 1 | ||
elapsed_time = run_hf( | ||
requests, args.model, tokenizer, args.n, args.use_beam_search, | ||
args.hf_max_batch_size, args.trust_remote_code) | ||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, | ||
args.use_beam_search, args.hf_max_batch_size, | ||
args.trust_remote_code) | ||
else: | ||
raise ValueError(f"Unknown backend: {args.backend}") | ||
total_num_tokens = sum( | ||
prompt_len + output_len | ||
for _, prompt_len, output_len in requests | ||
) | ||
total_num_tokens = sum(prompt_len + output_len | ||
for _, prompt_len, output_len in requests) | ||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " | ||
f"{total_num_tokens / elapsed_time:.2f} tokens/s") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Benchmark the throughput.") | ||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"], | ||
parser.add_argument("--backend", | ||
type=str, | ||
choices=["vllm", "hf"], | ||
default="vllm") | ||
parser.add_argument("--dataset", type=str, required=True, | ||
parser.add_argument("--dataset", | ||
type=str, | ||
required=True, | ||
help="Path to the dataset.") | ||
parser.add_argument("--model", type=str, default="facebook/opt-125m") | ||
parser.add_argument("--tokenizer", type=str, default=None) | ||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) | ||
parser.add_argument("--n", type=int, default=1, | ||
parser.add_argument("--n", | ||
type=int, | ||
default=1, | ||
help="Number of generated sequences per prompt.") | ||
parser.add_argument("--use-beam-search", action="store_true") | ||
parser.add_argument("--num-prompts", type=int, default=1000, | ||
parser.add_argument("--num-prompts", | ||
type=int, | ||
default=1000, | ||
help="Number of prompts to process.") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--hf-max-batch-size", type=int, default=None, | ||
parser.add_argument("--hf-max-batch-size", | ||
type=int, | ||
default=None, | ||
help="Maximum batch size for HF backend.") | ||
parser.add_argument('--trust-remote-code', | ||
action='store_true', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import random | ||
from typing import Tuple | ||
from unittest.mock import patch | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.utils import set_random_seed | ||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata | ||
from vllm.worker.worker import Worker | ||
|
||
|
||
class MockLogitsSampler(Sampler): | ||
|
||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor): | ||
super().__init__(vocab_size=vocab_size) | ||
self.fake_logits = fake_logits | ||
|
||
def _get_logits(self, *args, **kwargs) -> torch.Tensor: | ||
return self.fake_logits | ||
|
||
def forward(self, *args, **kwargs): | ||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states", | ||
lambda x, y: x): | ||
return super().forward(*args, **kwargs) | ||
|
||
|
||
def _prepare_test( | ||
batch_size: int | ||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: | ||
vocab_size = 32000 | ||
input_tensor = torch.rand((batch_size, 1024), | ||
device="cuda", | ||
dtype=torch.float16) | ||
fake_logits = torch.full((batch_size, vocab_size), | ||
1e-2, | ||
device=input_tensor.device, | ||
dtype=input_tensor.dtype) | ||
sampler = MockLogitsSampler(32000, fake_logits) | ||
worker = Worker(None, None, None) | ||
worker.block_size = 16 | ||
return input_tensor, fake_logits, sampler, worker | ||
|
||
|
||
def _test_sampler_all_greedy(seed: int): | ||
set_random_seed(seed) | ||
batch_size = random.randint(1, 256) | ||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) | ||
|
||
seq_group_metadata_list = [] | ||
for i in range(batch_size): | ||
seq_group_metadata_list.append( | ||
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
sampling_params=SamplingParams(temperature=0, ), | ||
block_tables={0: [1]}, | ||
)) | ||
|
||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) | ||
sampler_output = sampler(embedding=None, | ||
hidden_states=input_tensor, | ||
input_metadata=input_metadata) | ||
expected = torch.argmax(fake_logits, dim=-1) | ||
for i, sequence_output in enumerate(sampler_output): | ||
for nth_output in sequence_output: | ||
assert nth_output.output_token == expected[i].item() | ||
|
||
|
||
def _test_sampler_all_random(seed: int): | ||
set_random_seed(seed) | ||
batch_size = random.randint(1, 256) | ||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) | ||
|
||
for i in range(batch_size): | ||
fake_logits[i, i] = 1e2 | ||
|
||
seq_group_metadata_list = [] | ||
for i in range(batch_size): | ||
seq_group_metadata_list.append( | ||
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
sampling_params=SamplingParams( | ||
temperature=1.0, | ||
n=random.randint(1, 10), | ||
), | ||
block_tables={0: [1]}, | ||
)) | ||
|
||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) | ||
sampler_output = sampler(embedding=None, | ||
hidden_states=input_tensor, | ||
input_metadata=input_metadata) | ||
for i, sequence_output in enumerate(sampler_output): | ||
for nth_output in sequence_output: | ||
assert nth_output.output_token == i | ||
|
||
|
||
def _test_sampler_all_beam(seed: int): | ||
set_random_seed(seed) | ||
batch_size = random.randint(1, 256) | ||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) | ||
|
||
seq_group_metadata_list = [] | ||
for i in range(batch_size): | ||
seq_group_metadata_list.append( | ||
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
sampling_params=SamplingParams( | ||
temperature=0, | ||
best_of=2, | ||
use_beam_search=True, | ||
), | ||
block_tables={0: [1]}, | ||
)) | ||
|
||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) | ||
sampler(embedding=None, | ||
hidden_states=input_tensor, | ||
input_metadata=input_metadata) | ||
# no assertion here as I am not sure how to determine whether | ||
# the outputs are expected - in other words, this just tests | ||
# whether there are no exceptions in the sampler | ||
# when handling an all-beam search case. | ||
|
||
|
||
def _test_sampler_mixed(seed: int): | ||
set_random_seed(seed) | ||
batch_size = random.randint(1, 256) | ||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) | ||
|
||
seq_group_metadata_list = [] | ||
expected_tokens = [] | ||
for i in range(batch_size): | ||
n = 1 | ||
sampling_type = random.randint(0, 2) | ||
if sampling_type == 0: | ||
sampling_params = SamplingParams(temperature=0) | ||
elif sampling_type == 1: | ||
n = random.randint(1, 10) | ||
sampling_params = SamplingParams( | ||
temperature=random.random() + 0.1, | ||
top_p=min(random.random() + 0.1, 1), | ||
top_k=random.randint(0, 10) or -1, | ||
n=n, | ||
presence_penalty=random.randint(0, 1), | ||
) | ||
else: | ||
sampling_params = SamplingParams(temperature=0, | ||
use_beam_search=True, | ||
best_of=2) | ||
for idx in range(n): | ||
fake_logits[i, i + idx] = 1e2 | ||
expected_tokens.append(i + idx) | ||
seq_group_metadata_list.append( | ||
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
sampling_params=sampling_params, | ||
block_tables={0: [1]}, | ||
)) | ||
|
||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) | ||
sampler_output = sampler(embedding=None, | ||
hidden_states=input_tensor, | ||
input_metadata=input_metadata) | ||
for i, sequence_output in enumerate(sampler_output): | ||
if seq_group_metadata_list[i].sampling_params.use_beam_search: | ||
continue | ||
for nth_output in sequence_output: | ||
assert nth_output.output_token in expected_tokens | ||
|
||
|
||
def test_sampler(): | ||
for i in range(128): | ||
print(f"Testing seed {i}...") | ||
_test_sampler_all_greedy(i) | ||
_test_sampler_all_random(i) | ||
_test_sampler_all_beam(i) | ||
_test_sampler_mixed(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb question: What is this variable used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allows us to alternate between greedy and stochastic sampling for requests