6
6
"fmt"
7
7
"io"
8
8
"strconv"
9
+ "strings"
9
10
"time"
10
11
11
12
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -131,9 +132,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
131
132
case * extProcPb.ProcessingRequest_ResponseHeaders :
132
133
loggerVerbose .Info ("got response headers" , "headers" , v .ResponseHeaders .Headers .GetHeaders ())
133
134
for _ , header := range v .ResponseHeaders .Headers .GetHeaders () {
134
- code := header .RawValue [0 ]
135
- if header .Key == "status" && string (code ) != "200" {
135
+ value := string (header .RawValue )
136
+
137
+ if header .Key == "status" && value != "200" {
136
138
reqCtx .ResponseStatusCode = errutil .ModelServerError
139
+ } else if header .Key == "content-type" && strings .Contains (value , "text/event-stream" ) {
140
+ reqCtx .modelServerStreaming = true
141
+ loggerVerbose .Info ("model server is streaming response" )
137
142
}
138
143
}
139
144
reqCtx .RequestState = ResponseRecieved
@@ -158,36 +163,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
158
163
}
159
164
160
165
case * extProcPb.ProcessingRequest_ResponseBody :
161
- go func () {
162
- _ , err := writer .Write (v .ResponseBody .Body )
163
- if err != nil {
164
- logger .V (logutil .DEFAULT ).Error (err , "Error populating writer" )
165
- }
166
- }()
167
-
168
- // Message is buffered, we can read and decode.
169
- if v .ResponseBody .EndOfStream {
170
- err = decoder .Decode (& responseBody )
171
- if err != nil {
172
- logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
166
+ if reqCtx .modelServerStreaming {
167
+ // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
168
+ reqCtx .respBodyResp = & extProcPb.ProcessingResponse {
169
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
170
+ ResponseBody : & extProcPb.BodyResponse {
171
+ Response : & extProcPb.CommonResponse {
172
+ BodyMutation : & extProcPb.BodyMutation {
173
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
174
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
175
+ Body : v .ResponseBody .Body ,
176
+ EndOfStream : v .ResponseBody .EndOfStream ,
177
+ },
178
+ },
179
+ },
180
+ },
181
+ },
182
+ },
173
183
}
174
- // Body stream complete. Close the reader pipe.
175
- reader .Close ()
176
-
177
- reqCtx , err = s .HandleResponseBody (ctx , reqCtx , responseBody )
178
- if err == nil && reqCtx .ResponseComplete {
179
- reqCtx .ResponseCompleteTimestamp = time .Now ()
180
- metrics .RecordRequestLatencies (ctx , reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .RequestReceivedTimestamp , reqCtx .ResponseCompleteTimestamp )
181
- metrics .RecordResponseSizes (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .ResponseSize )
182
- metrics .RecordInputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .PromptTokens )
183
- metrics .RecordOutputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .CompletionTokens )
184
+ } else {
185
+ go func () {
186
+ _ , err := writer .Write (v .ResponseBody .Body )
187
+ if err != nil {
188
+ logger .V (logutil .DEFAULT ).Error (err , "Error populating writer" )
189
+ }
190
+ }()
191
+
192
+ // Message is buffered, we can read and decode.
193
+ if v .ResponseBody .EndOfStream {
194
+ err = decoder .Decode (& responseBody )
195
+ if err != nil {
196
+ logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
197
+ }
198
+ // Body stream complete. Close the reader pipe.
199
+ reader .Close ()
200
+
201
+ reqCtx , err = s .HandleResponseBody (ctx , reqCtx , responseBody )
202
+ if err == nil && reqCtx .ResponseComplete {
203
+ reqCtx .ResponseCompleteTimestamp = time .Now ()
204
+ metrics .RecordRequestLatencies (ctx , reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .RequestReceivedTimestamp , reqCtx .ResponseCompleteTimestamp )
205
+ metrics .RecordResponseSizes (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .ResponseSize )
206
+ metrics .RecordInputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .PromptTokens )
207
+ metrics .RecordOutputTokens (reqCtx .Model , reqCtx .ResolvedTargetModel , reqCtx .Usage .CompletionTokens )
208
+ }
209
+ loggerVerbose .Info ("Request context after HandleResponseBody" , "context" , reqCtx )
184
210
}
185
- loggerVerbose .Info ("Request context after HandleResponseBody" , "context" , reqCtx )
186
211
}
187
212
case * extProcPb.ProcessingRequest_ResponseTrailers :
188
213
// This is currently unused.
189
214
}
190
215
216
+ // Handle the err and fire an immediate response.
191
217
if err != nil {
192
218
logger .V (logutil .DEFAULT ).Error (err , "Failed to process request" , "request" , req )
193
219
resp , err := BuildErrResponse (err )
@@ -246,7 +272,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter
246
272
if err := srv .Send (r .respBodyResp ); err != nil {
247
273
return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
248
274
}
249
- r .RequestState = BodyResponseResponsesComplete
275
+
276
+ body := r .respBodyResp .Response .(* extProcPb.ProcessingResponse_ResponseBody )
277
+ if body .ResponseBody .Response .GetBodyMutation ().GetStreamedResponse ().GetEndOfStream () {
278
+ r .RequestState = BodyResponseResponsesComplete
279
+ }
250
280
// Dump the response so a new stream message can begin
251
281
r .reqBodyResp = nil
252
282
}
@@ -273,6 +303,8 @@ type StreamingRequestContext struct {
273
303
ResponseComplete bool
274
304
ResponseStatusCode string
275
305
306
+ modelServerStreaming bool
307
+
276
308
reqHeaderResp * extProcPb.ProcessingResponse
277
309
reqBodyResp * extProcPb.ProcessingResponse
278
310
reqTrailerResp * extProcPb.ProcessingResponse
0 commit comments