diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index 180ee3ce..8330daa4 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -1,10 +1,12 @@ package backend import ( + "math/rand" "sync" "inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1" corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" ) // The datastore is a local cache of relevant data for the given LLMServerPool (currently all pulled from k8s-api) @@ -22,3 +24,33 @@ func (ds *K8sDatastore) GetPodIPs() []string { }) return ips } + +func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string { + weights := 0 + + source := rand.NewSource(rand.Int63()) + if seed > 0 { + source = rand.NewSource(seed) + } + r := rand.New(source) + for _, model := range model.TargetModels { + weights += model.Weight + } + klog.Infof("Weights for Model(%v) total to: %v", model.Name, weights) + randomVal := r.Intn(weights) + for _, model := range model.TargetModels { + if randomVal < model.Weight { + return model.Name + } + randomVal -= model.Weight + } + return "" +} + +func ModelHasObjective(model *v1alpha1.Model) bool { + if model.Objective != nil && model.Objective.DesiredAveragePerOutputTokenLatencyAtP95OverMultipleRequests != nil { + return true + } + + return false +} diff --git a/pkg/ext-proc/backend/datastore_test.go b/pkg/ext-proc/backend/datastore_test.go new file mode 100644 index 00000000..b47db1a5 --- /dev/null +++ b/pkg/ext-proc/backend/datastore_test.go @@ -0,0 +1,88 @@ +package backend + +import ( + "testing" + + "inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1" +) + +var () + +func TestRandomWeightedDraw(t *testing.T) { + tests := []struct { + name string + datastore K8sDatastore + model *v1alpha1.Model + want string + }{ + { + name: "'random' distribution", + model: &v1alpha1.Model{ + TargetModels: []v1alpha1.TargetModel{ + { + Name: "canary", + Weight: 50, + }, + { + Name: "v1", + Weight: 50, + }, + }, + }, + want: "canary", + }, + { + name: "'random' distribution", + model: &v1alpha1.Model{ + TargetModels: []v1alpha1.TargetModel{ + { + Name: "canary", + Weight: 25, + }, + { + Name: "v1.1", + Weight: 55, + }, + { + Name: "v1", + Weight: 50, + }, + }, + }, + want: "v1", + }, + { + name: "'random' distribution", + model: &v1alpha1.Model{ + TargetModels: []v1alpha1.TargetModel{ + { + Name: "canary", + Weight: 20, + }, + { + Name: "v1.1", + Weight: 20, + }, + { + Name: "v1", + Weight: 10, + }, + }, + }, + want: "v1.1", + }, + } + var seedVal int64 + seedVal = 420 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for range 10000 { + model := RandomWeightedDraw(test.model, seedVal) + if model != test.want { + t.Errorf("Model returned!: %v", model) + break + } + } + }) + } +} diff --git a/pkg/ext-proc/handlers/request.go b/pkg/ext-proc/handlers/request.go index 8fdea646..512148e8 100644 --- a/pkg/ext-proc/handlers/request.go +++ b/pkg/ext-proc/handlers/request.go @@ -9,6 +9,8 @@ import ( extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" klog "k8s.io/klog/v2" + "inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1" + "inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend" "inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling" ) @@ -33,25 +35,38 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces return nil, fmt.Errorf("model not found in request") } klog.V(3).Infof("Model requested: %v", model) + modelName := model + + // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. + // This might be a security risk in the future where adapters not registered in the LLMService + // are able to be requested by using their distinct name. + modelObj := s.FetchModelData(model) + if modelObj != nil && len(modelObj.TargetModels) > 0 { + modelName = backend.RandomWeightedDraw(modelObj) + if modelName == "" { + return nil, fmt.Errorf("Error getting target model name for model %v", modelObj.Name) + } + } + klog.Infof("Model is null %v", modelObj == nil) llmReq := &scheduling.LLMRequest{ - Model: model, - // For now use the model as the target model. - // TODO: Once the API is approved, read the "LLMUseCase" configuration and apply traffic split. - TargetModels: map[string]int{model: 100}, - ResolvedTargetModel: model, - // TODO: Read from LLMService CRD. - Critical: true, + Model: model, + ResolvedTargetModel: modelName, + Critical: backend.ModelHasObjective(modelObj), } klog.V(3).Infof("LLM Request: %+v", llmReq) + requestBody := v.RequestBody.Body + var err error // Update target models in the body. - rb["model"] = llmReq.ResolvedTargetModel - updatedBody, err := json.Marshal(rb) - if err != nil { - klog.Errorf("Error marshaling request body: %v", err) - return nil, fmt.Errorf("error marshaling request body: %v", err) + if llmReq.Model != llmReq.ResolvedTargetModel { + rb["model"] = llmReq.ResolvedTargetModel + requestBody, err = json.Marshal(rb) + if err != nil { + klog.Errorf("Error marshaling request body: %v", err) + return nil, fmt.Errorf("error marshaling request body: %v", err) + } + klog.V(3).Infof("Updated body: %v", string(requestBody)) } - klog.V(3).Infof("Updated body: %v", string(updatedBody)) targetPod, err := s.scheduler.Schedule(llmReq) if err != nil { @@ -75,7 +90,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces { Header: &configPb.HeaderValue{ Key: "Content-Length", - RawValue: []byte(strconv.Itoa(len(updatedBody))), + RawValue: []byte(strconv.Itoa(len(requestBody))), }, }, } @@ -93,7 +108,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces }, BodyMutation: &extProcPb.BodyMutation{ Mutation: &extProcPb.BodyMutation_Body{ - Body: updatedBody, + Body: requestBody, }, }, }, @@ -103,6 +118,22 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces return resp, nil } +func (s *Server) FetchModelData(modelName string) (returnModel *v1alpha1.Model) { + s.datastore.LLMServices.Range(func(k, v any) bool { + service := v.(*v1alpha1.LLMService) + klog.Infof("Service name: %v", service.Name) + for _, model := range service.Spec.Models { + if model.Name == modelName { + returnModel = &model + // We want to stop iterating, return false. + return false + } + } + return true + }) + return +} + func HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest) *extProcPb.ProcessingResponse { klog.V(3).Info("--- In RequestHeaders processing ...") r := req.Request diff --git a/pkg/ext-proc/handlers/server.go b/pkg/ext-proc/handlers/server.go index 8ff07edd..bc96d0b8 100644 --- a/pkg/ext-proc/handlers/server.go +++ b/pkg/ext-proc/handlers/server.go @@ -13,11 +13,12 @@ import ( "inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling" ) -func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string) *Server { +func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore *backend.K8sDatastore) *Server { return &Server{ scheduler: scheduler, podProvider: pp, targetPodHeader: targetPodHeader, + datastore: datastore, } } @@ -29,6 +30,7 @@ type Server struct { // The key of the header to specify the target pod address. This value needs to match Envoy // configuration. targetPodHeader string + datastore *backend.K8sDatastore } type Scheduler interface { diff --git a/pkg/ext-proc/main.go b/pkg/ext-proc/main.go index 8a51c851..d949e9f4 100644 --- a/pkg/ext-proc/main.go +++ b/pkg/ext-proc/main.go @@ -132,7 +132,7 @@ func main() { if err := pp.Init(*refreshPodsInterval, *refreshMetricsInterval); err != nil { klog.Fatalf("failed to initialize: %v", err) } - extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader)) + extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader, datastore)) healthPb.RegisterHealthServer(s, &healthServer{}) klog.Infof("Starting gRPC server on port :%v", *port)