diff --git a/examples/llm-api/llm_inference_kv_events.py b/examples/llm-api/llm_inference_kv_events.py index d07ad75be6..69b9dc95a2 100644 --- a/examples/llm-api/llm_inference_kv_events.py +++ b/examples/llm-api/llm_inference_kv_events.py @@ -1,19 +1,20 @@ ### Get KV Cache Events -import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig +from tensorrt_llm.llmapi import KvCacheConfig def main(): pytorch_config = PyTorchConfig(enable_overlap_scheduler=True, + autotuner_enabled=False, kv_cache_dtype='auto') llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", tensor_parallel_size=2, pytorch_backend_config=pytorch_config, - kv_cache_config=trtllm.KvCacheConfig(enable_block_reuse=True, - event_buffer_max_size=1024), + kv_cache_config=KvCacheConfig(enable_block_reuse=True, + event_buffer_max_size=1024), backend="pytorch") # Sample prompts having a common prefix. @@ -42,7 +43,8 @@ def main(): print(kv_events) # Got output like follows: - # {'event_id': 0, 'data': {'type': 'created', 'num_blocks_per_cache_level': [101230, 0]}}, {'event_id': 1, 'data': {'type': 'stored', 'parent_hash': None, 'blocks': [{'type': 'stored_block', 'block_hash': 4203099703668305365, 'tokens': [{'type': 'unique_token', 'token_id': 1, 'token_extra_id': 0}, ... + # [{'event_id': 0, 'data': {'type': 'created', 'num_blocks_per_cache_level': [101230, 0]}}, + # {'event_id': 1, 'data': {'type': 'stored', 'parent_hash': None, 'blocks': [{'type': 'stored_block', 'block_hash': 4203099703668305365, 'tokens': [{'type': 'unique_token', 'token_id': 1, 'token_extra_id': 0}, ... if __name__ == '__main__': diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 8d981a32de..0e52896355 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -59,7 +59,7 @@ def get_build_dir(build_dir, build_type): def clear_folder(folder_path): for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) - if os.path.isdir(item_path): + if os.path.isdir(item_path) and not os.path.islink(item_path): rmtree(item_path) else: os.remove(item_path) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 816bd23f41..49cdfb5a35 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -891,34 +891,33 @@ def convert_to_torch_tensor( class KVCacheEventSerializer: - def get_event_serialize_func(event_type): + @classmethod + def get_event_serialize_func(cls, event_type): return { - "KVCacheCreatedData": KVCacheEventSerializer._created_to_json, - "KVCacheStoredData": KVCacheEventSerializer._stored_to_json, - "KVCacheStoredBlockData": - KVCacheEventSerializer._stored_block_to_json, - "KVCacheRemovedData": KVCacheEventSerializer._removed_to_json, - "KVCacheUpdatedData": KVCacheEventSerializer._updated_to_json, + "KVCacheCreatedData": cls._created_to_json, + "KVCacheStoredData": cls._stored_to_json, + "KVCacheStoredBlockData": cls._stored_block_to_json, + "KVCacheRemovedData": cls._removed_to_json, + "KVCacheUpdatedData": cls._updated_to_json, }.get(event_type, None) - @staticmethod - def serialize(events): + @classmethod + def serialize(cls, events): if events is None: return None if not isinstance(events, list): - events = [events] + return cls.to_json_str(events) - return [KVCacheEventSerializer.to_json_str(event) for event in events] + return [cls.to_json_str(event) for event in events] - @staticmethod - def to_json_str(event): + @classmethod + def to_json_str(cls, event): if event is None: return {} event_type = type(event.data).__name__ - event_serialize_func = KVCacheEventSerializer.get_event_serialize_func( - event_type) + event_serialize_func = cls.get_event_serialize_func(event_type) if event_serialize_func is None: raise ValueError(f"Unknown KVCache event data type: {event_type}") diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 85c2fe60dd..e86a49a3a2 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -126,6 +126,8 @@ def register_routes(self): self.app.add_api_route("/v1/models", self.get_model, methods=["GET"]) # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) + # TODO: workaround before ETCD support + self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) self.app.add_api_route("/v1/completions", self.openai_completion, methods=["POST"]) @@ -150,6 +152,16 @@ async def get_iteration_stats(self) -> JSONResponse: stats.append(stat) return JSONResponse(content=stats) + async def get_kv_cache_events(self) -> JSONResponse: + events = [] + try: + async for event in self.llm.get_kv_cache_events_async(2): + events.append(event) + except IndexError: + # queue is empty, no more events + pass + return JSONResponse(content=events) + async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response: def get_role() -> str: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml index 4f67ee52c8..d2009507f6 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml @@ -6,6 +6,7 @@ backend: "pytorch" pytorch_backend_config: use_cuda_graph: False enable_overlap_scheduler: False + autotuner_enabled: False context_servers: num_instances: 1 tensor_parallel_size: 1 @@ -13,6 +14,7 @@ context_servers: kv_cache_config: enable_block_reuse: True enable_partial_reuse: True + event_buffer_max_size: 1024 urls: - "localhost:8001" generation_servers: @@ -22,5 +24,6 @@ generation_servers: kv_cache_config: enable_block_reuse: True enable_partial_reuse: True + event_buffer_max_size: 1024 urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_workers.py b/tests/integration/defs/disaggregated/test_workers.py index 564a437c47..8c2f16e4f3 100644 --- a/tests/integration/defs/disaggregated/test_workers.py +++ b/tests/integration/defs/disaggregated/test_workers.py @@ -9,6 +9,7 @@ import aiohttp import pytest import yaml +from transformers import AutoTokenizer logging.basicConfig(level=logging.INFO) MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" @@ -17,13 +18,13 @@ def get_ctx_gen_server_urls_from_cfg(config_file: str): with open(config_file, 'r') as file: config = yaml.safe_load(file) - ctx_servers = [] - gen_servers = [] - for server in config["context_servers"]["urls"]: - ctx_servers.append("http://" + server) - for server in config["generation_servers"]["urls"]: - gen_servers.append("http://" + server) - return ctx_servers, gen_servers + ctx_servers = [] + gen_servers = [] + for server in config["context_servers"]["urls"]: + ctx_servers.append("http://" + server) + for server in config["generation_servers"]["urls"]: + gen_servers.append("http://" + server) + return ctx_servers, gen_servers def run_disaggregated_workers( @@ -45,6 +46,7 @@ def run_disaggregated_workers( str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', config_file ] + logging.info(f"Running workers with command: {' '.join(workers_cmd)}") workers_proc = subprocess.Popen(workers_cmd, stdout=stdout, stderr=subprocess.STDOUT, @@ -105,6 +107,29 @@ async def send_disagg_request(self, session: aiohttp.ClientSession, gen_response = await self.send_request(session, gen_url, gen_request) return gen_response + async def query_kv_cache_events(self, session: aiohttp.ClientSession, + url: str): + async with session.post(url + "/kv_cache_events") as response: + events_raw = await response.json() + + events = [] + for event_raw in events_raw: + event = {"id": event_raw["event_id"]} | event_raw["data"] + if event["type"] == "stored": + for block in event["blocks"]: + block["token_id"] = [ + token["token_id"] for token in block["tokens"] + ] + block["token_extra_id"] = [ + token["token_extra_id"] for token in block["tokens"] + ] + # TODO: check by BlockKey::usesExtraIds + if not any(block["token_extra_id"]): + del block["token_extra_id"] + del block["tokens"] + events.append(event) + return events + async def check_server_ready(self, session: aiohttp.ClientSession, server_url: str) -> bool: try: @@ -195,6 +220,161 @@ async def test_multi_round_request(self, await asyncio.gather(*chat_threads) +class CacheBlockMeta: + + def __init__(self, hash: int, parent_hash: Optional[int] = None): + self.hash = hash + self.parent_hash = parent_hash + # TODO: maintain next_hashes for partial matching + + def __str__(self): + if self.parent_hash is None: + return f"CacheBlockMeta({self.hash:016x})" + else: + return f"CacheBlockMeta({self.hash:016x}, {self.parent_hash:016x})" + + def __repr__(self): + return self.__str__() + + +# TODO: use pybind-ed BlockKeyHasher +def block_key_hasher(token_ids: List[int], + parent_hash: Optional[int] = None) -> int: + mask32 = 0xffff_ffff + mask64 = 0xffff_ffff_ffff_ffff + seed = len(token_ids) + if parent_hash is not None: + seed ^= (parent_hash * 0xbf58476d1ce4e5b9) & mask64 + for token_id in token_ids: + a = token_id & mask32 + a = (((a >> 16) ^ a) * 0x45d9f3b) & mask32 + a = (((a >> 16) ^ a) * 0x45d9f3b) & mask32 + a = (a >> 16) ^ a + seed ^= (((a + 0x9e3779b9) & mask32) + ((seed << 6) & mask64) + + (seed >> 2)) & mask64 + # TODO: handle token_extra_id and lora_task_id + return seed & mask64 + + +class KvCacheBlockMap: + + def __init__(self): + self.kv_blocks: dict[int, CacheBlockMeta] = {} + + def update_with_events(self, events: List[dict]): + for event in events: + if event["type"] == "stored": + parent_hash = event["parent_hash"] + for block in event["blocks"]: + block_hash = block["block_hash"] + self.kv_blocks[block_hash] = CacheBlockMeta( + block_hash, parent_hash) + elif event["type"] == "removed": + block_hashes = event["block_hashes"] + for block_hash in block_hashes: + self.kv_blocks.pop(block_hash, None) + + def get_block_match_count(self, block_hashes: List[int]) -> int: + count = 0 + for block_hash in block_hashes: + if block_hash in self.kv_blocks: + count += 1 + else: + break + return count + + def __str__(self): + return f"ServerState(active_requests={self.active_requests}, kv_blocks={', '.join(str(block) for block in self.kv_blocks.values())})" + + def __repr__(self): + return self.__str__() + + +class KvCacheEventWorkerTester(BasicWorkerTester): + + def __init__(self, + ctx_servers: List[str], + gen_servers: List[str], + req_timeout_secs: int = 180, + server_start_timeout_secs: int = 180): + super().__init__(ctx_servers, gen_servers, req_timeout_secs, + server_start_timeout_secs) + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + self.kv_cache_block_maps = {} + for ctx_server in ctx_servers: + self.kv_cache_block_maps[ctx_server] = KvCacheBlockMap() + for gen_server in gen_servers: + if gen_server not in self.kv_cache_block_maps: + self.kv_cache_block_maps[gen_server] = KvCacheBlockMap() + + async def send_request(self, session: aiohttp.ClientSession, url: str, + request: dict) -> dict: + response = await super().send_request(session, url, request) + + events = await self.query_kv_cache_events(session, url) + self.kv_cache_block_maps[url].update_with_events(events) + return response + + async def multi_round_request(self, + session: aiohttp.ClientSession, + init_prompt: str, + max_rounds: int, + check_match_count: bool = True): + request = { + "model": MODEL_NAME, + "prompt": init_prompt, + "max_tokens": 64, + "temperature": 0.0, + } + tokens_per_block = 32 # TODO: read from config + prev_ctx_match_count = 0 + prev_gen_match_count = 0 + for i in range(max_rounds): + # split tokens into blocks and check block match count by hash + tokens = self.tokenizer(request["prompt"])["input_ids"] + block_hashes = [] + for t in range(0, len(tokens), tokens_per_block): + block_hashes.append( + block_key_hasher(tokens[t:t + tokens_per_block], + None if t == 0 else block_hashes[-1])) + ctx_match_count = self.kv_cache_block_maps[ + self.ctx_servers[0]].get_block_match_count(block_hashes) + gen_match_count = self.kv_cache_block_maps[ + self.gen_servers[0]].get_block_match_count(block_hashes) + assert ctx_match_count >= prev_ctx_match_count + assert gen_match_count >= prev_gen_match_count + + response = await self.send_disagg_request(session, + self.ctx_servers[0], + self.gen_servers[0], + request) + logging.info( + f"Received response {i}: {repr(response['choices'][0]['text'])}" + ) + prev_ctx_match_count = ctx_match_count + prev_gen_match_count = gen_match_count + request["prompt"] += response["choices"][0]["text"] + + if check_match_count: + assert ctx_match_count > 0 + assert gen_match_count >= ctx_match_count + return request["prompt"] + + async def test_multi_round_request(self, + init_prompts: List[str], + max_rounds: int = 8): + async with await self.new_session() as session: + chat_threads = [ + self.multi_round_request(session, prompt, max_rounds, False) + for prompt in init_prompts + ] + prompts = await asyncio.gather(*chat_threads) + await asyncio.gather(*[ + self.multi_round_request(session, prompt, 1, True) + for prompt in prompts + ]) + + def prepare_model(llama_model_root: str, llm_venv): src_dst_dict = { llama_model_root: @@ -214,15 +394,14 @@ def test_workers_conditional_disaggregation(disaggregated_test_root, config_file = os.path.join(disaggregated_test_root, 'test_configs/disagg_config_cache_reuse.yaml') prepare_model(llama_model_root, llm_venv) + cwd = llm_venv.get_working_directory() - with open( - os.path.join(llm_venv.get_working_directory(), - 'output_workers.log'), 'w') as log_file: + with open(os.path.join(cwd, 'output_workers.log'), 'w') as log_file: workers_proc, ctx_servers, gen_servers = run_disaggregated_workers( config_file=config_file, stdout=log_file, env=llm_venv._new_env, - cwd=llm_venv.get_working_directory(), + cwd=cwd, num_ranks=2) try: tester = ConditionalWorkerTester(ctx_servers, gen_servers) @@ -236,3 +415,34 @@ def test_workers_conditional_disaggregation(disaggregated_test_root, finally: workers_proc.terminate() workers_proc.wait() + + +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_workers_kv_cache_events(disaggregated_test_root, + disaggregated_example_root, llm_venv, + llama_model_root): + config_file = os.path.join(disaggregated_test_root, + 'test_configs/disagg_config_cache_reuse.yaml') + prepare_model(llama_model_root, llm_venv) + cwd = llm_venv.get_working_directory() + + with open(os.path.join(cwd, 'output_workers.log'), 'w') as log_file: + workers_proc, ctx_servers, gen_servers = run_disaggregated_workers( + config_file=config_file, + stdout=log_file, + env=llm_venv._new_env, + cwd=cwd, + num_ranks=2) + try: + tester = KvCacheEventWorkerTester(ctx_servers, gen_servers) + prompts_file = os.path.join(disaggregated_example_root, + 'clients/prompts.json') + with open(prompts_file, 'r') as f: + prompts = json.load(f) + asyncio.run(tester.test_multi_round_request(prompts, 6)) + except Exception as e: + raise e + finally: + workers_proc.terminate() + workers_proc.wait() diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 7ab0e981f7..0cfb98a021 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -74,6 +74,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap_cuda_graph[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8] - disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0] + - disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0] - condition: ranges: system_gpu_count: diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index a2a96e0912..a98529af8f 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -1,8 +1,6 @@ import asyncio import time -import pytest - import tensorrt_llm from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest @@ -115,19 +113,18 @@ def test_expected_kv_cache_events(): assert events and len(events) >= 2 for event in events: if event: - if event[0]["event_id"] == 0: - assert event[0]["data"]["type"] == "created" - elif event[0]["event_id"] == 1: - assert event[0]["data"]["type"] == "stored" + if event["event_id"] == 0: + assert event["data"]["type"] == "created" + elif event["event_id"] == 1: + assert event["data"]["type"] == "stored" -@pytest.mark.skip("https://nvbugs/5150466: flaky fail") def test_kv_cache_event_async_api(): llm = create_llm() sampling_params = SamplingParams(max_tokens=6, temperature=0.01) prompt = "Hello, my name is" - async def task0(): + async def generate(): async for output in llm.generate_async(prompt, streaming=True, sampling_params=sampling_params): @@ -135,16 +132,16 @@ async def task0(): events = [] - async def task1(): + async def get_events(): async for event in llm.get_kv_cache_events_async(): events.append(event) assert events async def main(): - await asyncio.gather(task0(), task1()) - for i in range(2): - await asyncio.gather(task0(), task1()) + await generate() + await asyncio.gather(generate(), get_events()) + await asyncio.gather(generate(), get_events()) asyncio.run(main()) @@ -166,9 +163,9 @@ def test_llm_kv_events_api(): while events1: event = events1.pop(0) if event: - assert event[0]["event_id"] == 1 - assert event[0]["data"]["type"] == "stored" - assert len(event[0]["data"]["blocks"]) == 5 + assert event["event_id"] == 1 + assert event["data"]["type"] == "stored" + assert len(event["data"]["blocks"]) == 5 _ = llm.generate(requests[1], sampling_params=sampling_params) events2 = llm.get_kv_cache_events(5) @@ -176,18 +173,18 @@ def test_llm_kv_events_api(): while events2: event = events2.pop(0) if event: - if event[0]["event_id"] == 2: + if event["event_id"] == 2: # 2 removed events needed # should be a removed event to make space for context block - assert event[0]["data"]["type"] == "removed" - assert event[0]["data"]["block_hashes"] - elif event[0]["event_id"] == 3: - assert event[0]["data"]["type"] == "removed" - assert event[0]["data"]["block_hashes"] + assert event["data"]["type"] == "removed" + assert event["data"]["block_hashes"] + elif event["event_id"] == 3: + assert event["data"]["type"] == "removed" + assert event["data"]["block_hashes"] # stored event for 2nd request - elif event[0]["event_id"] == 4: - assert event[0]["data"]["type"] == "stored" - assert len(event[0]["data"]["blocks"]) == 5 + elif event["event_id"] == 4: + assert event["data"]["type"] == "stored" + assert len(event["data"]["blocks"]) == 5 _ = llm.generate(requests[2], sampling_params=sampling_params) events3 = llm.get_kv_cache_events(5) @@ -195,15 +192,15 @@ def test_llm_kv_events_api(): while events3: event = events3.pop(0) if event: - if event[0]["event_id"] == 5: - assert event[0]["data"]["type"] == "removed" - assert event[0]["data"]["block_hashes"] - elif event[0]["event_id"] == 6: - assert event[0]["data"]["type"] == "removed" - assert event[0]["data"]["block_hashes"] - elif event[0]["event_id"] == 7: - assert event[0]["data"]["type"] == "stored" - assert len(event[0]["data"]["blocks"]) == 5 + if event["event_id"] == 5: + assert event["data"]["type"] == "removed" + assert event["data"]["block_hashes"] + elif event["event_id"] == 6: + assert event["data"]["type"] == "removed" + assert event["data"]["block_hashes"] + elif event["event_id"] == 7: + assert event["data"]["type"] == "stored" + assert len(event["data"]["blocks"]) == 5 # no more events after request is finished assert not llm.get_kv_cache_events(5)