Skip to content

Commit ae56716

Browse files
pcastonguayShunkang
and
Shunkang
authored
feat: Disaggregated router class (#3584)
* Add draft scheduler class Signed-off-by: Shunkang <[email protected]> * Refactor the design Signed-off-by: Shunkang <[email protected]> * feat: Introduce router class for disaggregated server Signed-off-by: Patrice Castonguay <[email protected]> * Add unit tests for router class Signed-off-by: Patrice Castonguay <[email protected]> * Adding tests for disagg_utils Signed-off-by: Patrice Castonguay <[email protected]> * Fixing missing import Signed-off-by: Patrice Castonguay <[email protected]> * Fixing disagg integration tests Signed-off-by: Patrice Castonguay <[email protected]> * Addressing MR review comments Signed-off-by: Patrice Castonguay <[email protected]> --------- Signed-off-by: Shunkang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Co-authored-by: Shunkang <[email protected]>
1 parent b9fce42 commit ae56716

File tree

9 files changed

+626
-42
lines changed

9 files changed

+626
-42
lines changed

tensorrt_llm/commands/serve.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def disaggregated(config_file: Optional[str], server_start_timeout: int,
253253
server = OpenAIDisaggServer(ctx_servers=ctx_server_urls,
254254
gen_servers=gen_server_urls,
255255
req_timeout_secs=request_timeout,
256-
server_start_timeout_secs=server_start_timeout)
256+
server_start_timeout_secs=server_start_timeout,
257+
ctx_router_type=disagg_cfg.ctx_router_type,
258+
gen_router_type=disagg_cfg.gen_router_type)
257259

258260
asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port))
259261

tensorrt_llm/llmapi/disagg_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class DisaggServerConfig():
2929
server_configs: List[CtxGenServerConfig]
3030
hostname: str = "localhost"
3131
port: int = 8000
32+
ctx_router_type: str = "round_robin"
33+
gen_router_type: str = "round_robin"
3234

3335

3436
def parse_disagg_config_file(yaml_config_file: str):
@@ -68,7 +70,11 @@ def extract_disagg_cfg(hostname: str = 'localhost',
6870
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
6971
type="gen", **generation_servers)
7072

71-
return DisaggServerConfig(server_configs, hostname, port)
73+
ctx_router_type = context_servers.get("router_type", "round_robin")
74+
gen_router_type = generation_servers.get("router_type", "round_robin")
75+
76+
return DisaggServerConfig(server_configs, hostname, port, ctx_router_type,
77+
gen_router_type)
7278

7379

7480
def extract_ctx_gen_cfgs(type: Literal['ctx', 'gen'],

tensorrt_llm/serve/openai_disagg_server.py

+57-40
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CompletionResponse,
2424
DisaggregatedParams,
2525
ErrorResponse)
26+
from tensorrt_llm.serve.router import create_router
2627
from tensorrt_llm.version import __version__ as VERSION
2728

2829
logging.basicConfig(level=logging.INFO)
@@ -36,11 +37,16 @@ def __init__(self,
3637
ctx_servers: List[str] = None,
3738
gen_servers: List[str] = None,
3839
req_timeout_secs: int = 180,
39-
server_start_timeout_secs: int = 180):
40+
server_start_timeout_secs: int = 180,
41+
ctx_router_type: str = "round_robin",
42+
gen_router_type: str = "round_robin"):
43+
4044
self.ctx_servers = ctx_servers
4145
self.gen_servers = gen_servers
4246
self.ctx_server_idx = 0
4347
self.gen_server_idx = 0
48+
self.ctx_router = create_router(ctx_router_type, ctx_servers)
49+
self.gen_router = create_router(gen_router_type, gen_servers)
4450

