Skip to content

Commit 7f4a37d

Browse files
njhillIsotr0py
authored andcommitted
[Benchmark] More accurate TPOT calc in benchmark_serving.py (vllm-project#12288)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 0b95111 commit 7f4a37d

File tree

2 files changed

+66
-46
lines changed

2 files changed

+66
-46
lines changed

benchmarks/backend_request_func.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class RequestFuncOutput:
3535
generated_text: str = ""
3636
success: bool = False
3737
latency: float = 0.0
38+
output_tokens: int = 0
3839
ttft: float = 0.0 # Time to first token
3940
itl: List[float] = field(
4041
default_factory=list) # List of inter-token latencies
@@ -156,7 +157,7 @@ async def async_request_trt_llm(
156157
timestamp = time.perf_counter()
157158
# First token
158159
if ttft == 0.0:
159-
ttft = time.perf_counter() - st
160+
ttft = timestamp - st
160161
output.ttft = ttft
161162

162163
# Decoding phase
@@ -245,6 +246,9 @@ async def async_request_openai_completions(
245246
"logprobs": request_func_input.logprobs,
246247
"stream": True,
247248
"ignore_eos": request_func_input.ignore_eos,
249+
"stream_options": {
250+
"include_usage": True,
251+
},
248252
}
249253
if request_func_input.extra_body:
250254
payload.update(request_func_input.extra_body)
@@ -256,7 +260,6 @@ async def async_request_openai_completions(
256260
output.prompt_len = request_func_input.prompt_len
257261

258262
generated_text = ""
259-
ttft = 0.0
260263
st = time.perf_counter()
261264
most_recent_timestamp = st
262265
try:
@@ -271,15 +274,16 @@ async def async_request_openai_completions(
271274

272275
chunk = chunk_bytes.decode("utf-8").removeprefix(
273276
"data: ")
274-
if chunk == "[DONE]":
275-
latency = time.perf_counter() - st
276-
else:
277+
if chunk != "[DONE]":
277278
data = json.loads(chunk)
278279

279280
# NOTE: Some completion API might have a last
280281
# usage summary response without a token so we
281282
# want to check a token was generated
282-
if data["choices"][0]["text"]:
283+
if choices := data.get("choices"):
284+
# Note that text could be empty here
285+
# e.g. for special tokens
286+
text = choices[0].get("text")
283287
timestamp = time.perf_counter()
284288
# First token
285289
if not first_chunk_received:
@@ -293,7 +297,10 @@ async def async_request_openai_completions(
293297
most_recent_timestamp)
294298

295299
most_recent_timestamp = timestamp
296-
generated_text += data["choices"][0]["text"]
300+
generated_text += text
301+
elif usage := data.get("usage"):
302+
output.output_tokens = usage.get(
303+
"completion_tokens")
297304
if first_chunk_received:
298305
output.success = True
299306
else:
@@ -302,7 +309,7 @@ async def async_request_openai_completions(
302309
"Never received a valid chunk to calculate TTFT."
303310
"This response will be marked as failed!")
304311
output.generated_text = generated_text
305-
output.latency = latency
312+
output.latency = most_recent_timestamp - st
306313
else:
307314
output.error = response.reason or ""
308315
output.success = False
@@ -342,6 +349,9 @@ async def async_request_openai_chat_completions(
342349
"max_completion_tokens": request_func_input.output_len,
343350
"stream": True,
344351
"ignore_eos": request_func_input.ignore_eos,
352+
"stream_options": {
353+
"include_usage": True,
354+
},
345355
}
346356
if request_func_input.extra_body:
347357
payload.update(request_func_input.extra_body)
@@ -368,31 +378,32 @@ async def async_request_openai_chat_completions(
368378

369379
chunk = chunk_bytes.decode("utf-8").removeprefix(
370380
"data: ")
371-
if chunk == "[DONE]":
372-
latency = time.perf_counter() - st
373-
else:
381+
if chunk != "[DONE]":
374382
timestamp = time.perf_counter()
375383
data = json.loads(chunk)
376384

377-
delta = data["choices"][0]["delta"]
378-
if delta.get("content", None):
385+
if choices := data.get("choices"):
386+
content = choices[0]["delta"].get("content")
379387
# First token
380388
if ttft == 0.0:
381-
ttft = time.perf_counter() - st
389+
ttft = timestamp - st
382390
output.ttft = ttft
383391

384392
# Decoding phase
385393
else:
386394
output.itl.append(timestamp -
387395
most_recent_timestamp)
388396

389-
generated_text += delta["content"]
397+
generated_text += content
398+
elif usage := data.get("usage"):
399+
output.output_tokens = usage.get(
400+
"completion_tokens")
390401

391402
most_recent_timestamp = timestamp
392403

393404
output.generated_text = generated_text
394405
output.success = True
395-
output.latency = latency
406+
output.latency = most_recent_timestamp - st
396407
else:
397408
output.error = response.reason or ""
398409
output.success = False

benchmarks/benchmark_serving.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import argparse
2626
import asyncio
2727
import base64
28+
import gc
2829
import io
2930
import json
3031
import os
@@ -423,7 +424,7 @@ def calculate_metrics(
423424
tokenizer: PreTrainedTokenizerBase,
424425
selected_percentile_metrics: List[str],
425426
selected_percentiles: List[float],
426-
gootput_config_dict: Dict[str, float],
427+
goodput_config_dict: Dict[str, float],
427428
) -> Tuple[BenchmarkMetrics, List[int]]:
428429
actual_output_lens: List[int] = []
429430
total_input = 0
@@ -436,19 +437,23 @@ def calculate_metrics(
436437
e2els: List[float] = []
437438
for i in range(len(outputs)):
438439
if outputs[i].success:
439-
# We use the tokenizer to count the number of output tokens for all
440-
# serving backends instead of looking at len(outputs[i].itl) since
441-
# multiple output tokens may be bundled together
442-
# Note : this may inflate the output token count slightly
443-
output_len = len(
444-
tokenizer(outputs[i].generated_text,
445-
add_special_tokens=False).input_ids)
440+
output_len = outputs[i].output_tokens
441+
442+
if output_len is None:
443+
# We use the tokenizer to count the number of output tokens
444+
# for some serving backends instead of looking at
445+
# len(outputs[i].itl) since multiple output tokens may be
446+
# bundled together
447+
# Note : this may inflate the output token count slightly
448+
output_len = len(
449+
tokenizer(outputs[i].generated_text,
450+
add_special_tokens=False).input_ids)
446451
actual_output_lens.append(output_len)
447452
total_input += input_requests[i][1]
448453
tpot = 0
449454
if output_len > 1:
450-
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
451-
1)
455+
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
456+
tpot = latency_minus_ttft / (output_len - 1)
452457
tpots.append(tpot)
453458
# Note: if output_len <= 1, we regard tpot as 0 for goodput
454459
all_tpots.append(tpot)
@@ -459,21 +464,21 @@ def calculate_metrics(
459464
else:
460465
actual_output_lens.append(0)
461466

462-
if gootput_config_dict:
467+
if goodput_config_dict:
463468
valid_metrics = []
464469
slo_values = []
465470

466-
if "ttft" in gootput_config_dict:
471+
if "ttft" in goodput_config_dict:
467472
valid_metrics.append(ttfts)
468-
slo_values.append(gootput_config_dict["ttft"] /
473+
slo_values.append(goodput_config_dict["ttft"] /
469474
MILLISECONDS_TO_SECONDS_CONVERSION)
470-
if "tpot" in gootput_config_dict:
475+
if "tpot" in goodput_config_dict:
471476
valid_metrics.append(all_tpots)
472-
slo_values.append(gootput_config_dict["tpot"] /
477+
slo_values.append(goodput_config_dict["tpot"] /
473478
MILLISECONDS_TO_SECONDS_CONVERSION)
474-
if "e2el" in gootput_config_dict:
479+
if "e2el" in goodput_config_dict:
475480
valid_metrics.append(e2els)
476-
slo_values.append(gootput_config_dict["e2el"] /
481+
slo_values.append(goodput_config_dict["e2el"] /
477482
MILLISECONDS_TO_SECONDS_CONVERSION)
478483

479484
for req_metric in zip(*valid_metrics):
@@ -537,7 +542,7 @@ async def benchmark(
537542
selected_percentile_metrics: List[str],
538543
selected_percentiles: List[str],
539544
ignore_eos: bool,
540-
gootput_config_dict: Dict[str, float],
545+
goodput_config_dict: Dict[str, float],
541546
max_concurrency: Optional[int],
542547
):
543548
if backend in ASYNC_REQUEST_FUNCS:
@@ -661,7 +666,7 @@ async def limited_request_func(request_func_input, pbar):
661666
tokenizer=tokenizer,
662667
selected_percentile_metrics=selected_percentile_metrics,
663668
selected_percentiles=selected_percentiles,
664-
gootput_config_dict=gootput_config_dict,
669+
goodput_config_dict=goodput_config_dict,
665670
)
666671

667672
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
@@ -673,7 +678,7 @@ async def limited_request_func(request_func_input, pbar):
673678
metrics.total_output))
674679
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
675680
metrics.request_throughput))
676-
if gootput_config_dict:
681+
if goodput_config_dict:
677682
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
678683
metrics.request_goodput))
679684
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
@@ -688,7 +693,7 @@ async def limited_request_func(request_func_input, pbar):
688693
"total_output_tokens": metrics.total_output,
689694
"request_throughput": metrics.request_throughput,
690695
"request_goodput:":
691-
metrics.request_goodput if gootput_config_dict else None,
696+
metrics.request_goodput if goodput_config_dict else None,
692697
"output_throughput": metrics.output_throughput,
693698
"total_token_throughput": metrics.total_token_throughput,
694699
"input_lens": [output.prompt_len for output in outputs],
@@ -744,11 +749,11 @@ def process_one_metric(
744749

745750
def check_goodput_args(args):
746751
# Check and parse goodput arguments
747-
gootput_config_dict = {}
752+
goodput_config_dict = {}
748753
VALID_NAMES = ["ttft", "tpot", "e2el"]
749754
if args.goodput:
750-
gootput_config_dict = parse_goodput(args.goodput)
751-
for slo_name, slo_val in gootput_config_dict.items():
755+
goodput_config_dict = parse_goodput(args.goodput)
756+
for slo_name, slo_val in goodput_config_dict.items():
752757
if slo_name not in VALID_NAMES:
753758
raise ValueError(
754759
f"Invalid metric name found, {slo_name}: {slo_val}. "
@@ -759,22 +764,22 @@ def check_goodput_args(args):
759764
f"Invalid value found, {slo_name}: {slo_val}. "
760765
"The service level objective value should be "
761766
"non-negative.")
762-
return gootput_config_dict
767+
return goodput_config_dict
763768

764769

765770
def parse_goodput(slo_pairs):
766-
gootput_config_dict = {}
771+
goodput_config_dict = {}
767772
try:
768773
for slo_pair in slo_pairs:
769774
slo_name, slo_val = slo_pair.split(":")
770-
gootput_config_dict[slo_name] = float(slo_val)
775+
goodput_config_dict[slo_name] = float(slo_val)
771776
except ValueError as err:
772777
raise argparse.ArgumentTypeError(
773778
"Invalid format found for service level objectives. "
774779
"Specify service level objectives for goodput as \"KEY:VALUE\" "
775780
"pairs, where the key is a metric name, and the value is a "
776781
"number in milliseconds.") from err
777-
return gootput_config_dict
782+
return goodput_config_dict
778783

779784

780785
def main(args: argparse.Namespace):
@@ -874,7 +879,11 @@ def main(args: argparse.Namespace):
874879
else:
875880
raise ValueError(f"Unknown dataset: {args.dataset_name}")
876881

877-
gootput_config_dict = check_goodput_args(args)
882+
goodput_config_dict = check_goodput_args(args)
883+
884+
# Avoid GC processing "static" data - reduce pause times.
885+
gc.collect()
886+
gc.freeze()
878887

879888
benchmark_result = asyncio.run(
880889
benchmark(
@@ -896,7 +905,7 @@ def main(args: argparse.Namespace):
896905
float(p) for p in args.metric_percentiles.split(",")
897906
],
898907
ignore_eos=args.ignore_eos,
899-
gootput_config_dict=gootput_config_dict,
908+
goodput_config_dict=goodput_config_dict,
900909
max_concurrency=args.max_concurrency,
901910
))
902911

0 commit comments

Comments
 (0)