Skip to content

Commit afe4314

Browse files
authored
Merge pull request #64 from kfswain/model-splitting
Integrating LLMService with weight splitting
2 parents fcaaef2 + c26e1b2 commit afe4314

File tree

5 files changed

+170
-17
lines changed

5 files changed

+170
-17
lines changed

pkg/ext-proc/backend/datastore.go

+32
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package backend
22

33
import (
4+
"math/rand"
45
"sync"
56

67
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
78
corev1 "k8s.io/api/core/v1"
9+
"k8s.io/klog/v2"
810
)
911

1012
// The datastore is a local cache of relevant data for the given LLMServerPool (currently all pulled from k8s-api)
@@ -22,3 +24,33 @@ func (ds *K8sDatastore) GetPodIPs() []string {
2224
})
2325
return ips
2426
}
27+
28+
func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
29+
weights := 0
30+
31+
source := rand.NewSource(rand.Int63())
32+
if seed > 0 {
33+
source = rand.NewSource(seed)
34+
}
35+
r := rand.New(source)
36+
for _, model := range model.TargetModels {
37+
weights += model.Weight
38+
}
39+
klog.Infof("Weights for Model(%v) total to: %v", model.Name, weights)
40+
randomVal := r.Intn(weights)
41+
for _, model := range model.TargetModels {
42+
if randomVal < model.Weight {
43+
return model.Name
44+
}
45+
randomVal -= model.Weight
46+
}
47+
return ""
48+
}
49+
50+
func ModelHasObjective(model *v1alpha1.Model) bool {
51+
if model.Objective != nil && model.Objective.DesiredAveragePerOutputTokenLatencyAtP95OverMultipleRequests != nil {
52+
return true
53+
}
54+
55+
return false
56+
}
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package backend
2+
3+
import (
4+
"testing"
5+
6+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
7+
)
8+
9+
var ()
10+
11+
func TestRandomWeightedDraw(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
datastore K8sDatastore
15+
model *v1alpha1.Model
16+
want string
17+
}{
18+
{
19+
name: "'random' distribution",
20+
model: &v1alpha1.Model{
21+
TargetModels: []v1alpha1.TargetModel{
22+
{
23+
Name: "canary",
24+
Weight: 50,
25+
},
26+
{
27+
Name: "v1",
28+
Weight: 50,
29+
},
30+
},
31+
},
32+
want: "canary",
33+
},
34+
{
35+
name: "'random' distribution",
36+
model: &v1alpha1.Model{
37+
TargetModels: []v1alpha1.TargetModel{
38+
{
39+
Name: "canary",
40+
Weight: 25,
41+
},
42+
{
43+
Name: "v1.1",
44+
Weight: 55,
45+
},
46+
{
47+
Name: "v1",
48+
Weight: 50,
49+
},
50+
},
51+
},
52+
want: "v1",
53+
},
54+
{
55+
name: "'random' distribution",
56+
model: &v1alpha1.Model{
57+
TargetModels: []v1alpha1.TargetModel{
58+
{
59+
Name: "canary",
60+
Weight: 20,
61+
},
62+
{
63+
Name: "v1.1",
64+
Weight: 20,
65+
},
66+
{
67+
Name: "v1",
68+
Weight: 10,
69+
},
70+
},
71+
},
72+
want: "v1.1",
73+
},
74+
}
75+
var seedVal int64
76+
seedVal = 420
77+
for _, test := range tests {
78+
t.Run(test.name, func(t *testing.T) {
79+
for range 10000 {
80+
model := RandomWeightedDraw(test.model, seedVal)
81+
if model != test.want {
82+
t.Errorf("Model returned!: %v", model)
83+
break
84+
}
85+
}
86+
})
87+
}
88+
}

pkg/ext-proc/handlers/request.go

+46-15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1010
klog "k8s.io/klog/v2"
1111

12+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
13+
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1214
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
1315
)
1416