4551
if (len(self.gen_servers) == 0):
4652
raise ValueError("At least one generation server must be provided")
@@ -97,24 +103,28 @@ async def version(self) -> JSONResponse:
97103
async def merge_streaming_responses(self, ctx_response,
98104
gen_server: str,
99105
gen_req: Union[CompletionRequest, ChatCompletionRequest]):
100-
# First yield the context response if it's not None
101-
if ctx_response is not None:
102-
# Remove the disaggregated params from the context response
103-
data = ctx_response.model_dump()
104-
del data['choices'][0]['disaggregated_params']
105-
data = json.dumps(data)
106-
yield f"data: {data}\n\n".encode('utf-8')
107-
108-
# Then yield the generation responses
109-
if isinstance(gen_req, CompletionRequest):
110-
gen_response = await self.send_completion_request(gen_server, gen_req)
111-
elif isinstance(gen_req, ChatCompletionRequest):
112-
gen_response = await self.send_chat_request(gen_server, gen_req)
113-
else:
114-
raise TypeError("Invalid request type: {type(gen_req).__name__}")
106+
try:
107+
# First yield the context response if it's not None
108+
if ctx_response is not None:
109+
# Remove the disaggregated params from the context response
110+
data = ctx_response.model_dump()
111+
del data['choices'][0]['disaggregated_params']
112+
data = json.dumps(data)
113+
yield f"data: {data}\n\n".encode('utf-8')
114+
115+
# Then yield the generation responses
116+
if isinstance(gen_req, CompletionRequest):
117+
gen_response = await self.send_completion_request(gen_server, gen_req)
118+
elif isinstance(gen_req, ChatCompletionRequest):
119+
gen_response = await self.send_chat_request(gen_server, gen_req)
120+
else:
121+
raise TypeError("Invalid request type: {type(gen_req).__name__}")
115122

116-
async for chunk in gen_response.body_iterator:
117-
yield chunk
123+
async for chunk in gen_response.body_iterator:
124+
yield chunk
125+
126+
finally:
127+
await self.gen_router.finish_request(gen_req)
118128

119129
async def openai_completion(self, req: CompletionRequest) -> Response:
120130
try:
@@ -158,21 +168,25 @@ async def _process_context_server_request(self, ctx_req, request_type: str):
158168
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
159169
return None
160170

161-
ctx_server = self.get_next_server(self.ctx_servers, "context")
162-
logging.info("Sending request to ctx server: %s", ctx_server)
163-
164-
if request_type == "chat":
165-
ctx_req.max_completion_tokens = 1
166-
elif request_type == "completion":
167-
ctx_req.max_tokens = 1
168-
ctx_req.disaggregated_params = DisaggregatedParams(request_type="context_only")
169-
ctx_req.stream = False
170-
ctx_req.stream_options = None
171-
172-
if request_type == "chat":
173-
return await self.send_chat_request(ctx_server, ctx_req)
174-
elif request_type == "completion":
175-
return await self.send_completion_request(ctx_server, ctx_req)
171+
try:
172+
if request_type == "chat":
173+
ctx_req.max_completion_tokens = 1
174+
elif request_type == "completion":
175+
ctx_req.max_tokens = 1
176+
ctx_req.disaggregated_params = DisaggregatedParams(request_type="context_only")
177+
ctx_req.stream = False
178+
ctx_req.stream_options = None
179+
180+
ctx_server = await self.ctx_router.get_next_server(ctx_req)
181+
logging.info("Sending request to ctx server: %s", ctx_server)
182+
183+
if request_type == "chat":
184+
response = await self.send_chat_request(ctx_server, ctx_req)
185+
else:
186+
response = await self.send_completion_request(ctx_server, ctx_req)
187+
return response # Don't forget to return the response if needed
188+
finally:
189+
await self.ctx_router.finish_request(ctx_req)
176190

177191
async def _process_generation_server_request(self, gen_req, ctx_response):
178192
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
@@ -192,16 +206,19 @@ async def _process_generation_server_request(self, gen_req, ctx_response):
192206
gen_req.disaggregated_params.request_type = "generation_only"
193207

194208
# Pick a generation server and send request
195-
gen_server = self.get_next_server(self.gen_servers, "generation")
209+
gen_server = await self.gen_router.get_next_server(gen_req)
196210
logging.info("Sending request to gen server: %s", gen_server)
197211

