6
6
"fmt"
7
7
"io"
8
8
"strconv"
9
+ "strings"
9
10
"time"
10
11
11
12
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -131,9 +132,14 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
131
132
case * extProcPb.ProcessingRequest_ResponseHeaders :
132
133
loggerVerbose .Info ("got response headers" , "headers" , v .ResponseHeaders .Headers .GetHeaders ())
133
134
for _ , header := range v .ResponseHeaders .Headers .GetHeaders () {
134
- code := header .RawValue [0 ]
135
- if header .Key == "status" && string (code ) != "200" {
135
+ value := string (header .RawValue )
136
+ logger .Error (nil , "header" , "key" , header .Key , "value" , value )
137
+ if header .Key == "status" && value != "200" {
136
138
reqCtx .ResponseStatusCode = errutil .ModelServerError
139
+ } else if header .Key == "content-type" && strings .Contains (value , "text/event-stream" ) {
140
+ reqCtx .modelServerStreaming = true
141
+ loggerVerbose .Info ("model server is streaming response" )
142
+ logger .Error (nil , "made it here" )
137
143
}
138
144
}
139
145
reqCtx .RequestState = ResponseRecieved
@@ -158,36 +164,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
158
164
}
159
165
160
166
case * extProcPb.ProcessingRequest_ResponseBody :
161
- go func () {
162
- _ , err := writer .Write (v .ResponseBody .Body )
163
- if err != nil {
164
- logger .V (logutil .DEFAULT ).Error (err , "Error populating writer" )
165
- }
166
- }()
167
-
168
- // Message is buffered, we can read and decode.
169
- if v .ResponseBody .EndOfStream {
170
- err = decoder .Decode (& responseBody )
171
- if err != nil {
172
- logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
167
+ if reqCtx .modelServerStreaming {
168
+ // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
169
+ reqCtx .respBodyResp = & extProcPb.ProcessingResponse {
170
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
171
+ ResponseBody : & extProcPb.BodyResponse {
172
+ Response : & extProcPb.CommonResponse {
173
+ BodyMutation : & extProcPb.BodyMutation {
174
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
175
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
176
+ Body : v .ResponseBody .Body ,
177
+ EndOfStream : v .ResponseBody .EndOfStream ,
178
+ },
179
+ },
180
+ },
181
+ },
182
+ },
183
+ },
173
184
}
174
- // Body stream complete. Close the reader pipe.
175
- reader .Close ()
176
-
177
- reqCtx , err = s .HandleResponseBody (ctx , reqCtx , responseBody )
178
- if err == nil && reqCtx .ResponseComplete {
179
- reqCtx .ResponseCompleteTimestamp = time .Now ()
180
- metrics .RecordRequestLatencies (ctx , reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .RequestReceivedTimestamp , reqCtx .ResponseCompleteTimestamp )
181
- metrics .RecordResponseSizes (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .ResponseSize )
182
- metrics .RecordInputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .PromptTokens )
183
- metrics .RecordOutputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .CompletionTokens )
185
+ } else {
186
+ go func () {
187
+ _ , err := writer .Write (v .ResponseBody .Body )
188
+ if err != nil {
189
+ logger .V (logutil .DEFAULT ).Error (err , "Error populating writer" )
190
+ }
191
+ }()
192
+
193
+ // Message is buffered, we can read and decode.
194
+ if v .ResponseBody .EndOfStream {
195
+ err = decoder .Decode (& responseBody )
196
+ if err != nil {
197
+ logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
198
+ }
199
+ // Body stream complete. Close the reader pipe.
200
+ reader .Close ()
201
+
202
+ reqCtx , err = s .HandleResponseBody (ctx , reqCtx , responseBody )
203
+ if err == nil && reqCtx .ResponseComplete {
204
+ reqCtx .ResponseCompleteTimestamp = time .Now ()
205
+ metrics .RecordRequestLatencies (ctx , reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .RequestReceivedTimestamp , reqCtx .ResponseCompleteTimestamp )
206
+ metrics .RecordResponseSizes (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .ResponseSize )
207
+ metrics .RecordInputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .PromptTokens )
208
+ metrics .RecordOutputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .CompletionTokens )
209
+ }
210
+ loggerVerbose .Info ("Request context after HandleResponseBody" , "context" , reqCtx )
184
211
}
185
- loggerVerbose .Info ("Request context after HandleResponseBody" , "context" , reqCtx )
186
212
}
187
213
case * extProcPb.ProcessingRequest_ResponseTrailers :
188
214
// This is currently unused.
189
215
}
190
216
217
+ // Handle the err and fire an immediate response.
191
218
if err != nil {
192
219
logger .V (logutil .DEFAULT ).Error (err , "Failed to process request" , "request" , req )
193
220
resp , err := BuildErrResponse (err )
@@ -246,7 +273,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter
246
273
if err := srv .Send (r .respBodyResp ); err != nil {
247
274
return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
248
275
}
249
- r .RequestState = BodyResponseResponsesComplete
276
+
277
+ body := r .respBodyResp .Response .(* extProcPb.ProcessingResponse_ResponseBody )
278
+ if body .ResponseBody .Response .GetBodyMutation ().GetStreamedResponse ().GetEndOfStream () {
279
+ r .RequestState = BodyResponseResponsesComplete
280
+ }
250
281
// Dump the response so a new stream message can begin
251
282
r .reqBodyResp = nil
252
283
}
@@ -273,6 +304,8 @@ type StreamingRequestContext struct {
273
304
ResponseComplete bool
274
305
ResponseStatusCode string
275
306
307
+ modelServerStreaming bool
308
+
276
309
reqHeaderResp * extProcPb.ProcessingResponse
277
310
reqBodyResp * extProcPb.ProcessingResponse
278
311
reqTrailerResp * extProcPb.ProcessingResponse
@@ -339,14 +372,15 @@ func (s *StreamingServer) HandleRequestBody(
339
372
// Update target models in the body.
340
373
if llmReq .Model != llmReq .ResolvedTargetModel {
341
374
requestBodyMap ["model" ] = llmReq .ResolvedTargetModel
342
- requestBodyBytes , err = json .Marshal (requestBodyMap )
343
- if err != nil {
344
- logger .V (logutil .DEFAULT ).Error (err , "Error marshaling request body" )
345
- return reqCtx , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
346
- }
347
- loggerVerbose .Info ("Updated request body marshalled" , "body" , string (requestBodyBytes ))
348
375
}
349
376
377
+ requestBodyBytes , err = json .Marshal (requestBodyMap )
378
+ if err != nil {
379
+ logger .V (logutil .DEFAULT ).Error (err , "Error marshaling request body" )
380
+ return reqCtx , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
381
+ }
382
+ loggerVerbose .Info ("Updated request body marshalled" , "body" , string (requestBodyBytes ))
383
+
350
384
target , err := s .scheduler .Schedule (ctx , llmReq )
351
385
if err != nil {
352
386
return reqCtx , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
0 commit comments