Skip to content

Integrating LLMService with weight splitting #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pkg/ext-proc/backend/datastore.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package backend

import (
"math/rand"
"sync"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
corev1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
)

// The datastore is a local cache of relevant data for the given LLMServerPool (currently all pulled from k8s-api)
Expand All @@ -22,3 +24,33 @@ func (ds *K8sDatastore) GetPodIPs() []string {
})
return ips
}

func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
weights := 0

source := rand.NewSource(rand.Int63())
if seed > 0 {
source = rand.NewSource(seed)
}
r := rand.New(source)
for _, model := range model.TargetModels {
weights += model.Weight
}
klog.Infof("Weights for Model(%v) total to: %v", model.Name, weights)
randomVal := r.Intn(weights)
for _, model := range model.TargetModels {
if randomVal < model.Weight {
return model.Name
}
randomVal -= model.Weight
}
return ""
}

func ModelHasObjective(model *v1alpha1.Model) bool {
if model.Objective != nil && model.Objective.DesiredAveragePerOutputTokenLatencyAtP95OverMultipleRequests != nil {
return true
}

return false
}
88 changes: 88 additions & 0 deletions pkg/ext-proc/backend/datastore_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package backend

import (
"testing"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
)

var ()

func TestRandomWeightedDraw(t *testing.T) {
tests := []struct {
name string
datastore K8sDatastore
model *v1alpha1.Model
want string
}{
{
name: "'random' distribution",
model: &v1alpha1.Model{
TargetModels: []v1alpha1.TargetModel{
{
Name: "canary",
Weight: 50,
},
{
Name: "v1",
Weight: 50,
},
},
},
want: "canary",
},
{
name: "'random' distribution",
model: &v1alpha1.Model{
TargetModels: []v1alpha1.TargetModel{
{
Name: "canary",
Weight: 25,
},
{
Name: "v1.1",
Weight: 55,
},
{
Name: "v1",
Weight: 50,
},
},
},
want: "v1",
},
{
name: "'random' distribution",
model: &v1alpha1.Model{
TargetModels: []v1alpha1.TargetModel{
{
Name: "canary",
Weight: 20,
},
{
Name: "v1.1",
Weight: 20,
},
{
Name: "v1",
Weight: 10,
},
},
},
want: "v1.1",
},
}
var seedVal int64
seedVal = 420
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for range 10000 {
model := RandomWeightedDraw(test.model, seedVal)
if model != test.want {
t.Errorf("Model returned!: %v", model)
break
}
}
})
}
}
61 changes: 46 additions & 15 deletions pkg/ext-proc/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
klog "k8s.io/klog/v2"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
)

Expand All @@ -33,25 +35,38 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
return nil, fmt.Errorf("model not found in request")
}
klog.V(3).Infof("Model requested: %v", model)
modelName := model

// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
// This might be a security risk in the future where adapters not registered in the LLMService
// are able to be requested by using their distinct name.
modelObj := s.FetchModelData(model)
if modelObj != nil && len(modelObj.TargetModels) > 0 {
modelName = backend.RandomWeightedDraw(modelObj)
if modelName == "" {
return nil, fmt.Errorf("Error getting target model name for model %v", modelObj.Name)
}
}
klog.Infof("Model is null %v", modelObj == nil)
llmReq := &scheduling.LLMRequest{
Model: model,
// For now use the model as the target model.
// TODO: Once the API is approved, read the "LLMUseCase" configuration and apply traffic split.
TargetModels: map[string]int{model: 100},
ResolvedTargetModel: model,
// TODO: Read from LLMService CRD.
Critical: true,
Model: model,
ResolvedTargetModel: modelName,
Critical: backend.ModelHasObjective(modelObj),
}
klog.V(3).Infof("LLM Request: %+v", llmReq)

requestBody := v.RequestBody.Body
var err error
// Update target models in the body.
rb["model"] = llmReq.ResolvedTargetModel
updatedBody, err := json.Marshal(rb)
if err != nil {
klog.Errorf("Error marshaling request body: %v", err)
return nil, fmt.Errorf("error marshaling request body: %v", err)
if llmReq.Model != llmReq.ResolvedTargetModel {
rb["model"] = llmReq.ResolvedTargetModel
requestBody, err = json.Marshal(rb)
if err != nil {
klog.Errorf("Error marshaling request body: %v", err)
return nil, fmt.Errorf("error marshaling request body: %v", err)
}
klog.V(3).Infof("Updated body: %v", string(requestBody))
}
klog.V(3).Infof("Updated body: %v", string(updatedBody))

targetPod, err := s.scheduler.Schedule(llmReq)
if err != nil {
Expand All @@ -75,7 +90,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
{
Header: &configPb.HeaderValue{
Key: "Content-Length",
RawValue: []byte(strconv.Itoa(len(updatedBody))),
RawValue: []byte(strconv.Itoa(len(requestBody))),
},
},
}
Expand All @@ -93,7 +108,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
},
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_Body{
Body: updatedBody,
Body: requestBody,
},
},
},
Expand All @@ -103,6 +118,22 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
return resp, nil
}

func (s *Server) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
s.datastore.LLMServices.Range(func(k, v any) bool {
service := v.(*v1alpha1.LLMService)
klog.Infof("Service name: %v", service.Name)
for _, model := range service.Spec.Models {
if model.Name == modelName {
returnModel = &model
// We want to stop iterating, return false.
return false
}
}
return true
})
return
}

func HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest) *extProcPb.ProcessingResponse {
klog.V(3).Info("--- In RequestHeaders processing ...")
r := req.Request
Expand Down
4 changes: 3 additions & 1 deletion pkg/ext-proc/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
)

func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string) *Server {
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore *backend.K8sDatastore) *Server {
return &Server{
scheduler: scheduler,
podProvider: pp,
targetPodHeader: targetPodHeader,
datastore: datastore,
}
}

Expand All @@ -29,6 +30,7 @@ type Server struct {
// The key of the header to specify the target pod address. This value needs to match Envoy
// configuration.
targetPodHeader string
datastore *backend.K8sDatastore
}

type Scheduler interface {
Expand Down
2 changes: 1 addition & 1 deletion pkg/ext-proc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func main() {
if err := pp.Init(*refreshPodsInterval, *refreshMetricsInterval); err != nil {
klog.Fatalf("failed to initialize: %v", err)
}
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader))
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader, datastore))
healthPb.RegisterHealthServer(s, &healthServer{})

klog.Infof("Starting gRPC server on port :%v", *port)
Expand Down