198212
if not gen_req.stream:
199-
if isinstance(gen_req, CompletionRequest):
200-
gen_response = await self.send_completion_request(gen_server, gen_req)
201-
elif isinstance(gen_req, ChatCompletionRequest):
202-
gen_response = await self.send_chat_request(gen_server, gen_req)
203-
204-
return gen_response
213+
try:
214+
if isinstance(gen_req, CompletionRequest):
215+
gen_response = await self.send_completion_request(gen_server, gen_req)
216+
elif isinstance(gen_req, ChatCompletionRequest):
217+
gen_response = await self.send_chat_request(gen_server, gen_req)
218+
219+
return gen_response
220+
finally:
221+
await self.gen_router.finish_request(gen_req)
205222
else:
206223
# Return a streaming response that combines both context and generation responses
207224
return StreamingResponse(

tensorrt_llm/serve/router.py

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import asyncio
2+
import heapq
3+
from abc import ABC, abstractmethod
4+
from typing import List, Union
5+
6+
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
7+
CompletionRequest)
8+
9+
10+
def get_request_num_tokens(
11+
request: Union[CompletionRequest, ChatCompletionRequest]) -> int:
12+
if request.disaggregated_params.request_type == "context_only":
13+
if isinstance(request, ChatCompletionRequest):
14+
raise ValueError(
15+
"LoadBalancing router with tokens doesn't support ChatCompletionRequest yet"
16+
)
17+
18+
if isinstance(request.prompt, str) or \
19+
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
20+
prompts = [request.prompt]
21+
else:
22+
prompts = request.prompt
23+
24+
num_tokens = sum(len(prompt) for prompt in prompts)
25+
elif request.disaggregated_params.request_type == "generation_only":
26+
raise ValueError(
27+
"LoadBalancing router with tokens doesn't support generation_only requests"
28+
)
29+
else:
30+
raise ValueError(
31+
f"Unsupported request type: {request.disaggregated_params.request_type}"
32+
)
33+
34+
return num_tokens
35+
36+
37+
class ServerState:
38+
39+
def __init__(self, server: str, use_tokens: bool = False):
40+
self._server = server
41+
self._num_active_requests = 0
42+
self._num_active_tokens = 0
43+
self._use_tokens = use_tokens
44+
self._lock = asyncio.Lock()
45+
46+
async def increment_load(self, request: Union[CompletionRequest,
47+
ChatCompletionRequest]):
48+
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
49+
async with self._lock:
50+
self._num_active_requests += 1
51+
self._num_active_tokens += num_tokens
52+
53+
async def decrement_load(self, request: Union[CompletionRequest,
54+
ChatCompletionRequest]):
55+
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
56+
async with self._lock:
57+
self._num_active_requests -= 1
58+
self._num_active_tokens -= num_tokens
59+
60+
61+
class Router(ABC):
62+
63+
def __init__(self, servers: List[str] = None):
64+
self._servers = servers
65+
66+
@abstractmethod
67+
async def get_next_server(
68+
self, request: Union[CompletionRequest,
69+
ChatCompletionRequest]) -> str:
70+
pass
71+
72+
@abstractmethod
73+
async def finish_request(self, request: Union[CompletionRequest,
74+
ChatCompletionRequest]):
75+
pass
76+
77+
78+
class RoundRobinRouter(Router):
79+
80+
def __init__(self, servers: List[str] = None):
81+
super().__init__(servers)
82+
self._server_idx = 0
83+
84+
async def get_next_server(
85+
self, request: Union[CompletionRequest,
86+
ChatCompletionRequest]) -> str:
87+
server = self._servers[self._server_idx]
88+
self._server_idx = (self._server_idx + 1) % len(self._servers)
89+
return server
90+
91+
async def finish_request(self, request: Union[CompletionRequest,
92+
ChatCompletionRequest]):
93+
pass
94+
95+
96+
class LoadBalancingRouter(Router):
97+
98+
def __init__(self, servers: List[str] = None, use_tokens: bool = False):
99+
super().__init__(servers)
100+
self._lock = asyncio.Lock()
101+
# Load map between servers and their number of tokens processed
102+
self._server_state = {}
103+
self._server_load_heap = []
104+
105+
# Routing table to map requests to servers
106+
self._req_routing_table = {}
107+
108+
self._use_tokens = use_tokens
109+
self._init_heap()
110+
111+
def _init_heap(self):
112+
for server in self._servers:
113+
self._server_state[server] = ServerState(server, self._use_tokens)
114+
heapq.heappush(self._server_load_heap,
115+
(self._get_server_load(server), server))
116+
117+
async def get_next_server(
118+
self, request: Union[CompletionRequest,
119+
ChatCompletionRequest]) -> str:
120+
async with self._lock:
121+
server = heapq.heappop(self._server_load_heap)[1]
122+
await self._server_state[server].increment_load(request)
123+
heapq.heappush(self._server_load_heap,
124+
(self._get_server_load(server), server))
125+
126+
self._req_routing_table[id(request)] = server
127+
128+
return server
129+
130+
def _get_server_load(self, server):
131+
return self._server_state[server]._num_active_tokens if self._use_tokens \
132+
else self._server_state[server]._num_active_requests
133+
134+
async def finish_request(self, request: Union[CompletionRequest,
135+
ChatCompletionRequest]):
136+
async with self._lock:
137+
server = self._req_routing_table[id(request)]
138+
await self._server_state[server].decrement_load(request)
139+
heapq.heappush(self._server_load_heap,
140+
(self._get_server_load(server), server))
141+
del self._req_routing_table[id(request)]
142+
143+
144+
def create_router(router_type: str, servers: List[str]) -> Router:
145+
"""
146+
Factory function to create different types of router instances.
147+
148+
Args:
149+
router_type (str): Type of router to create. Supported values:
150+
- "round_robin": Creates a RoundRobinRouter
151+
- "requests_load_balancing": Creates a LoadBalancingRouter, which balances requests across instances
152+
- "tokens_load_balancing": Creates a LoadBalancingRouter, which balances tokens across instances
153+
servers: List of server URLs
154+
155+
Returns:
156+
Router: An instance of the requested router type
157+
158+
Raises:
159+
ValueError: If an unsupported router type is provided
160+
"""
161+
162+
router_map = {
163+
"round_robin": RoundRobinRouter,
164+
"requests_load_balancing": LoadBalancingRouter,
165+
"tokens_load_balancing": LoadBalancingRouter
166+
}
167+
168+
router_class = router_map.get(router_type.lower())
169+
if router_class is None:
170+
raise ValueError(f"Unsupported router type: {router_type}. "
171+
f"Supported types are: {list(router_map.keys())}")
172+
173+
if router_type.endswith("load_balancing"):
174+
use_tokens = True if router_type.startswith("tokens") else False
175+
return router_class(servers, use_tokens=use_tokens)
176+
else:
177+
return router_class(servers)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
port: 8000
4+
backend: "pytorch"
5+
free_gpu_memory_fraction: 0.15
6+
context_servers:
7+
num_instances: 2
8+
router_type: tokens_load_balancing
9+
max_batch_size: 1
10+
max_num_tokens: 3000
11+
max_seq_len: 4096
12+
tensor_parallel_size: 1
13+
pipeline_parallel_size: 1
14+
kv_cache_config:
15+
free_gpu_memory_fraction: 0.15
16+
enable_partial_reuse: False
17+
pytorch_backend_config:
18+
use_cuda_graph: False
19+
enable_overlap_scheduler: False
20+
urls:
21+
- "localhost:8001"
22+
- "localhost:8002"
23+
generation_servers:
24+
num_instances: 2
25+
router_type: requests_load_balancing
26+
tensor_parallel_size: 1
27+
pipeline_parallel_size: 1
28+
max_batch_size: 256
29+
max_num_tokens: 4096
30+
max_seq_len: 4096
31+
kv_cache_config:
32+
free_gpu_memory_fraction: 0.15
33+
enable_partial_reuse: False
34+
pytorch_backend_config:
35+
use_cuda_graph: False
36+
enable_overlap_scheduler: True
37+
urls:
38+
- "localhost:8003"
39+
- "localhost:8004"

0 commit comments

Comments
 (0)