Skip to content

Commit 32e03ec

Browse files
authored
Add support for OpenAI API streaming protocol (#469)
* Add support for OpenAI API streaming protocol * Add streaming integration tests * reverting go mod changes * Uncommenting previous tests * fix errant typo * Updating test infra to work for multiple tests * Always marshal responseBody, add test case to check for this
1 parent 07df631 commit 32e03ec

File tree

7 files changed

+1293
-126
lines changed

7 files changed

+1293
-126
lines changed

.golangci.yml

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ linters:
2525
- makezero
2626
- errcheck
2727
- goconst
28-
- gocyclo
2928
- gofmt
3029
- goimports
3130
- gosimple

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ test: manifests generate fmt vet envtest ## Run tests.
123123

124124
.PHONY: test-integration
125125
test-integration: manifests generate fmt vet envtest ## Run tests.
126-
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration -coverprofile cover.out
126+
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration/epp/... -race -coverprofile cover.out
127127

128128
.PHONY: test-e2e
129129
test-e2e: ## Run end-to-end tests against an existing Kubernetes cluster with at least 3 available GPUs.

pkg/epp/handlers/streamingserver.go

+66-32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"strconv"
9+
"strings"
910
"time"
1011

1112
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -131,9 +132,14 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
131132
case *extProcPb.ProcessingRequest_ResponseHeaders:
132133
loggerVerbose.Info("got response headers", "headers", v.ResponseHeaders.Headers.GetHeaders())
133134
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
134-
code := header.RawValue[0]
135-
if header.Key == "status" && string(code) != "200" {
135+
value := string(header.RawValue)
136+
logger.Error(nil, "header", "key", header.Key, "value", value)
137+
if header.Key == "status" && value != "200" {
136138
reqCtx.ResponseStatusCode = errutil.ModelServerError
139+
} else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") {
140+
reqCtx.modelServerStreaming = true
141+
loggerVerbose.Info("model server is streaming response")
142+
logger.Error(nil, "made it here")
137143
}
138144
}
139145
reqCtx.RequestState = ResponseRecieved
@@ -158,36 +164,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
158164
}
159165

160166
case *extProcPb.ProcessingRequest_ResponseBody:
161-
go func() {
162-
_, err := writer.Write(v.ResponseBody.Body)
163-
if err != nil {
164-
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
165-
}
166-
}()
167-
168-
// Message is buffered, we can read and decode.
169-
if v.ResponseBody.EndOfStream {
170-
err = decoder.Decode(&responseBody)
171-
if err != nil {
172-
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
167+
if reqCtx.modelServerStreaming {
168+
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
169+
reqCtx.respBodyResp = &extProcPb.ProcessingResponse{
170+
Response: &extProcPb.ProcessingResponse_ResponseBody{
171+
ResponseBody: &extProcPb.BodyResponse{
172+
Response: &extProcPb.CommonResponse{
173+
BodyMutation: &extProcPb.BodyMutation{
174+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
175+
StreamedResponse: &extProcPb.StreamedBodyResponse{
176+
Body: v.ResponseBody.Body,
177+
EndOfStream: v.ResponseBody.EndOfStream,
178+
},
179+
},
180+
},
181+
},
182+
},
183+
},
173184
}
174-
// Body stream complete. Close the reader pipe.
175-
reader.Close()
176-
177-
reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
178-
if err == nil && reqCtx.ResponseComplete {
179-
reqCtx.ResponseCompleteTimestamp = time.Now()
180-
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
181-
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
182-
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
183-
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
185+
} else {
186+
go func() {
187+
_, err := writer.Write(v.ResponseBody.Body)
188+
if err != nil {
189+
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
190+
}
191+
}()
192+
193+
// Message is buffered, we can read and decode.
194+
if v.ResponseBody.EndOfStream {
195+
err = decoder.Decode(&responseBody)
196+
if err != nil {
197+
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
198+
}
199+
// Body stream complete. Close the reader pipe.
200+
reader.Close()
201+
202+
reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
203+
if err == nil && reqCtx.ResponseComplete {
204+
reqCtx.ResponseCompleteTimestamp = time.Now()
205+
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
206+
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
207+
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
208+
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
209+
}
210+
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
184211
}
185-
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
186212
}
187213
case *extProcPb.ProcessingRequest_ResponseTrailers:
188214
// This is currently unused.
189215
}
190216

217+
// Handle the err and fire an immediate response.
191218
if err != nil {
192219
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
193220
resp, err := BuildErrResponse(err)
@@ -246,7 +273,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter
246273
if err := srv.Send(r.respBodyResp); err != nil {
247274
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
248275
}
249-
r.RequestState = BodyResponseResponsesComplete
276+
277+
body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody)
278+
if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() {
279+
r.RequestState = BodyResponseResponsesComplete
280+
}
250281
// Dump the response so a new stream message can begin
251282
r.reqBodyResp = nil
252283
}
@@ -273,6 +304,8 @@ type StreamingRequestContext struct {
273304
ResponseComplete bool
274305
ResponseStatusCode string
275306

307+
modelServerStreaming bool
308+
276309
reqHeaderResp *extProcPb.ProcessingResponse
277310
reqBodyResp *extProcPb.ProcessingResponse
278311
reqTrailerResp *extProcPb.ProcessingResponse
@@ -339,14 +372,15 @@ func (s *StreamingServer) HandleRequestBody(
339372
// Update target models in the body.
340373
if llmReq.Model != llmReq.ResolvedTargetModel {
341374
requestBodyMap["model"] = llmReq.ResolvedTargetModel
342-
requestBodyBytes, err = json.Marshal(requestBodyMap)
343-
if err != nil {
344-
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
345-
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
346-
}
347-
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))
348375
}
349376

