Skip to content

Commit a5d4bd8

Browse files
committed
Always marshal responseBody, add test case to check for this
1 parent c6b3d50 commit a5d4bd8

File tree

4 files changed

+128
-8
lines changed

4 files changed

+128
-8
lines changed

Diff for: go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module sigs.k8s.io/gateway-api-inference-extension
22

3-
go 1.23.0
3+
go 1.24.0
44

5-
toolchain go1.23.2
5+
toolchain go1.24.2
66

77
require (
88
github.com/bojand/ghz v0.120.0

Diff for: pkg/epp/handlers/streamingserver.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,15 @@ func (s *StreamingServer) HandleRequestBody(
372372
// Update target models in the body.
373373
if llmReq.Model != llmReq.ResolvedTargetModel {
374374
requestBodyMap["model"] = llmReq.ResolvedTargetModel
375-
requestBodyBytes, err = json.Marshal(requestBodyMap)
376-
if err != nil {
377-
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
378-
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
379-
}
380-
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))
381375
}
382376

377+
requestBodyBytes, err = json.Marshal(requestBodyMap)
378+
if err != nil {
379+
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
380+
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
381+
}
382+
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))
383+
383384
target, err := s.scheduler.Schedule(ctx, llmReq)
384385
if err != nil {
385386
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}

Diff for: test/integration/epp/hermetic_test.go

+108
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,114 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) {
873873
},
874874
},
875875
},
876+
{
877+
name: "inferencemodel's modelName is not translated, passthrough",
878+
requests: []*extProcPb.ProcessingRequest{
879+
{
880+
Request: &extProcPb.ProcessingRequest_RequestHeaders{
881+
RequestHeaders: &extProcPb.HttpHeaders{
882+
Headers: &configPb.HeaderMap{
883+
Headers: []*configPb.HeaderValue{
884+
{
885+
Key: "hi",
886+
Value: "mom",
887+
},
888+
},
889+
},
890+
},
891+
},
892+
},
893+
{
894+
Request: &extProcPb.ProcessingRequest_RequestBody{
895+
RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"direct-"), EndOfStream: false},
896+
},
897+
},
898+
{
899+
Request: &extProcPb.ProcessingRequest_RequestBody{
900+
RequestBody: &extProcPb.HttpBody{Body: []byte("model\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true},
901+
},
902+
},
903+
},
904+
905+
//
906+
// pod 0 will be picked as all other models are above threshold
907+
pods: map[backendmetrics.Pod]*backendmetrics.Metrics{
908+
fakePod(0): {
909+
WaitingQueueSize: 4,
910+
KVCacheUsagePercent: 0.2,
911+
ActiveModels: map[string]int{
912+
"foo": 1,
913+
"bar": 1,
914+
"sql-lora-1fdg3": 1,
915+
},
916+
},
917+
fakePod(1): {
918+
WaitingQueueSize: 0,
919+
KVCacheUsagePercent: 0.85,
920+
ActiveModels: map[string]int{
921+
"foo": 1,
922+
"sql-lora-1fdg3": 1,
923+
},
924+
},
925+
fakePod(2): {
926+
WaitingQueueSize: 10,
927+
KVCacheUsagePercent: 0.9,
928+
ActiveModels: map[string]int{
929+
"foo": 1,
930+
"sql-lora-1fdg3": 1,
931+
},
932+
},
933+
},
934+
wantMetrics: `
935+
# HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model.
936+
# TYPE inference_model_request_total counter
937+
inference_model_request_total{model_name="direct-model",target_model_name="direct-model"} 1
938+
`,
939+
wantErr: false,
940+
wantResponses: []*extProcPb.ProcessingResponse{
941+
{
942+
Response: &extProcPb.ProcessingResponse_RequestHeaders{
943+
RequestHeaders: &extProcPb.HeadersResponse{
944+
Response: &extProcPb.CommonResponse{
945+
ClearRouteCache: true,
946+
HeaderMutation: &extProcPb.HeaderMutation{
947+
SetHeaders: []*configPb.HeaderValueOption{
948+
{
949+
Header: &configPb.HeaderValue{
950+
Key: "x-gateway-destination-endpoint",
951+
RawValue: []byte("192.168.1.2:8000"),
952+
},
953+
},
954+
{
955+
Header: &configPb.HeaderValue{
956+
Key: "Content-Length",
957+
RawValue: []byte(strconv.Itoa(74)),
958+
},
959+
},
960+
}},
961+
},
962+
},
963+
},
964+
DynamicMetadata: makeMetadata("192.168.1.2:8000"),
965+
},
966+
{
967+
Response: &extProcPb.ProcessingResponse_RequestBody{
968+
RequestBody: &extProcPb.BodyResponse{
969+
Response: &extProcPb.CommonResponse{
970+
BodyMutation: &extProcPb.BodyMutation{
971+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
972+
StreamedResponse: &extProcPb.StreamedBodyResponse{
973+
Body: []byte("{\"max_tokens\":100,\"model\":\"direct-model\",\"prompt\":\"test6\",\"temperature\":0}"),
974+
EndOfStream: true,
975+
},
976+
},
977+
},
978+
},
979+
},
980+
},
981+
},
982+
},
983+
},
876984
// Response flow tests
877985
{
878986
name: "responsebody sent over multiple requests, content-type is json, buffer",

Diff for: test/testdata/inferencepool-with-model-hermetic.yaml

+11
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,14 @@ spec:
5050
targetModels:
5151
- name: my-model-12345
5252
weight: 100
53+
---
54+
apiVersion: inference.networking.x-k8s.io/v1alpha2
55+
kind: InferenceModel
56+
metadata:
57+
name: inferencemodel-direct-model-name
58+
namespace: default
59+
spec:
60+
modelName: direct-model
61+
criticality: Critical
62+
poolRef:
63+
name: vllm-llama2-7b-pool

0 commit comments

Comments
 (0)