@@ -33,25 +35,38 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
3335
return nil, fmt.Errorf("model not found in request")
3436
}
3537
klog.V(3).Infof("Model requested: %v", model)
38+
modelName := model
39+
40+
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
41+
// This might be a security risk in the future where adapters not registered in the LLMService
42+
// are able to be requested by using their distinct name.
43+
modelObj := s.FetchModelData(model)
44+
if modelObj != nil && len(modelObj.TargetModels) > 0 {
45+
modelName = backend.RandomWeightedDraw(modelObj)
46+
if modelName == "" {
47+
return nil, fmt.Errorf("Error getting target model name for model %v", modelObj.Name)
48+
}
49+
}
50+
klog.Infof("Model is null %v", modelObj == nil)
3651
llmReq := &scheduling.LLMRequest{
37-
Model: model,
38-
// For now use the model as the target model.
39-
// TODO: Once the API is approved, read the "LLMUseCase" configuration and apply traffic split.
40-
TargetModels: map[string]int{model: 100},
41-
ResolvedTargetModel: model,
42-
// TODO: Read from LLMService CRD.
43-
Critical: true,
52+
Model: model,
53+
ResolvedTargetModel: modelName,
54+
Critical: backend.ModelHasObjective(modelObj),
4455
}
4556
klog.V(3).Infof("LLM Request: %+v", llmReq)
4657

58+
requestBody := v.RequestBody.Body
59+
var err error
4760
// Update target models in the body.
48-
rb["model"] = llmReq.ResolvedTargetModel
49-
updatedBody, err := json.Marshal(rb)
50-
if err != nil {
51-
klog.Errorf("Error marshaling request body: %v", err)
52-
return nil, fmt.Errorf("error marshaling request body: %v", err)
61+
if llmReq.Model != llmReq.ResolvedTargetModel {
62+
rb["model"] = llmReq.ResolvedTargetModel
63+
requestBody, err = json.Marshal(rb)
64+
if err != nil {
65+
klog.Errorf("Error marshaling request body: %v", err)
66+
return nil, fmt.Errorf("error marshaling request body: %v", err)
67+
}
68+
klog.V(3).Infof("Updated body: %v", string(requestBody))
5369
}
54-
klog.V(3).Infof("Updated body: %v", string(updatedBody))
5570

5671
targetPod, err := s.scheduler.Schedule(llmReq)
5772
if err != nil {
@@ -75,7 +90,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
7590
{
7691
Header: &configPb.HeaderValue{
7792
Key: "Content-Length",
78-
RawValue: []byte(strconv.Itoa(len(updatedBody))),
93+
RawValue: []byte(strconv.Itoa(len(requestBody))),
7994
},
8095
},
8196
}
@@ -93,7 +108,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
93108
},
94109
BodyMutation: &extProcPb.BodyMutation{
95110
Mutation: &extProcPb.BodyMutation_Body{
96-
Body: updatedBody,
111+
Body: requestBody,
97112
},
98113
},
99114
},
@@ -103,6 +118,22 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
103118
return resp, nil
104119
}
105120

121+
func (s *Server) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
122+
s.datastore.LLMServices.Range(func(k, v any) bool {
123+
service := v.(*v1alpha1.LLMService)
124+
klog.Infof("Service name: %v", service.Name)
125+
for _, model := range service.Spec.Models {
126+
if model.Name == modelName {
127+
returnModel = &model
128+
// We want to stop iterating, return false.
129+
return false
130+
}
131+
}
132+
return true
133+
})
134+
return
135+
}
136+
106137
func HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest) *extProcPb.ProcessingResponse {
107138
klog.V(3).Info("--- In RequestHeaders processing ...")
108139
r := req.Request

pkg/ext-proc/handlers/server.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ import (
1313
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
1414
)
1515

16-
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string) *Server {
16+
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore *backend.K8sDatastore) *Server {
1717
return &Server{
1818
scheduler: scheduler,
1919
podProvider: pp,
2020
targetPodHeader: targetPodHeader,
21+
datastore: datastore,
2122
}
2223
}
2324

@@ -29,6 +30,7 @@ type Server struct {
2930
// The key of the header to specify the target pod address. This value needs to match Envoy
3031
// configuration.
3132
targetPodHeader string
33+
datastore *backend.K8sDatastore
3234
}
3335

3436
type Scheduler interface {

pkg/ext-proc/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func main() {
132132
if err := pp.Init(*refreshPodsInterval, *refreshMetricsInterval); err != nil {
133133
klog.Fatalf("failed to initialize: %v", err)
134134
}
135-
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader))
135+
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), *targetPodHeader, datastore))
136136
healthPb.RegisterHealthServer(s, &healthServer{})
137137

138138
klog.Infof("Starting gRPC server on port :%v", *port)

0 commit comments

Comments
 (0)