377+
requestBodyBytes, err = json.Marshal(requestBodyMap)
378+
if err != nil {
379+
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
380+
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
381+
}
382+
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))
383+
350384
target, err := s.scheduler.Schedule(ctx, llmReq)
351385
if err != nil {
352386
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}

pkg/epp/server/controller_manager.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
ctrl "sigs.k8s.io/controller-runtime"
2929
"sigs.k8s.io/controller-runtime/pkg/cache"
3030
"sigs.k8s.io/controller-runtime/pkg/client"
31+
"sigs.k8s.io/controller-runtime/pkg/manager"
3132
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3233
)
3334

@@ -40,7 +41,7 @@ func init() {
4041

4142
// NewDefaultManager creates a new controller manager with default configuration.
4243
func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Manager, error) {
43-
manager, err := ctrl.NewManager(restConfig, ctrl.Options{
44+
defaultOpts := ctrl.Options{
4445
Scheme: scheme,
4546
Cache: cache.Options{
4647
ByObject: map[client.Object]cache.ByObject{
@@ -65,7 +66,13 @@ func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Ma
6566
},
6667
},
6768
},
68-
})
69+
}
70+
return NewManagerWithOptions(restConfig, defaultOpts)
71+
}
72+
73+
// NewManagerWithOptions creates a new controller manager with injectable options.
74+
func NewManagerWithOptions(restConfig *rest.Config, opts manager.Options) (ctrl.Manager, error) {
75+
manager, err := ctrl.NewManager(restConfig, opts)
6976
if err != nil {
7077
return nil, fmt.Errorf("failed to create controller manager: %v", err)
7178
}

pkg/epp/util/testing/request.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package testing
1919
import (
2020
"encoding/json"
2121

22+
envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2223
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2324
"github.com/go-logr/logr"
2425
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -38,8 +39,29 @@ func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.Proces
3839
}
3940
req := &extProcPb.ProcessingRequest{
4041
Request: &extProcPb.ProcessingRequest_RequestBody{
41-
RequestBody: &extProcPb.HttpBody{Body: llmReq},
42+
RequestBody: &extProcPb.HttpBody{Body: llmReq, EndOfStream: true},
4243
},
4344
}
4445
return req
4546
}
47+
48+
func GenerateStreamedRequestSet(logger logr.Logger, prompt, model string) []*extProcPb.ProcessingRequest {
49+
requests := []*extProcPb.ProcessingRequest{}
50+
headerReq := &extProcPb.ProcessingRequest{
51+
Request: &extProcPb.ProcessingRequest_RequestHeaders{
52+
RequestHeaders: &extProcPb.HttpHeaders{
53+
Headers: &envoyCorev3.HeaderMap{
54+
Headers: []*envoyCorev3.HeaderValue{
55+
{
56+
Key: "hi",
57+
Value: "mom",
58+
},
59+
},
60+
},
61+
},
62+
},
63+
}
64+
requests = append(requests, headerReq)
65+
requests = append(requests, GenerateRequest(logger, prompt, model))
66+
return requests
67+
}

0 commit comments

Comments
 (0)