|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +import yaml |
| 4 | +import requests |
| 5 | +import json |
| 6 | +import time |
| 7 | +import random |
| 8 | +import numpy as np |
| 9 | +from tqdm import tqdm |
| 10 | +from typing import Union, List, Tuple |
| 11 | +from concurrent.futures import ThreadPoolExecutor |
| 12 | +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast |
| 13 | + |
| 14 | + |
| 15 | +def seed_all(seed): |
| 16 | + random.seed(seed) |
| 17 | + os.environ["PYTHONHASHSEED"] = str(seed) |
| 18 | + np.random.seed(seed) |
| 19 | + |
| 20 | + |
| 21 | +def get_tokenizer( |
| 22 | + tokenizer_name: str, |
| 23 | +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: |
| 24 | + """Gets a tokenizer for the given model name via Huggingface.""" |
| 25 | + |
| 26 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) |
| 27 | + return tokenizer |
| 28 | + |
| 29 | + |
| 30 | +def get_output_length(input_num: int, output_len: int) -> List[int]: |
| 31 | + min_len, max_len = 2, output_len * 2 |
| 32 | + mean = (min_len + max_len) * 0.5 |
| 33 | + std = mean |
| 34 | + output_lens = [] |
| 35 | + for _ in range(input_num): |
| 36 | + cur_len = random.gauss(mean, std) |
| 37 | + cur_len = round(cur_len) |
| 38 | + if cur_len < min_len: |
| 39 | + cur_len = min_len |
| 40 | + elif cur_len > max_len: |
| 41 | + cur_len = max_len |
| 42 | + output_lens.append(cur_len) |
| 43 | + return output_lens |
| 44 | + |
| 45 | + |
| 46 | +def gen_random_input_text(input_len, tokenizer) -> str: |
| 47 | + random_ids = [random.randint(512, 8192) for _ in range(1024)] |
| 48 | + random_text = tokenizer.decode(random_ids) |
| 49 | + return random_text |
| 50 | + |
| 51 | + |
| 52 | +def gen_random_data( |
| 53 | + input_len: int, output_len: int, input_num: int, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] |
| 54 | +) -> Tuple[List[str], List[int], List[int]]: |
| 55 | + prompts = [] |
| 56 | + input_lens = [] |
| 57 | + output_lens = get_output_length(input_num, output_len) |
| 58 | + for i in range(input_num): |
| 59 | + input_text = gen_random_input_text(input_len, tokenizer) |
| 60 | + prompts.append(input_text) |
| 61 | + input_lens.append(input_len) |
| 62 | + print("Generate random data finish.") |
| 63 | + return prompts, input_lens, output_lens |
| 64 | + |
| 65 | + |
| 66 | +def post_stream_lightllm(url: str, text_input: str, max_new_tokens: int) -> List[float]: |
| 67 | + data = { |
| 68 | + "inputs": text_input, |
| 69 | + "parameters": { |
| 70 | + "do_sample": False, |
| 71 | + "ignore_eos": True, |
| 72 | + "max_new_tokens": max_new_tokens, |
| 73 | + }, |
| 74 | + } |
| 75 | + headers = {"Content-Type": "application/json"} |
| 76 | + used_time = [] |
| 77 | + start_time = time.time() |
| 78 | + last_time = start_time |
| 79 | + response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) |
| 80 | + if response.status_code != 200: |
| 81 | + print(response.json()) |
| 82 | + assert response.status_code == 200 |
| 83 | + for line in response.iter_lines(): |
| 84 | + if line: |
| 85 | + current_time = time.time() |
| 86 | + elapsed_time = current_time - last_time |
| 87 | + used_time.append(elapsed_time) |
| 88 | + # print(line.decode("utf-8")) |
| 89 | + last_time = current_time |
| 90 | + return used_time |
| 91 | + |
| 92 | + |
| 93 | +model_name = [] |
| 94 | + |
| 95 | + |
| 96 | +def post_stream_openai(url: str, text_input: str, max_new_tokens: int) -> List[float]: |
| 97 | + data = { |
| 98 | + "model": model_name[0], |
| 99 | + "prompt": text_input, |
| 100 | + "n": 1, |
| 101 | + "ignore_eos": True, |
| 102 | + "max_tokens": max_new_tokens, |
| 103 | + "stream": True, |
| 104 | + } |
| 105 | + headers = {"Content-Type": "application/json"} |
| 106 | + used_time = [] |
| 107 | + start_time = time.time() |
| 108 | + last_time = start_time |
| 109 | + response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) |
| 110 | + assert response.status_code == 200 |
| 111 | + for line in response.iter_content(chunk_size=8192): |
| 112 | + line = line.strip() |
| 113 | + if line: |
| 114 | + line = line.decode("utf-8")[6:] # remove "data: " |
| 115 | + if line == "[DONE]": |
| 116 | + continue |
| 117 | + data = json.loads(line) |
| 118 | + if not data["choices"][0]["text"]: |
| 119 | + continue |
| 120 | + current_time = time.time() |
| 121 | + elapsed_time = current_time - last_time |
| 122 | + used_time.append(elapsed_time) |
| 123 | + last_time = current_time |
| 124 | + return used_time |
| 125 | + |
| 126 | + |
| 127 | +def post_stream_triton(url: str, text_input: str, max_new_tokens: int) -> List[float]: |
| 128 | + data = {"text_input": text_input, "max_tokens": max_new_tokens, "stream": True} |
| 129 | + headers = {"Content-Type": "application/json"} |
| 130 | + used_time = [] |
| 131 | + start_time = time.time() |
| 132 | + last_time = start_time |
| 133 | + response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) |
| 134 | + assert response.status_code == 200 |
| 135 | + for line in response.iter_lines(): |
| 136 | + if line: |
| 137 | + current_time = time.time() |
| 138 | + elapsed_time = current_time - last_time |
| 139 | + used_time.append(elapsed_time) |
| 140 | + last_time = current_time |
| 141 | + return used_time |
| 142 | + |
| 143 | + |
| 144 | +def main(): |
| 145 | + parser = argparse.ArgumentParser() |
| 146 | + parser.add_argument("--url", type=str, default="http://127.0.0.1:8000/generate_stream") |
| 147 | + parser.add_argument("--num_clients", type=int, default=100) |
| 148 | + parser.add_argument("--tokenizer_path", type=str, default=None) |
| 149 | + parser.add_argument("--input_num", type=int, default=2000) |
| 150 | + parser.add_argument("--input_len", type=int, default=1024) |
| 151 | + parser.add_argument("--output_len", type=int, default=128) |
| 152 | + parser.add_argument("--server_api", type=str, default="lightllm") |
| 153 | + parser.add_argument("--dump_file", type=str, default="") |
| 154 | + parser.add_argument("--seed", type=int, default=0) |
| 155 | + |
| 156 | + args = parser.parse_args() |
| 157 | + if args.dump_file and os.path.exists(args.dump_file): |
| 158 | + # 读取并输出 JSON 内容 |
| 159 | + with open(args.dump_file, "r") as json_file: |
| 160 | + content = json.load(json_file) |
| 161 | + print(json.dumps(content, indent=4)) |
| 162 | + return |
| 163 | + |
| 164 | + assert args.tokenizer_path is not None |
| 165 | + model_name.append(args.tokenizer_path) |
| 166 | + seed_all(args.seed) |
| 167 | + url = args.url |
| 168 | + tokenizer = get_tokenizer(args.tokenizer_path) |
| 169 | + prompts, input_lens, max_new_tokens = gen_random_data(args.input_len, args.output_len, args.input_num, tokenizer) |
| 170 | + |
| 171 | + percentiles = [25, 50, 75, 90, 95, 99, 100] |
| 172 | + if args.server_api == "lightllm": |
| 173 | + post_stream = post_stream_lightllm |
| 174 | + elif args.server_api == "openai": |
| 175 | + post_stream = post_stream_openai |
| 176 | + elif args.server_api == "triton": |
| 177 | + post_stream = post_stream_triton |
| 178 | + else: |
| 179 | + raise Exception(f"Not support {args.server_api} server_api.") |
| 180 | + |
| 181 | + dump_dict = {} |
| 182 | + dump_dict["backend"] = args.server_api |
| 183 | + dump_dict["clients"] = args.num_clients |
| 184 | + start_time = time.time() |
| 185 | + |
| 186 | + with ThreadPoolExecutor(max_workers=args.num_clients) as executor: |
| 187 | + results = list( |
| 188 | + tqdm( |
| 189 | + executor.map(lambda p: post_stream(url, p[0], p[1]), zip(prompts, max_new_tokens)), |
| 190 | + total=len(prompts), |
| 191 | + desc="Running tests", |
| 192 | + ) |
| 193 | + ) |
| 194 | + end_time = time.time() |
| 195 | + first_token_time = [] |
| 196 | + decode_token_time = [] |
| 197 | + request_time = [] |
| 198 | + final_output_lens = [] |
| 199 | + valid_num = 0 |
| 200 | + for result in results: |
| 201 | + if len(result) > 1: # 统计至少decode出两个token的数据 |
| 202 | + first_token_time.append(result[0]) |
| 203 | + decode_token_time.append(sum(result[1:]) / len(result[1:])) |
| 204 | + request_time.append(sum(result)) |
| 205 | + final_output_lens.append(len(result)) |
| 206 | + valid_num += 1 |
| 207 | + |
| 208 | + print( |
| 209 | + f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n" |
| 210 | + ) |
| 211 | + print(f"Total QPS: {valid_num / (end_time - start_time)}") |
| 212 | + print(f"Avg Input Length: {sum(input_lens) / len(input_lens)}") |
| 213 | + print(f"Avg Output Length: {sum(final_output_lens) / len(final_output_lens)}") |
| 214 | + print(f"Total Throughput: {(sum(input_lens) + sum(final_output_lens)) / (end_time - start_time)} token/s") |
| 215 | + print(f"Input Throughput: {sum(input_lens) / (end_time - start_time)} token/s") |
| 216 | + print(f"Output Throughput: {sum(final_output_lens) / (end_time - start_time)} token/s") |
| 217 | + print("-" * 10) |
| 218 | + dump_dict["request_num"] = valid_num |
| 219 | + dump_dict["Total QPS"] = valid_num / (end_time - start_time) |
| 220 | + dump_dict["Avg Input Length"] = sum(input_lens) / len(input_lens) |
| 221 | + dump_dict["Avg Output Length"] = sum(final_output_lens) / len(final_output_lens) |
| 222 | + dump_dict["Total Throughput"] = (sum(input_lens) + sum(final_output_lens)) / (end_time - start_time) |
| 223 | + dump_dict["Input Throughput"] = sum(input_lens) / (end_time - start_time) |
| 224 | + dump_dict["Output Throughput"] = sum(final_output_lens) / (end_time - start_time) |
| 225 | + |
| 226 | + values = np.percentile(request_time, percentiles) |
| 227 | + request_time_dict = {} |
| 228 | + for percentile, value in zip(percentiles, values): |
| 229 | + print(f"request_time P{percentile}: {value:.6f}s") |
| 230 | + request_time_dict[f"P{percentile}"] = value |
| 231 | + dump_dict["request_time"] = request_time_dict |
| 232 | + print("-" * 10) |
| 233 | + |
| 234 | + first_token_time_dict = {} |
| 235 | + values = np.percentile(first_token_time, percentiles) |
| 236 | + for percentile, value in zip(percentiles, values): |
| 237 | + print(f"first_token_time P{percentile}: {value:.6f}s") |
| 238 | + first_token_time_dict[f"P{percentile}"] = value |
| 239 | + dump_dict["first_token_time_dict"] = first_token_time_dict |
| 240 | + print("-" * 10) |
| 241 | + |
| 242 | + decode_token_time_dict = {} |
| 243 | + values = np.percentile(decode_token_time, percentiles) |
| 244 | + for percentile, value in zip(percentiles, values): |
| 245 | + print(f"decode_token_time P{percentile}: {value * 1000:.6f}ms") |
| 246 | + decode_token_time_dict[f"P{percentile}"] = value * 1000 |
| 247 | + dump_dict["decode_token_time_dict"] = decode_token_time_dict |
| 248 | + print(dump_dict) |
| 249 | + |
| 250 | + if args.dump_file: |
| 251 | + with open(args.dump_file, "w") as json_file: |
| 252 | + json.dump(dump_dict, json_file, indent=4) |
| 253 | + print(f"Results have been written to {args.dump_file}") |
| 254 | + |
| 255 | + |
| 256 | +if __name__ == "__main__": |
| 257 | + main() |
0 commit comments