@@ -184,6 +184,20 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
184
184
case * extProcPb.ProcessingRequest_ResponseTrailers :
185
185
// This is currently unused.
186
186
}
187
+
188
+ if err != nil {
189
+ logger .V (logutil .DEFAULT ).Error (err , "Failed to process request" , "request" , req )
190
+ resp , err := BuildErrResponse (err )
191
+ if err != nil {
192
+ return err
193
+ } else {
194
+ if err := srv .Send (resp ); err != nil {
195
+ logger .V (logutil .DEFAULT ).Error (err , "Send failed" )
196
+ return status .Errorf (codes .Unknown , "failed to send response back to Envoy: %v" , err )
197
+ }
198
+ return nil
199
+ }
200
+ }
187
201
loggerVerbose .Info ("checking" , "request state" , reqCtx .RequestState )
188
202
if err := reqCtx .updateStateAndSendIfNeeded (srv , loggerVerbose ); err != nil {
189
203
return err
@@ -280,6 +294,7 @@ const (
280
294
TrailerResponseResponsesComplete StreamRequestState = 7
281
295
)
282
296
297
+ // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
283
298
func (s * StreamingServer ) HandleRequestBody (
284
299
ctx context.Context ,
285
300
reqCtx * StreamingRequestContext ,
@@ -294,7 +309,7 @@ func (s *StreamingServer) HandleRequestBody(
294
309
// Resolve target models.
295
310
model , ok := requestBodyMap ["model" ].(string )
296
311
if ! ok {
297
- return nil , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
312
+ return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
298
313
}
299
314
loggerVerbose .Info ("Model requested" , "model" , model )
300
315
modelName := model
@@ -304,12 +319,12 @@ func (s *StreamingServer) HandleRequestBody(
304
319
// are able to be requested by using their distinct name.
305
320
modelObj := s .datastore .ModelGet (model )
306
321
if modelObj == nil {
307
- return nil , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
322
+ return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
308
323
}
309
324
if len (modelObj .Spec .TargetModels ) > 0 {
310
325
modelName = datastore .RandomWeightedDraw (logger , modelObj , 0 )
311
326
if modelName == "" {
312
- return nil , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
327
+ return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
313
328
}
314
329
}
315
330
llmReq := & scheduling.LLMRequest {
@@ -326,21 +341,21 @@ func (s *StreamingServer) HandleRequestBody(
326
341
requestBodyBytes , err = json .Marshal (requestBodyMap )
327
342
if err != nil {
328
343
logger .V (logutil .DEFAULT ).Error (err , "Error marshaling request body" )
329
- return nil , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
344
+ return reqCtx , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
330
345
}
331
346
loggerVerbose .Info ("Updated request body marshalled" , "body" , string (requestBodyBytes ))
332
347
}
333
348
334
349
targetPod , err := s .scheduler .Schedule (ctx , llmReq )
335
350
if err != nil {
336
- return nil , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
351
+ return reqCtx , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
337
352
}
338
353
339
354
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
340
355
// Attach the port number
341
356
pool , err := s .datastore .PoolGet ()
342
357
if err != nil {
343
- return nil , err
358
+ return reqCtx , err
344
359
}
345
360
endpoint := targetPod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
346
361
@@ -432,6 +447,7 @@ func (s *StreamingServer) HandleRequestBody(
432
447
return reqCtx , nil
433
448
}
434
449
450
+ // HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
435
451
func (s * StreamingServer ) HandleResponseBody (
436
452
ctx context.Context ,
437
453
reqCtx * StreamingRequestContext ,
@@ -443,7 +459,7 @@ func (s *StreamingServer) HandleResponseBody(
443
459
responseBytes , err := json .Marshal (response )
444
460
if err != nil {
445
461
logger .V (logutil .DEFAULT ).Error (err , "error marshalling responseBody" )
446
- return nil , err
462
+ return reqCtx , err
447
463
}
448
464
if response ["usage" ] != nil {
449
465
usg := response ["usage" ].(map [string ]interface {})
0 commit comments