diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e365cf6fc86..a5cfeaa07b9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -554,114 +554,135 @@ def do_log_stats(self) -> None: def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: - """Get Stats to be Logged to Prometheus.""" + """Get Stats to be Logged to StdOut and Prometheus.""" now = time.time() - # KV Cache Usage in %. + # System State (_sys suffix). + # Scheduler State + num_running_sys = len(self.scheduler.running) + num_swapped_sys = len(self.scheduler.swapped) + num_waiting_sys = len(self.scheduler.waiting) + + # KV Cache Usage in %. num_total_gpu = self.cache_config.num_gpu_blocks num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) - + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage = 0. + cpu_cache_usage_sys = 0. if num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) - cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) - - # Scheduler State - num_running = len(self.scheduler.running) - num_swapped = len(self.scheduler.swapped) - num_waiting = len(self.scheduler.waiting) - - # Iteration stats if we have scheduler output. - num_prompt_tokens = 0 - num_generation_tokens = 0 - num_prompt_tokens_lst = [] - num_generation_tokens_lst = [] - request_n = [] - request_best_of = [] - time_to_first_tokens = [] - time_per_output_tokens = [] + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Iteration Stats (_iter suffix). + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter = [] + time_per_output_tokens_iter = [] + + # Request Stats (_requests suffix). + # Latency Timings. time_e2e_requests = [] - finished_reason_counter = CollectionsCounter() + time_queue_requests = [] + time_inference_requests = [] + time_prefill_requests = [] + time_decode_requests = [] + # Metadata. + n_requests = [] + best_of_requests = [] + num_prompt_tokens_requests = [] + num_generation_tokens_requests = [] + max_num_generation_tokens_requests = [] + finished_reason_requests = CollectionsCounter() + if scheduler_outputs is not None: - prompt_run = scheduler_outputs.num_prefill_groups > 0 - - # Number of Tokens - if prompt_run: - num_prompt_tokens_lst = [ - len(scheduled_seq_group.seq_group.prompt_token_ids) - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups - ] - num_prompt_tokens = sum(num_prompt_tokens_lst) - num_generation_tokens = sum( - scheduled_seq_group.seq_group.num_seqs() - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - else: - num_generation_tokens = scheduler_outputs.num_batched_tokens - num_generation_tokens_lst = [ - seq.get_output_len() for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups for seq in - scheduled_seq_group.seq_group.get_finished_seqs() - ] - - # Sampling Params - if prompt_run: - request_n = [ - scheduled_seq_group.seq_group.sampling_params.n - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups - ] - request_best_of = [ - scheduled_seq_group.seq_group.sampling_params.best_of - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups - ] - - # Latency Timings - time_last_iters = [] for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - seq_group = scheduled_seq_group.seq_group - # Time since last token. + # Iteration time. # (n.b. updates seq_group.metrics.last_token_time) - time_last_iters.append(seq_group.get_last_latency(now)) - # Time since arrival for all finished requests. - if seq_group.is_finished(): - time_e2e_requests.append(now - - seq_group.metrics.arrival_time) - - time_to_first_tokens = time_last_iters if prompt_run else [] - time_per_output_tokens = [] if prompt_run else time_last_iters + latency = scheduled_seq_group.seq_group.maybe_get_last_latency(now) + + # Number of tokens (for throughput calculations). + # Because of chunked prefill, we can have prefills and decodes in the same batch. + # If token_chunk_size > 1, this seq_group is a chunk of a prefill. + # If token_chunk_size = 1, this seq_group is a decode. + if scheduled_seq_group.token_chunk_size > 1: + num_prompt_tokens_iter += scheduled_seq_group.token_chunk_size + if latency is not None: + time_to_first_tokens_iter.append(latency) + else: + num_generation_tokens_iter += 1 + assert latency is not None + time_per_output_tokens_iter.append(latency) - # Finished Requests - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - if not scheduled_seq_group.seq_group.is_finished(): - continue - finished_reason_counter += CollectionsCounter([ - SequenceStatus.get_finished_reason(seq.status) for seq in - scheduled_seq_group.seq_group.get_finished_seqs() - ]) + seq_group = scheduled_seq_group.seq_group + + # Because of chunked prefill, we can have a single sequence group + # that does multiple prompt_runs. To prevent logging the same metadata + # more than once per request, we standardize on logging request level + # information for finished requests, which can only happen once. + if seq_group.is_finished(): + # Latnecy timings. + time_e2e_requests.append( + now - + seq_group.metrics.arrival_time) + time_queue_requests.append( + seq_group.metrics.first_scheduled_time - + seq_group.metrics.arrival_time) + time_prefill_requests.append( + seq_group.metrics.first_token_time - + seq_group.metrics.first_scheduled_time + ) + time_decode_requests.append( + now - + seq_group.metrics.first_token_time + ) + time_inference_requests.append( + now - seq_group.metrics.first_scheduled_time) + + # Metadata. + n_requests.append(seq_group.sampling_params.n) + best_of_requests.append(seq_group.sampling_params.best_of) + num_prompt_tokens_requests.append( + seq_group.metrics.num_prompt_tokens) + num_generation_tokens_requests.append( + seq_group.metrics.num_generated_tokens) + max_num_generation_tokens_requests.append( + seq_group.metrics.max_num_generated_tokens) + finished_reason_requests += CollectionsCounter([ + SequenceStatus.get_finished_reason(seq.status) for seq in + scheduled_seq_group.seq_group.get_finished_seqs() + ]) return Stats( now=now, - num_running=num_running, - num_swapped=num_swapped, - num_waiting=num_waiting, - gpu_cache_usage=gpu_cache_usage, - cpu_cache_usage=cpu_cache_usage, - finished_reason_counter=finished_reason_counter, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=num_generation_tokens, - num_prompt_tokens_lst=num_prompt_tokens_lst, - num_generation_tokens_lst=num_generation_tokens_lst, - request_n=request_n, - request_best_of=request_best_of, - time_to_first_tokens=time_to_first_tokens, - time_per_output_tokens=time_per_output_tokens, + prompt_run=num_prompt_tokens_iter > 0, + decode_run=(num_prompt_tokens_iter == 0 and + num_generation_tokens_iter > 0), + + # System stats. + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + + # Iteration stats. + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_prompt_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_to_first_tokens_iter, + + # Request stats. time_e2e_requests=time_e2e_requests, + time_inference_requests=time_inference_requests, + time_queue_requests=time_queue_requests, + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + max_num_generation_tokens_requests= + max_num_generation_tokens_requests, + n_requests=n_requests, + best_of_requests=best_of_requests, + finished_reasons_requests=finished_reason_requests ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ee11856e2a2..111de8c4b9d 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,7 +1,7 @@ import time from dataclasses import dataclass from typing import Counter as CollectionsCounter -from typing import Dict, List, Protocol +from typing import Dict, List, Protocol, Union import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -54,23 +54,30 @@ def __init__(self, labelnames: List[str], max_model_len: int): name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - - # Raw stats from last model iteration - self.counter_request_success = Counter( - name="vllm:request_success", - documentation="Count of successfully processed requests.", - labelnames=labelnames + [LABEL_NAME_FINISHED_REASON]) - self.histogram_request_prompt_tokens = Histogram( - name="vllm:request_prompt_tokens", + + # Iteration-level stats. + self.counter_prompt_tokens = Counter( + name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.histogram_request_generation_tokens = Histogram( - name="vllm:request_generation_tokens", + labelnames=labelnames) + self.counter_generation_tokens = Counter( + name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.counter_num_prefill_iterations = Counter( + name="vllm:prefill_iterations_total", + documentation="Number of prefill iterations. " + "Iterations with chunked prefill are counted here.", + labelnames=labelnames), + self.counter_num_decode_iterations = Counter( + name="vllm:decode_iterations_total", + documentation="Number of decode iterations.", + labelnames=labelnames), + self.histogram_iteration_num_tokens = Histogram( + name="vllm:iteration_num_tokens_total", + documentation="Histogram of number of total tokens per iteration.", labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), + buckets=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # get batched_max_num_tokens ) self.histogram_time_to_first_token = Histogram( name="vllm:time_to_first_token_seconds", @@ -88,24 +95,71 @@ def __init__(self, labelnames: List[str], max_model_len: int): 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 ]) - self.histogram_e2e_request_latency = Histogram( + + # Request-level data. + # Latency breakdown + self.histogram_e2e_time_request = Histogram( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) - self.histogram_request_n = Histogram( + self.histogram_queue_time_request = Histogram( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + labelnames=labelnames, + buckets=[0.1, 1.0, 2.5, 5.0, 10.0, 20.0, 50.0, 100]) + self.histogram_inference_time_request = Histogram( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + labelnames=labelnames, + buckets=[0.1, 1.0, 2.5, 5.0, 10.0, 20.0, 50.0, 100]) + self.histogram_decode_time_request = Histogram( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + labelnames=labelnames, + buckets=[0.1, 1.0, 2.5, 5.0, 10.0, 20.0, 50.0, 100]) + self.histogram_prefill_time_request = Histogram( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + labelnames=labelnames, + buckets=[0.1, 1.0, 2.5, 5.0, 10.0, 20.0, 50.0, 100]) + # Metadata. + self.counter_request_success = Counter( + name="vllm:request_success", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [LABEL_NAME_FINISHED_REASON]) + self.histogram_num_prompt_tokens_request = Histogram( + name="vllm:request_num_prompt_tokens", + documentation="Histogram of number of prompt tokens for requests.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_num_generation_tokens_request = Histogram( + name="vllm:request_num_generation_tokens", + documentation= + "Histogram of number of generation tokens for requests.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_max_num_generation_tokens_request = Histogram( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_n_request = Histogram( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", labelnames=labelnames, - buckets=[1, 2, 5, 10, 20], - ) - self.histogram_request_best_of = Histogram( + buckets=[1, 2, 5, 10, 20]) + self.histogram_best_of_request = Histogram( name="vllm:request_params_best_of", documentation="Histogram of the best_of request parameter.", labelnames=labelnames, - buckets=[1, 2, 5, 10, 20], - ) - + buckets=[1, 2, 5, 10, 20]) + # Legacy metrics self.gauge_avg_prompt_throughput = Gauge( name="vllm:avg_prompt_throughput_toks_per_s", @@ -117,17 +171,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="Average generation throughput in tokens/s.", labelnames=labelnames, ) - # Deprecated in favor of vllm:request_prompt_tokens_sum - self.counter_prompt_tokens = Counter( - name="vllm:prompt_tokens_total", - documentation="Number of prefill tokens processed.", - labelnames=labelnames) - # Deprecated in favor of vllm:request_generation_tokens_sum - self.counter_generation_tokens = Counter( - name="vllm:generation_tokens_total", - documentation="Number of generation tokens processed.", - labelnames=labelnames) - + # end-metrics-definitions @@ -158,25 +202,36 @@ def build_1_2_5_buckets(max_value: int): class Stats: """Created by LLMEngine for use by StatLogger.""" now: float - - # System stats. - num_running: int - num_waiting: int - num_swapped: int - gpu_cache_usage: float - cpu_cache_usage: float - - # Raw stats from last model iteration. - finished_reason_counter: CollectionsCounter[str] - num_prompt_tokens: int - num_generation_tokens: int - num_prompt_tokens_lst: List[int] - num_generation_tokens_lst: List[int] - request_n: List[int] - request_best_of: List[int] - time_to_first_tokens: List[float] - time_per_output_tokens: List[float] + prompt_run: bool + decode_run: bool + + # System stats (should have _sys suffix). + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + + # Iteration stats (should have _iter suffix). + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + + # Request Stats (should have _requests suffix). + # Latency time_e2e_requests: List[float] + time_queue_requests: List[float] + time_inference_requests: List[float] + time_prefill_requests: List[float] + time_decode_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + max_num_generation_tokens_requests: List[int] + best_of_requests: List[int] + n_requests: List[int] + finished_reasons_requests: CollectionsCounter[str] class SupportsMetricsInfo(Protocol): @@ -215,57 +270,84 @@ def _local_interval_elapsed(self, now: float) -> bool: return elapsed_time > self.local_interval def _log_prometheus(self, stats: Stats) -> None: - # Set system stat gauges. - self.metrics.gauge_scheduler_running.labels(**self.labels).set( - stats.num_running) - self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( - stats.num_swapped) - self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( - stats.num_waiting) - self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( - stats.gpu_cache_usage) - self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( - stats.cpu_cache_usage) - - # Add to token counters. - self.metrics.counter_prompt_tokens.labels(**self.labels).inc( - stats.num_prompt_tokens) - self.metrics.counter_generation_tokens.labels(**self.labels).inc( - stats.num_generation_tokens) - - # Add to request counters. - for finished_reason, count in stats.finished_reason_counter.items(): - self.metrics.counter_request_success.labels( - **{ - **self.labels, - LABEL_NAME_FINISHED_REASON: finished_reason, - }).inc(count) - - # Observe number of tokens in histograms. - for val in stats.num_prompt_tokens_lst: - self.metrics.histogram_request_prompt_tokens.labels( - **self.labels).observe(val) - for val in stats.num_generation_tokens_lst: - self.metrics.histogram_request_generation_tokens.labels( - **self.labels).observe(val) - - # Observe sampling params in histograms. - for n in stats.request_n: - self.metrics.histogram_request_n.labels(**self.labels).observe(n) - for best_of in stats.request_best_of: - self.metrics.histogram_request_best_of.labels( - **self.labels).observe(best_of) - - # Observe request level latencies in histograms. - for ttft in stats.time_to_first_tokens: - self.metrics.histogram_time_to_first_token.labels( - **self.labels).observe(ttft) - for tpot in stats.time_per_output_tokens: - self.metrics.histogram_time_per_output_token.labels( - **self.labels).observe(tpot) - for e2e in stats.time_e2e_requests: - self.metrics.histogram_e2e_request_latency.labels( - **self.labels).observe(e2e) + # System state data. + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + + # Iteration level data. + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_counter(self.metrics.counter_num_prefill_iterations, + (1 if stats.prompt_run else 0)) + self._log_counter(self.metrics.counter_num_decode_iterations, + (1 if stats.decode_run else 0)) + self._log_histogram(self.metrics.histogram_iteration_num_tokens, + [stats.num_prompt_tokens_iter + stats.num_generation_tokens_iter]) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data. + # Latency. + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + self._log_histogram(self.metrics.histogram_queue_time_request, + stats.time_queue_requests) + self._log_histogram(self.metrics.histogram_inference_time_request, + stats.time_inference_requests) + self._log_histogram(self.metrics.histogram_prefill_time_request, + stats.time_decode_requests) + self._log_histogram(self.metrics.histogram_decode_time_request, + stats.time_prefill_requests) + + # Metadata. + self._log_counter_labels(self.metrics.counter_request_success, + stats.finished_reasons_requests, + LABEL_NAME_FINISHED_REASON) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram(self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_max_num_generation_tokens_request, + stats.max_num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, + stats.n_requests) + self._log_histogram(self.metrics.histogram_best_of_request, + stats.best_of_requests) + + + + def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter: Counter, + data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram: Histogram, + data: Union[List[int], List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: diff --git a/vllm/sequence.py b/vllm/sequence.py index 92362a9a5d2..2680fe95c97 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -80,13 +80,14 @@ class RequestMetrics: Attributes: arrival_time: The time when the request arrived. + last_token_time: The time when the last token was generated. first_scheduled_time: The time when the request was first scheduled. first_token_time: The time when the first token was generated. time_in_queue: The time the request spent in the queue. finished_time: The time when the request was finished. """ arrival_time: float - last_token_time: float + last_token_time: Optional[float] first_scheduled_time: Optional[float] first_token_time: Optional[float] time_in_queue: Optional[float] @@ -442,15 +443,29 @@ def prompt_token_ids(self) -> List[int]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - def get_last_latency(self, now: float) -> float: - """Gets last token latency for Request level timings.""" + def maybe_get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If prompt not done, return None. + if self.get_num_uncomputed_tokens() > 0: + return None + + # Otherwise return token latency. latency = now - self.metrics.last_token_time self.metrics.last_token_time = now return latency def maybe_set_first_token_time(self, time: float) -> None: """Sets the first token time for Request level timings.""" - if self.metrics.first_token_time is None: + # For the case of chunked_prefill, get_num_uncomputed_tokens=0 + # implies that we have computed the entire prompt and have + # therefore generated the first token. + # + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None and + self.get_num_uncomputed_tokens() == 0): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: