Skip to content

Commit 68d3780

Browse files
authored
[Misc] Minimum requirements for SageMaker compatibility (#11576)
1 parent 5dba257 commit 68d3780

File tree

3 files changed

+95
-3
lines changed

3 files changed

+95
-3
lines changed

Dockerfile

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ RUN mv vllm test_docs/
234234
#################### TEST IMAGE ####################
235235

236236
#################### OPENAI API SERVER ####################
237-
# openai api server alternative
238-
FROM vllm-base AS vllm-openai
237+
# base openai image with additional requirements, for any subsequent openai-style images
238+
FROM vllm-base AS vllm-openai-base
239239

240240
# install additional dependencies for openai api server
241241
RUN --mount=type=cache,target=/root/.cache/pip \
@@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
247247

248248
ENV VLLM_USAGE_SOURCE production-docker-image
249249

250+
# define sagemaker first, so it is not default from `docker build`
251+
FROM vllm-openai-base AS vllm-sagemaker
252+
253+
COPY examples/sagemaker-entrypoint.sh .
254+
RUN chmod +x sagemaker-entrypoint.sh
255+
ENTRYPOINT ["./sagemaker-entrypoint.sh"]
256+
257+
FROM vllm-openai-base AS vllm-openai
258+
250259
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
251260
#################### OPENAI API SERVER ####################

examples/sagemaker-entrypoint.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
3+
# Define the prefix for environment variables to look for
4+
PREFIX="SM_VLLM_"
5+
ARG_PREFIX="--"
6+
7+
# Initialize an array for storing the arguments
8+
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
9+
ARGS=(--port 8080)
10+
11+
# Loop through all environment variables
12+
while IFS='=' read -r key value; do
13+
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
14+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
15+
16+
# Add the argument name and value to the ARGS array
17+
ARGS+=("${ARG_PREFIX}${arg_name}")
18+
if [ -n "$value" ]; then
19+
ARGS+=("$value")
20+
fi
21+
done < <(env | grep "^${PREFIX}")
22+
23+
# Pass the collected arguments to the main entrypoint
24+
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"

vllm/entrypoints/openai/api_server.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import AsyncIterator, Optional, Set, Tuple
1717

1818
import uvloop
19-
from fastapi import APIRouter, FastAPI, Request
19+
from fastapi import APIRouter, FastAPI, HTTPException, Request
2020
from fastapi.exceptions import RequestValidationError
2121
from fastapi.middleware.cors import CORSMiddleware
2222
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -44,11 +44,15 @@
4444
CompletionResponse,
4545
DetokenizeRequest,
4646
DetokenizeResponse,
47+
EmbeddingChatRequest,
48+
EmbeddingCompletionRequest,
4749
EmbeddingRequest,
4850
EmbeddingResponse,
4951
EmbeddingResponseData,
5052
ErrorResponse,
5153
LoadLoraAdapterRequest,
54+
PoolingChatRequest,
55+
PoolingCompletionRequest,
5256
PoolingRequest, PoolingResponse,
5357
ScoreRequest, ScoreResponse,
5458
TokenizeRequest,
@@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
310314
return Response(status_code=200)
311315

312316

317+
@router.api_route("/ping", methods=["GET", "POST"])
318+
async def ping(raw_request: Request) -> Response:
319+
"""Ping check. Endpoint required for SageMaker"""
320+
return await health(raw_request)
321+
322+
313323
@router.post("/tokenize")
314324
@with_cancellation
315325
async def tokenize(request: TokenizeRequest, raw_request: Request):
@@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
483493
return await create_score(request, raw_request)
484494

485495

496+
TASK_HANDLERS = {
497+
"generate": {
498+
"messages": (ChatCompletionRequest, create_chat_completion),
499+
"default": (CompletionRequest, create_completion),
500+
},
501+
"embed": {
502+
"messages": (EmbeddingChatRequest, create_embedding),
503+
"default": (EmbeddingCompletionRequest, create_embedding),
504+
},
505+
"score": {
506+
"default": (ScoreRequest, create_score),
507+
},
508+
"reward": {
509+
"messages": (PoolingChatRequest, create_pooling),
510+
"default": (PoolingCompletionRequest, create_pooling),
511+
},
512+
"classify": {
513+
"messages": (PoolingChatRequest, create_pooling),
514+
"default": (PoolingCompletionRequest, create_pooling),
515+
},
516+
}
517+
518+
519+
@router.post("/invocations")
520+
async def invocations(raw_request: Request):
521+
"""
522+
For SageMaker, routes requests to other handlers based on model `task`.
523+
"""
524+
body = await raw_request.json()
525+
task = raw_request.app.state.task
526+
527+
if task not in TASK_HANDLERS:
528+
raise HTTPException(
529+
status_code=400,
530+
detail=f"Unsupported task: '{task}' for '/invocations'. "
531+
f"Expected one of {set(TASK_HANDLERS.keys())}")
532+
533+
handler_config = TASK_HANDLERS[task]
534+
if "messages" in body:
535+
request_model, handler = handler_config["messages"]
536+
else:
537+
request_model, handler = handler_config["default"]
538+
539+
# this is required since we lose the FastAPI automatic casting
540+
request = request_model.model_validate(body)
541+
return await handler(request, raw_request)
542+
543+
486544
if envs.VLLM_TORCH_PROFILER_DIR:
487545
logger.warning(
488546
"Torch Profiler is enabled in the API server. This should ONLY be "
@@ -687,6 +745,7 @@ def init_app_state(
687745
chat_template=resolved_chat_template,
688746
chat_template_content_format=args.chat_template_content_format,
689747
)
748+
state.task = model_config.task
690749

691750

692751
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:

0 commit comments

Comments
 (0)