|
16 | 16 | from typing import AsyncIterator, Optional, Set, Tuple
|
17 | 17 |
|
18 | 18 | import uvloop
|
19 |
| -from fastapi import APIRouter, FastAPI, Request |
| 19 | +from fastapi import APIRouter, FastAPI, HTTPException, Request |
20 | 20 | from fastapi.exceptions import RequestValidationError
|
21 | 21 | from fastapi.middleware.cors import CORSMiddleware
|
22 | 22 | from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
44 | 44 | CompletionResponse,
|
45 | 45 | DetokenizeRequest,
|
46 | 46 | DetokenizeResponse,
|
| 47 | + EmbeddingChatRequest, |
| 48 | + EmbeddingCompletionRequest, |
47 | 49 | EmbeddingRequest,
|
48 | 50 | EmbeddingResponse,
|
49 | 51 | EmbeddingResponseData,
|
50 | 52 | ErrorResponse,
|
51 | 53 | LoadLoraAdapterRequest,
|
| 54 | + PoolingChatRequest, |
| 55 | + PoolingCompletionRequest, |
52 | 56 | PoolingRequest, PoolingResponse,
|
53 | 57 | ScoreRequest, ScoreResponse,
|
54 | 58 | TokenizeRequest,
|
@@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
|
310 | 314 | return Response(status_code=200)
|
311 | 315 |
|
312 | 316 |
|
| 317 | +@router.api_route("/ping", methods=["GET", "POST"]) |
| 318 | +async def ping(raw_request: Request) -> Response: |
| 319 | + """Ping check. Endpoint required for SageMaker""" |
| 320 | + return await health(raw_request) |
| 321 | + |
| 322 | + |
313 | 323 | @router.post("/tokenize")
|
314 | 324 | @with_cancellation
|
315 | 325 | async def tokenize(request: TokenizeRequest, raw_request: Request):
|
@@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
483 | 493 | return await create_score(request, raw_request)
|
484 | 494 |
|
485 | 495 |
|
| 496 | +TASK_HANDLERS = { |
| 497 | + "generate": { |
| 498 | + "messages": (ChatCompletionRequest, create_chat_completion), |
| 499 | + "default": (CompletionRequest, create_completion), |
| 500 | + }, |
| 501 | + "embed": { |
| 502 | + "messages": (EmbeddingChatRequest, create_embedding), |
| 503 | + "default": (EmbeddingCompletionRequest, create_embedding), |
| 504 | + }, |
| 505 | + "score": { |
| 506 | + "default": (ScoreRequest, create_score), |
| 507 | + }, |
| 508 | + "reward": { |
| 509 | + "messages": (PoolingChatRequest, create_pooling), |
| 510 | + "default": (PoolingCompletionRequest, create_pooling), |
| 511 | + }, |
| 512 | + "classify": { |
| 513 | + "messages": (PoolingChatRequest, create_pooling), |
| 514 | + "default": (PoolingCompletionRequest, create_pooling), |
| 515 | + }, |
| 516 | +} |
| 517 | + |
| 518 | + |
| 519 | +@router.post("/invocations") |
| 520 | +async def invocations(raw_request: Request): |
| 521 | + """ |
| 522 | + For SageMaker, routes requests to other handlers based on model `task`. |
| 523 | + """ |
| 524 | + body = await raw_request.json() |
| 525 | + task = raw_request.app.state.task |
| 526 | + |
| 527 | + if task not in TASK_HANDLERS: |
| 528 | + raise HTTPException( |
| 529 | + status_code=400, |
| 530 | + detail=f"Unsupported task: '{task}' for '/invocations'. " |
| 531 | + f"Expected one of {set(TASK_HANDLERS.keys())}") |
| 532 | + |
| 533 | + handler_config = TASK_HANDLERS[task] |
| 534 | + if "messages" in body: |
| 535 | + request_model, handler = handler_config["messages"] |
| 536 | + else: |
| 537 | + request_model, handler = handler_config["default"] |
| 538 | + |
| 539 | + # this is required since we lose the FastAPI automatic casting |
| 540 | + request = request_model.model_validate(body) |
| 541 | + return await handler(request, raw_request) |
| 542 | + |
| 543 | + |
486 | 544 | if envs.VLLM_TORCH_PROFILER_DIR:
|
487 | 545 | logger.warning(
|
488 | 546 | "Torch Profiler is enabled in the API server. This should ONLY be "
|
@@ -687,6 +745,7 @@ def init_app_state(
|
687 | 745 | chat_template=resolved_chat_template,
|
688 | 746 | chat_template_content_format=args.chat_template_content_format,
|
689 | 747 | )
|
| 748 | + state.task = model_config.task |
690 | 749 |
|
691 | 750 |
|
692 | 751 | def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
|
0 commit comments