Skip to content

Commit 359af69

Browse files
committed
Add support for OpenAI API streaming protocol
1 parent a70d66e commit 359af69

File tree

1 file changed

+58
-26
lines changed

1 file changed

+58
-26
lines changed

pkg/epp/handlers/streamingserver.go

+58-26
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"strconv"
9+
"strings"
910
"time"
1011

1112
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -131,9 +132,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
131132
case *extProcPb.ProcessingRequest_ResponseHeaders:
132133
loggerVerbose.Info("got response headers", "headers", v.ResponseHeaders.Headers.GetHeaders())
133134
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
134-
code := header.RawValue[0]
135-
if header.Key == "status" && string(code) != "200" {
135+
value := string(header.RawValue)
136+
137+
if header.Key == "status" && value != "200" {
136138
reqCtx.ResponseStatusCode = errutil.ModelServerError
139+
} else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") {
140+
reqCtx.modelServerStreaming = true
141+
loggerVerbose.Info("model server is streaming response")
137142
}
138143
}
139144
reqCtx.RequestState = ResponseRecieved
@@ -158,36 +163,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
158163
}
159164

160165
case *extProcPb.ProcessingRequest_ResponseBody:
161-
go func() {
162-
_, err := writer.Write(v.ResponseBody.Body)
163-
if err != nil {
164-
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
165-
}
166-
}()
167-
168-
// Message is buffered, we can read and decode.
169-
if v.ResponseBody.EndOfStream {
170-
err = decoder.Decode(&responseBody)
171-
if err != nil {
172-
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
166+
if reqCtx.modelServerStreaming {
167+
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
168+
reqCtx.respBodyResp = &extProcPb.ProcessingResponse{
169+
Response: &extProcPb.ProcessingResponse_ResponseBody{
170+
ResponseBody: &extProcPb.BodyResponse{
171+
Response: &extProcPb.CommonResponse{
172+
BodyMutation: &extProcPb.BodyMutation{
173+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
174+
StreamedResponse: &extProcPb.StreamedBodyResponse{
175+
Body: v.ResponseBody.Body,
176+
EndOfStream: v.ResponseBody.EndOfStream,
177+
},
178+
},
179+
},
180+
},
181+
},
182+
},
173183
}
174-
// Body stream complete. Close the reader pipe.
175-
reader.Close()
176-
177-
reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
178-
if err == nil && reqCtx.ResponseComplete {
179-
reqCtx.ResponseCompleteTimestamp = time.Now()
180-
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
181-
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
182-
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
183-
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
184+
} else {
185+
go func() {
186+
_, err := writer.Write(v.ResponseBody.Body)
187+
if err != nil {
188+
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
189+
}
190+
}()
191+
192+
// Message is buffered, we can read and decode.
193+
if v.ResponseBody.EndOfStream {
194+
err = decoder.Decode(&responseBody)
195+
if err != nil {
196+
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
197+
}
198+
// Body stream complete. Close the reader pipe.
199+
reader.Close()
200+
201+
reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
202+
if err == nil && reqCtx.ResponseComplete {
203+
reqCtx.ResponseCompleteTimestamp = time.Now()
204+
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
205+
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
206+
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
207+
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
208+
}
209+
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
184210
}
185-
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
186211
}
187212
case *extProcPb.ProcessingRequest_ResponseTrailers:
188213
// This is currently unused.
189214
}
190215

216+
// Handle the err and fire an immediate response.
191217
if err != nil {
192218
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
193219
resp, err := BuildErrResponse(err)
@@ -246,7 +272,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter
246272
if err := srv.Send(r.respBodyResp); err != nil {
247273
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
248274
}
249-
r.RequestState = BodyResponseResponsesComplete
275+
276+
body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody)
277+
if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() {
278+
r.RequestState = BodyResponseResponsesComplete
279+
}
250280
// Dump the response so a new stream message can begin
251281
r.reqBodyResp = nil
252282
}
@@ -273,6 +303,8 @@ type StreamingRequestContext struct {
273303
ResponseComplete bool
274304
ResponseStatusCode string
275305

306+
modelServerStreaming bool
307+
276308
reqHeaderResp *extProcPb.ProcessingResponse
277309
reqBodyResp *extProcPb.ProcessingResponse
278310
reqTrailerResp *extProcPb.ProcessingResponse

0 commit comments

Comments
 (0)