diff --git a/cmd/body-based-routing/main.go b/cmd/body-based-routing/main.go index 13f841b6..cfc584ce 100644 --- a/cmd/body-based-routing/main.go +++ b/cmd/body-based-routing/main.go @@ -44,7 +44,7 @@ import ( var ( grpcPort = flag.Int( "grpcPort", - runserver.DefaultGrpcPort, + 9004, "The gRPC port used for communicating with Envoy proxy") grpcHealthPort = flag.Int( "grpcHealthPort", @@ -52,6 +52,8 @@ var ( "The port used for gRPC liveness and readiness probes") metricsPort = flag.Int( "metricsPort", 9090, "The metrics port") + streaming = flag.Bool( + "streaming", false, "Enables streaming support for Envoy full-duplex streaming mode") logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity") setupLog = ctrl.Log.WithName("setup") @@ -92,7 +94,7 @@ func run() error { ctx := ctrl.SetupSignalHandler() // Setup runner. - serverRunner := &runserver.ExtProcServerRunner{GrpcPort: *grpcPort} + serverRunner := runserver.NewDefaultExtProcServerRunner(*grpcPort, *streaming) // Register health server. if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), *grpcHealthPort); err != nil { diff --git a/pkg/body-based-routing/handlers/request.go b/pkg/body-based-routing/handlers/request.go index 6596e191..c0be46ac 100644 --- a/pkg/body-based-routing/handlers/request.go +++ b/pkg/body-based-routing/handlers/request.go @@ -23,17 +23,21 @@ import ( basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +const modelHeader = "X-Gateway-Model-Name" + // HandleRequestBody handles request bodies. -func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*eppb.ProcessingResponse, error) { +func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([]*eppb.ProcessingResponse, error) { logger := log.FromContext(ctx) + var ret []*eppb.ProcessingResponse - var data map[string]any - if err := json.Unmarshal(body.GetBody(), &data); err != nil { + requestBodyBytes, err := json.Marshal(data) + if err != nil { return nil, err } @@ -41,37 +45,71 @@ func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*e if !ok { metrics.RecordModelNotInBodyCounter() logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter") - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{}, - }, - }, nil + if s.streaming { + ret = append(ret, &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &eppb.HeadersResponse{}, + }, + }) + ret = addStreamedBodyResponse(ret, requestBodyBytes) + return ret, nil + } else { + ret = append(ret, &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_RequestBody{ + RequestBody: &eppb.BodyResponse{}, + }, + }) + } + return ret, nil } modelStr, ok := modelVal.(string) if !ok { metrics.RecordModelNotParsedCounter() logger.V(logutil.DEFAULT).Info("Model parameter value is not a string") - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{}, - }, - }, fmt.Errorf("the model parameter value %v is not a string", modelVal) + return nil, fmt.Errorf("the model parameter value %v is not a string", modelVal) } metrics.RecordSuccessCounter() - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{ - Response: &eppb.CommonResponse{ - // Necessary so that the new headers are used in the routing decision. - ClearRouteCache: true, - HeaderMutation: &eppb.HeaderMutation{ - SetHeaders: []*basepb.HeaderValueOption{ - { - Header: &basepb.HeaderValue{ - Key: "X-Gateway-Model-Name", - RawValue: []byte(modelStr), + + if s.streaming { + ret = append(ret, &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &eppb.HeadersResponse{ + Response: &eppb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &eppb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: modelHeader, + RawValue: []byte(modelStr), + }, + }, + }, + }, + }, + }, + }, + }) + ret = addStreamedBodyResponse(ret, requestBodyBytes) + return ret, nil + } + + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_RequestBody{ + RequestBody: &eppb.BodyResponse{ + Response: &eppb.CommonResponse{ + // Necessary so that the new headers are used in the routing decision. + ClearRouteCache: true, + HeaderMutation: &eppb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: modelHeader, + RawValue: []byte(modelStr), + }, }, }, }, @@ -82,20 +120,43 @@ func (s *Server) HandleRequestBody(ctx context.Context, body *eppb.HttpBody) (*e }, nil } +func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBytes []byte) []*eppb.ProcessingResponse { + return append(responses, &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: requestBodyBytes, + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }) +} + // HandleRequestHeaders handles request headers. -func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) (*eppb.ProcessingResponse, error) { - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &eppb.HeadersResponse{}, +func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &eppb.HeadersResponse{}, + }, }, }, nil } // HandleRequestTrailers handles request trailers. -func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) (*eppb.ProcessingResponse, error) { - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_RequestTrailers{ - RequestTrailers: &eppb.TrailersResponse{}, +func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) ([]*eppb.ProcessingResponse, error) { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_RequestTrailers{ + RequestTrailers: &eppb.TrailersResponse{}, + }, }, }, nil } diff --git a/pkg/body-based-routing/handlers/request_test.go b/pkg/body-based-routing/handlers/request_test.go index 76f64e0c..0f088702 100644 --- a/pkg/body-based-routing/handlers/request_test.go +++ b/pkg/body-based-routing/handlers/request_test.go @@ -18,6 +18,7 @@ package handlers import ( "context" + "encoding/json" "strings" "testing" @@ -31,78 +32,138 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -const ( - bodyWithModel = ` - { - "model": "foo", - "prompt": "Tell me a joke" - } - ` - bodyWithModelNoStr = ` - { - "model": 1, - "prompt": "Tell me a joke" - } - ` - bodyWithoutModel = ` - { - "prompt": "Tell me a joke" - } - ` -) - func TestHandleRequestBody(t *testing.T) { metrics.Register() ctx := logutil.NewTestLoggerIntoContext(context.Background()) tests := []struct { - name string - body *extProcPb.HttpBody - want *extProcPb.ProcessingResponse - wantErr bool + name string + body map[string]any + streaming bool + want []*extProcPb.ProcessingResponse + wantErr bool }{ { - name: "malformed body", - body: &extProcPb.HttpBody{ - Body: []byte("malformed json"), + name: "model not found", + body: map[string]any{ + "prompt": "Tell me a joke", + }, + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{}, + }, + }, }, - wantErr: true, }, { - name: "model not found", - body: &extProcPb.HttpBody{ - Body: []byte(bodyWithoutModel), + name: "model not found with streaming", + body: map[string]any{ + "prompt": "Tell me a joke", }, - want: &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{}, + streaming: true, + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{}, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: mapToBytes(t, map[string]any{ + "prompt": "Tell me a joke", + }), + EndOfStream: true, + }, + }, + }, + }, + }, + }, }, }, }, { name: "model is not string", - body: &extProcPb.HttpBody{ - Body: []byte(bodyWithModelNoStr), + body: map[string]any{ + "model": 1, + "prompt": "Tell me a joke", }, wantErr: true, }, { name: "success", - body: &extProcPb.HttpBody{ - Body: []byte(bodyWithModel), + body: map[string]any{ + "model": "foo", + "prompt": "Tell me a joke", }, - want: &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - // Necessary so that the new headers are used in the routing decision. - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*basepb.HeaderValueOption{ - { - Header: &basepb.HeaderValue{ - Key: "X-Gateway-Model-Name", - RawValue: []byte("foo"), + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + // Necessary so that the new headers are used in the routing decision. + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: "X-Gateway-Model-Name", + RawValue: []byte("foo"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "success-with-streaming", + body: map[string]any{ + "model": "foo", + "prompt": "Tell me a joke", + }, + streaming: true, + want: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*basepb.HeaderValueOption{ + { + Header: &basepb.HeaderValue{ + Key: "X-Gateway-Model-Name", + RawValue: []byte("foo"), + }, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: mapToBytes(t, map[string]any{ + "model": "foo", + "prompt": "Tell me a joke", + }), + EndOfStream: true, }, }, }, @@ -116,7 +177,7 @@ func TestHandleRequestBody(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - server := &Server{} + server := &Server{streaming: test.streaming} resp, err := server.HandleRequestBody(ctx, test.body) if err != nil { if !test.wantErr { @@ -147,3 +208,12 @@ func TestHandleRequestBody(t *testing.T) { t.Error(err) } } + +func mapToBytes(t *testing.T, m map[string]any) []byte { + // Convert map to JSON byte array + bytes, err := json.Marshal(m) + if err != nil { + t.Fatalf("Marshal(): %v", err) + } + return bytes +} diff --git a/pkg/body-based-routing/handlers/response.go b/pkg/body-based-routing/handlers/response.go index a62aa076..fbcb75d6 100644 --- a/pkg/body-based-routing/handlers/response.go +++ b/pkg/body-based-routing/handlers/response.go @@ -21,28 +21,34 @@ import ( ) // HandleResponseHeaders handles response headers. -func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) (*eppb.ProcessingResponse, error) { - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &eppb.HeadersResponse{}, +func (s *Server) HandleResponseHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &eppb.HeadersResponse{}, + }, }, }, nil } // HandleResponseBody handles response bodies. -func (s *Server) HandleResponseBody(body *eppb.HttpBody) (*eppb.ProcessingResponse, error) { - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_ResponseBody{ - ResponseBody: &eppb.BodyResponse{}, +func (s *Server) HandleResponseBody(body *eppb.HttpBody) ([]*eppb.ProcessingResponse, error) { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseBody{ + ResponseBody: &eppb.BodyResponse{}, + }, }, }, nil } // HandleResponseTrailers handles response trailers. -func (s *Server) HandleResponseTrailers(trailers *eppb.HttpTrailers) (*eppb.ProcessingResponse, error) { - return &eppb.ProcessingResponse{ - Response: &eppb.ProcessingResponse_ResponseTrailers{ - ResponseTrailers: &eppb.TrailersResponse{}, +func (s *Server) HandleResponseTrailers(trailers *eppb.HttpTrailers) ([]*eppb.ProcessingResponse, error) { + return []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseTrailers{ + ResponseTrailers: &eppb.TrailersResponse{}, + }, }, }, nil } diff --git a/pkg/body-based-routing/handlers/server.go b/pkg/body-based-routing/handlers/server.go index 813c55c8..36eb3c2f 100644 --- a/pkg/body-based-routing/handlers/server.go +++ b/pkg/body-based-routing/handlers/server.go @@ -18,23 +18,27 @@ package handlers import ( "context" + "encoding/json" "errors" "io" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func NewServer() *Server { - return &Server{} +func NewServer(streaming bool) *Server { + return &Server{streaming: streaming} } // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto -type Server struct{} +type Server struct { + streaming bool +} func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { ctx := srv.Context() @@ -42,6 +46,8 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { loggerVerbose := logger.V(logutil.VERBOSE) loggerVerbose.Info("Processing") + reader, writer := io.Pipe() + for { select { case <-ctx.Done(): @@ -60,19 +66,25 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { return status.Errorf(codes.Unknown, "cannot receive stream request: %v", recvErr) } - var resp *extProcPb.ProcessingResponse + var responses []*extProcPb.ProcessingResponse var err error switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - resp, err = s.HandleRequestHeaders(req.GetRequestHeaders()) + if s.streaming && !req.GetRequestHeaders().GetEndOfStream() { + // If streaming and the body is not empty, then headers are handled when processing request body. + loggerVerbose.Info("Received headers, passing off header processing until body arrives...") + } else { + responses, err = s.HandleRequestHeaders(req.GetRequestHeaders()) + } case *extProcPb.ProcessingRequest_RequestBody: - resp, err = s.HandleRequestBody(ctx, req.GetRequestBody()) + loggerVerbose.Info("Incoming body chunk", "body", string(v.RequestBody.Body), "EoS", v.RequestBody.EndOfStream) + responses, err = s.processRequestBody(ctx, req.GetRequestBody(), writer, reader, logger) case *extProcPb.ProcessingRequest_RequestTrailers: - resp, err = s.HandleRequestTrailers(req.GetRequestTrailers()) + responses, err = s.HandleRequestTrailers(req.GetRequestTrailers()) case *extProcPb.ProcessingRequest_ResponseHeaders: - resp, err = s.HandleResponseHeaders(req.GetResponseHeaders()) + responses, err = s.HandleResponseHeaders(req.GetResponseHeaders()) case *extProcPb.ProcessingRequest_ResponseBody: - resp, err = s.HandleResponseBody(req.GetResponseBody()) + responses, err = s.HandleResponseBody(req.GetResponseBody()) default: logger.V(logutil.DEFAULT).Error(nil, "Unknown Request type", "request", v) return status.Error(codes.Unknown, "unknown request type") @@ -83,10 +95,56 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { return status.Errorf(status.Code(err), "failed to handle request: %v", err) } - loggerVerbose.Info("Response generated", "response", resp) - if err := srv.Send(resp); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Send failed") - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + for _, resp := range responses { + loggerVerbose.Info("Response generated", "response", resp) + if err := srv.Send(resp); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Send failed") + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } } } } + +func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBody, bufferWriter *io.PipeWriter, bufferReader *io.PipeReader, logger logr.Logger) ([]*extProcPb.ProcessingResponse, error) { + loggerVerbose := logger.V(logutil.VERBOSE) + + var requestBody map[string]interface{} + if s.streaming { + // In the stream case, we can receive multiple request bodies. + // To buffer the full message, we create a goroutine with a writer.Write() + // call, which will block until the corresponding reader reads from it. + // We do not read until we receive the EndofStream signal, and then + // decode the entire JSON body. + if !body.EndOfStream { + go func() { + loggerVerbose.Info("Writing to stream buffer") + _, err := bufferWriter.Write(body.Body) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error populating writer") + } + }() + + return nil, nil + } + + if body.EndOfStream { + loggerVerbose.Info("Flushing stream buffer") + decoder := json.NewDecoder(bufferReader) + if err := decoder.Decode(&requestBody); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") + } + bufferReader.Close() + } + } else { + if err := json.Unmarshal(body.GetBody(), &requestBody); err != nil { + return nil, err + } + } + + requestBodyResp, err := s.HandleRequestBody(ctx, requestBody) + if err != nil { + return nil, err + } + + return requestBodyResp, nil +} diff --git a/pkg/body-based-routing/server/runserver.go b/pkg/body-based-routing/server/runserver.go index 90a64b70..1646aa5a 100644 --- a/pkg/body-based-routing/server/runserver.go +++ b/pkg/body-based-routing/server/runserver.go @@ -34,17 +34,14 @@ import ( type ExtProcServerRunner struct { GrpcPort int SecureServing bool + Streaming bool } -// Default values for CLI flags in main -const ( - DefaultGrpcPort = 9004 // default for --grpcPort -) - -func NewDefaultExtProcServerRunner() *ExtProcServerRunner { +func NewDefaultExtProcServerRunner(port int, streaming bool) *ExtProcServerRunner { return &ExtProcServerRunner{ - GrpcPort: DefaultGrpcPort, + GrpcPort: port, SecureServing: true, + Streaming: streaming, } } @@ -65,7 +62,10 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { srv = grpc.NewServer() } - extProcPb.RegisterExternalProcessorServer(srv, handlers.NewServer()) + extProcPb.RegisterExternalProcessorServer( + srv, + handlers.NewServer(r.Streaming), + ) // Forward to the gRPC runnable. return runnable.GRPCServer("ext-proc", srv, r.GrpcPort).Start(ctx) diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go index be8b2721..718bfedf 100644 --- a/test/integration/bbr/hermetic_test.go +++ b/test/integration/bbr/hermetic_test.go @@ -35,8 +35,6 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -const port = runserver.DefaultGrpcPort - var logger = logutil.NewTestLogger().V(logutil.VERBOSE) func TestBodyBasedRouting(t *testing.T) { @@ -102,8 +100,10 @@ func TestBodyBasedRouting(t *testing.T) { } func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { + port := 9004 + serverCtx, stopServer := context.WithCancel(context.Background()) - serverRunner := runserver.NewDefaultExtProcServerRunner() + serverRunner := runserver.NewDefaultExtProcServerRunner(port, false) serverRunner.SecureServing = false go func() {