Skip to content

Commit 971936e

Browse files
shihaobaishihaobai
and
shihaobai
authored
Benchclient (#740)
Co-authored-by: shihaobai <[email protected]>
1 parent 5c28b33 commit 971936e

File tree

3 files changed

+264
-287
lines changed

3 files changed

+264
-287
lines changed

Diff for: docs/EN/source/getting_started/quickstart.rst

+7
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,11 @@ In a new terminal, use the following command to test the model service:
8181
$ }'
8282
8383
84+
For DeepSeek-R1 benchmark, use the following command to test the model service:
85+
86+
.. code-block:: console
87+
88+
$ cd test
89+
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream
90+
8491

Diff for: test/benchmark_client.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import os
2+
import argparse
3+
import yaml
4+
import requests
5+
import json
6+
import time
7+
import random
8+
import numpy as np
9+
from tqdm import tqdm
10+
from typing import Union, List, Tuple
11+
from concurrent.futures import ThreadPoolExecutor
12+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
13+
14+
15+
def seed_all(seed):
16+
random.seed(seed)
17+
os.environ["PYTHONHASHSEED"] = str(seed)
18+
np.random.seed(seed)
19+
20+
21+
def get_tokenizer(
22+
tokenizer_name: str,
23+
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
24+
"""Gets a tokenizer for the given model name via Huggingface."""
25+
26+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
27+
return tokenizer
28+
29+
30+
def get_output_length(input_num: int, output_len: int) -> List[int]:
31+
min_len, max_len = 2, output_len * 2
32+
mean = (min_len + max_len) * 0.5
33+
std = mean
34+
output_lens = []
35+
for _ in range(input_num):
36+
cur_len = random.gauss(mean, std)
37+
cur_len = round(cur_len)
38+
if cur_len < min_len:
39+
cur_len = min_len
40+
elif cur_len > max_len:
41+
cur_len = max_len
42+
output_lens.append(cur_len)
43+
return output_lens
44+
45+
46+
def gen_random_input_text(input_len, tokenizer) -> str:
47+
random_ids = [random.randint(512, 8192) for _ in range(1024)]
48+
random_text = tokenizer.decode(random_ids)
49+
return random_text
50+
51+
52+
def gen_random_data(
53+
input_len: int, output_len: int, input_num: int, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
54+
) -> Tuple[List[str], List[int], List[int]]:
55+
prompts = []
56+
input_lens = []
57+
output_lens = get_output_length(input_num, output_len)
58+
for i in range(input_num):
59+
input_text = gen_random_input_text(input_len, tokenizer)
60+
prompts.append(input_text)
61+
input_lens.append(input_len)
62+
print("Generate random data finish.")
63+
return prompts, input_lens, output_lens
64+
65+
66+
def post_stream_lightllm(url: str, text_input: str, max_new_tokens: int) -> List[float]:
67+
data = {
68+
"inputs": text_input,
69+
"parameters": {
70+
"do_sample": False,
71+
"ignore_eos": True,
72+
"max_new_tokens": max_new_tokens,
73+
},
74+
}
75+
headers = {"Content-Type": "application/json"}
76+
used_time = []
77+
start_time = time.time()
78+
last_time = start_time
79+
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
80+
if response.status_code != 200:
81+
print(response.json())
82+
assert response.status_code == 200
83+
for line in response.iter_lines():
84+
if line:
85+
current_time = time.time()
86+
elapsed_time = current_time - last_time
87+
used_time.append(elapsed_time)
88+
# print(line.decode("utf-8"))
89+
last_time = current_time
90+
return used_time
91+
92+
93+
model_name = []
94+
95+
96+
def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[float]:
97+
data = {
98+
"model": model_name[0],
99+
"prompt": text_input,
100+
"n": 1,
101+
"ignore_eos": True,
102+
"max_tokens": max_new_tokens,
103+
"stream": True,
104+
}
105+
headers = {"Content-Type": "application/json"}
106+
used_time = []
107+
start_time = time.time()
108+
last_time = start_time
109+
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
110+
assert response.status_code == 200
111+
for line in response.iter_content(chunk_size=8192):
112+
line = line.strip()
113+
if line:
114+
line = line.decode("utf-8")[6:] # remove "data: "
115+
if line == "[DONE]":
116+
continue
117+
data = json.loads(line)
118+
if not data["choices"][0]["text"]:
119+
continue
120+
current_time = time.time()
121+
elapsed_time = current_time - last_time
122+
used_time.append(elapsed_time)
123+
last_time = current_time
124+
return used_time
125+
126+
127+
def post_stream_triton(url: str, text_input: str, max_new_tokens: int) -> List[float]:
128+
data = {"text_input": text_input, "max_tokens": max_new_tokens, "stream": True}
129+
headers = {"Content-Type": "application/json"}
130+
used_time = []
131+
start_time = time.time()
132+
last_time = start_time
133+
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
134+
assert response.status_code == 200
135+
for line in response.iter_lines():
136+
if line:
137+
current_time = time.time()
138+
elapsed_time = current_time - last_time
139+
used_time.append(elapsed_time)
140+
last_time = current_time
141+
return used_time
142+
143+
144+
def main():
145+
parser = argparse.ArgumentParser()
146+
parser.add_argument("--url", type=str, default="http://127.0.0.1:8000/generate_stream")
147+
parser.add_argument("--num_clients", type=int, default=100)
148+
parser.add_argument("--tokenizer_path", type=str, default=None)
149+
parser.add_argument("--input_num", type=int, default=2000)
150+
parser.add_argument("--input_len", type=int, default=1024)
151+
parser.add_argument("--output_len", type=int, default=128)
152+
parser.add_argument("--server_api", type=str, default="lightllm")
153+
parser.add_argument("--dump_file", type=str, default="")
154+
parser.add_argument("--seed", type=int, default=0)
155+
156+
args = parser.parse_args()
157+
if args.dump_file and os.path.exists(args.dump_file):
158+
# 读取并输出 JSON 内容
159+
with open(args.dump_file, "r") as json_file:
160+
content = json.load(json_file)
161+
print(json.dumps(content, indent=4))
162+
return
163+
164+
assert args.tokenizer_path is not None
165+
model_name.append(args.tokenizer_path)
166+
seed_all(args.seed)
167+
url = args.url
168+
tokenizer = get_tokenizer(args.tokenizer_path)
169+
prompts, input_lens, max_new_tokens = gen_random_data(args.input_len, args.output_len, args.input_num, tokenizer)
170+
171+
percentiles = [25, 50, 75, 90, 95, 99, 100]
172+
if args.server_api == "lightllm":
173+
post_stream = post_stream_lightllm
174+
elif args.server_api == "openai":
175+
post_stream = post_stream_openai
176+
elif args.server_api == "triton":
177+
post_stream = post_stream_triton
178+
else:
179+
raise Exception(f"Not support {args.server_api} server_api.")
180+
181+
dump_dict = {}
182+
dump_dict["backend"] = args.server_api
183+
dump_dict["clients"] = args.num_clients
184+
start_time = time.time()
185+
186+
with ThreadPoolExecutor(max_workers=args.num_clients) as executor:
187+
results = list(
188+
tqdm(
189+
executor.map(lambda p: post_stream(url, p[0], p[1]), zip(prompts, max_new_tokens)),
190+
total=len(prompts),
191+
desc="Running tests",
192+
)
193+
)
194+
end_time = time.time()
195+
first_token_time = []
196+
decode_token_time = []
197+
request_time = []
198+
final_output_lens = []
199+
valid_num = 0
200+
for result in results:
201+
if len(result) > 1: # 统计至少decode出两个token的数据
202+
first_token_time.append(result[0])
203+
decode_token_time.append(sum(result[1:]) / len(result[1:]))
204+
request_time.append(sum(result))
205+
final_output_lens.append(len(result))
206+
valid_num += 1
207+
208+
print(
209+
f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n"
210+
)
211+
print(f"Total QPS: {valid_num / (end_time - start_time)}")
212+
print(f"Avg Input Length: {sum(input_lens) / len(input_lens)}")
213+
print(f"Avg Output Length: {sum(final_output_lens) / len(final_output_lens)}")
214+
print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)} token/s")
215+
print(f"Input Throughput: {sum(input_lens) / (end_time - start_time)} token/s")
216+
print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time)} token/s")
217+
print("-" * 10)
218+
dump_dict["request_num"] = valid_num
219+
dump_dict["Total QPS"] = valid_num / (end_time - start_time)
220+
dump_dict["Avg Input Length"] = sum(input_lens) / len(input_lens)
221+
dump_dict["Avg Output Length"] = sum(final_output_lens) / len(final_output_lens)
222+
dump_dict["Total Throughput"] = (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)
223+
dump_dict["Input Throughput"] = sum(input_lens) / (end_time - start_time)
224+
dump_dict["Output Throughput"] = sum(final_output_lens) / (end_time - start_time)
225+
226+
values = np.percentile(request_time, percentiles)
227+
request_time_dict = {}
228+
for percentile, value in zip(percentiles, values):
229+
print(f"request_time P{percentile}: {value:.6f}s")
230+
request_time_dict[f"P{percentile}"] = value
231+
dump_dict["request_time"] = request_time_dict
232+
print("-" * 10)
233+
234+
first_token_time_dict = {}
235+
values = np.percentile(first_token_time, percentiles)
236+
for percentile, value in zip(percentiles, values):
237+
print(f"first_token_time P{percentile}: {value:.6f}s")
238+
first_token_time_dict[f"P{percentile}"] = value
239+
dump_dict["first_token_time_dict"] = first_token_time_dict
240+
print("-" * 10)
241+
242+
decode_token_time_dict = {}
243+
values = np.percentile(decode_token_time, percentiles)
244+
for percentile, value in zip(percentiles, values):
245+
print(f"decode_token_time P{percentile}: {value * 1000:.6f}ms")
246+
decode_token_time_dict[f"P{percentile}"] = value * 1000
247+
dump_dict["decode_token_time_dict"] = decode_token_time_dict
248+
print(dump_dict)
249+
250+
if args.dump_file:
251+
with open(args.dump_file, "w") as json_file:
252+
json.dump(dump_dict, json_file, indent=4)
253+
print(f"Results have been written to {args.dump_file}")
254+
255+
256+
if __name__ == "__main__":
257+
main()

0 commit comments

Comments
 (0)