Skip to content

test: add kv cache event tests for disagg workers #3602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/llm-api/llm_inference_kv_events.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
### Get KV Cache Events

import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig


def main():
pytorch_config = PyTorchConfig(enable_overlap_scheduler=True,
autotuner_enabled=False,
kv_cache_dtype='auto')

llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
tensor_parallel_size=2,
pytorch_backend_config=pytorch_config,
kv_cache_config=trtllm.KvCacheConfig(enable_block_reuse=True,
event_buffer_max_size=1024),
kv_cache_config=KvCacheConfig(enable_block_reuse=True,
event_buffer_max_size=1024),
backend="pytorch")

# Sample prompts having a common prefix.
Expand Down Expand Up @@ -42,7 +43,8 @@ def main():
print(kv_events)

# Got output like follows:
# {'event_id': 0, 'data': {'type': 'created', 'num_blocks_per_cache_level': [101230, 0]}}, {'event_id': 1, 'data': {'type': 'stored', 'parent_hash': None, 'blocks': [{'type': 'stored_block', 'block_hash': 4203099703668305365, 'tokens': [{'type': 'unique_token', 'token_id': 1, 'token_extra_id': 0}, ...
# [{'event_id': 0, 'data': {'type': 'created', 'num_blocks_per_cache_level': [101230, 0]}},
# {'event_id': 1, 'data': {'type': 'stored', 'parent_hash': None, 'blocks': [{'type': 'stored_block', 'block_hash': 4203099703668305365, 'tokens': [{'type': 'unique_token', 'token_id': 1, 'token_extra_id': 0}, ...


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion scripts/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_build_dir(build_dir, build_type):
def clear_folder(folder_path):
for item in os.listdir(folder_path):
item_path = os.path.join(folder_path, item)
if os.path.isdir(item_path):
if os.path.isdir(item_path) and not os.path.islink(item_path):
rmtree(item_path)
else:
os.remove(item_path)
Expand Down
29 changes: 14 additions & 15 deletions tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,34 +891,33 @@ def convert_to_torch_tensor(

class KVCacheEventSerializer:

def get_event_serialize_func(event_type):
@classmethod
def get_event_serialize_func(cls, event_type):
return {
"KVCacheCreatedData": KVCacheEventSerializer._created_to_json,
"KVCacheStoredData": KVCacheEventSerializer._stored_to_json,
"KVCacheStoredBlockData":
KVCacheEventSerializer._stored_block_to_json,
"KVCacheRemovedData": KVCacheEventSerializer._removed_to_json,
"KVCacheUpdatedData": KVCacheEventSerializer._updated_to_json,
"KVCacheCreatedData": cls._created_to_json,
"KVCacheStoredData": cls._stored_to_json,
"KVCacheStoredBlockData": cls._stored_block_to_json,
"KVCacheRemovedData": cls._removed_to_json,
"KVCacheUpdatedData": cls._updated_to_json,
}.get(event_type, None)

@staticmethod
def serialize(events):
@classmethod
def serialize(cls, events):
if events is None:
return None

if not isinstance(events, list):
events = [events]
return cls.to_json_str(events)

return [KVCacheEventSerializer.to_json_str(event) for event in events]
return [cls.to_json_str(event) for event in events]

@staticmethod
def to_json_str(event):
@classmethod
def to_json_str(cls, event):
if event is None:
return {}

event_type = type(event.data).__name__
event_serialize_func = KVCacheEventSerializer.get_event_serialize_func(
event_type)
event_serialize_func = cls.get_event_serialize_func(event_type)
if event_serialize_func is None:
raise ValueError(f"Unknown KVCache event data type: {event_type}")

Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def register_routes(self):
self.app.add_api_route("/v1/models", self.get_model, methods=["GET"])
# TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now
self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"])
# TODO: workaround before ETCD support
self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"])
self.app.add_api_route("/v1/completions",
self.openai_completion,
methods=["POST"])
Expand All @@ -150,6 +152,16 @@ async def get_iteration_stats(self) -> JSONResponse:
stats.append(stat)
return JSONResponse(content=stats)

async def get_kv_cache_events(self) -> JSONResponse:
events = []
try:
async for event in self.llm.get_kv_cache_events_async(2):
events.append(event)
except IndexError:
# queue is empty, no more events
pass
return JSONResponse(content=events)

async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response:

def get_role() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ backend: "pytorch"
pytorch_backend_config:
use_cuda_graph: False
enable_overlap_scheduler: False
autotuner_enabled: False
context_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
event_buffer_max_size: 1024
urls:
- "localhost:8001"
generation_servers:
Expand All @@ -22,5 +24,6 @@ generation_servers:
kv_cache_config:
enable_block_reuse: True
enable_partial_reuse: True
event_buffer_max_size: 1024
urls:
- "localhost:8002"
Loading