Skip to content

Commit 53e414f

Browse files
committed
Fix build and test
1 parent afe4314 commit 53e414f

File tree

9 files changed

+87
-45
lines changed

9 files changed

+87
-45
lines changed

go.mod

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module inference.networking.x-k8s.io/llm-instance-gateway
22

3-
go 1.22.0
4-
toolchain go1.22.9
3+
go 1.22.7
4+
5+
toolchain go1.23.2
56

67
require (
78
github.com/bojand/ghz v0.120.0

pkg/ext-proc/backend/datastore.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ func (ds *K8sDatastore) GetPodIPs() []string {
2525
return ips
2626
}
2727

28+
func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
29+
s.LLMServices.Range(func(k, v any) bool {
30+
service := v.(*v1alpha1.LLMService)
31+
klog.V(3).Infof("Service name: %v", service.Name)
32+
for _, model := range service.Spec.Models {
33+
if model.Name == modelName {
34+
returnModel = &model
35+
// We want to stop iterating, return false.
36+
return false
37+
}
38+
}
39+
return true
40+
})
41+
return
42+
}
43+
2844
func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
2945
weights := 0
3046

@@ -36,7 +52,7 @@ func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
3652
for _, model := range model.TargetModels {
3753
weights += model.Weight
3854
}
39-
klog.Infof("Weights for Model(%v) total to: %v", model.Name, weights)
55+
klog.V(3).Infof("Weights for Model(%v) total to: %v", model.Name, weights)
4056
randomVal := r.Intn(weights)
4157
for _, model := range model.TargetModels {
4258
if randomVal < model.Weight {

pkg/ext-proc/backend/datastore_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ func TestRandomWeightedDraw(t *testing.T) {
7272
want: "v1.1",
7373
},
7474
}
75-
var seedVal int64
76-
seedVal = 420
75+
// var seedVal int64
76+
// seedVal = 420
7777
for _, test := range tests {
7878
t.Run(test.name, func(t *testing.T) {
79-
for range 10000 {
80-
model := RandomWeightedDraw(test.model, seedVal)
79+
for range 5 {
80+
model := RandomWeightedDraw(test.model, 0)
8181
if model != test.want {
8282
t.Errorf("Model returned!: %v", model)
8383
break

pkg/ext-proc/backend/fake.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package backend
22

33
import (
44
"context"
5-
"fmt"
5+
6+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
7+
klog "k8s.io/klog/v2"
68
)
79

810
type FakePodMetricsClient struct {
@@ -14,6 +16,14 @@ func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existi
1416
if err, ok := f.Err[pod]; ok {
1517
return nil, err
1618
}
17-
fmt.Printf("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod])
19+
klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod])
1820
return f.Res[pod], nil
1921
}
22+
23+
type FakeDataStore struct {
24+
Res map[string]*v1alpha1.Model
25+
}
26+
27+
func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
28+
return fds.Res[modelName]
29+
}

pkg/ext-proc/handlers/request.go

+4-21
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1010
klog "k8s.io/klog/v2"
1111

12-
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
1312
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1413
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
1514
)
@@ -40,14 +39,14 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
4039
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
4140
// This might be a security risk in the future where adapters not registered in the LLMService
4241
// are able to be requested by using their distinct name.
43-
modelObj := s.FetchModelData(model)
42+
modelObj := s.datastore.FetchModelData(model)
4443
if modelObj != nil && len(modelObj.TargetModels) > 0 {
45-
modelName = backend.RandomWeightedDraw(modelObj)
44+
modelName = backend.RandomWeightedDraw(modelObj, 0)
4645
if modelName == "" {
47-
return nil, fmt.Errorf("Error getting target model name for model %v", modelObj.Name)
46+
return nil, fmt.Errorf("error getting target model name for model %v", modelObj.Name)
4847
}
4948
}
50-
klog.Infof("Model is null %v", modelObj == nil)
49+
klog.V(3).Infof("Model is null %v", modelObj == nil)
5150
llmReq := &scheduling.LLMRequest{
5251
Model: model,
5352
ResolvedTargetModel: modelName,
@@ -118,22 +117,6 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
118117
return resp, nil
119118
}
120119

121-
func (s *Server) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
122-
s.datastore.LLMServices.Range(func(k, v any) bool {
123-
service := v.(*v1alpha1.LLMService)
124-
klog.Infof("Service name: %v", service.Name)
125-
for _, model := range service.Spec.Models {
126-
if model.Name == modelName {
127-
returnModel = &model
128-
// We want to stop iterating, return false.
129-
return false
130-
}
131-
}
132-
return true
133-
})
134-
return
135-
}
136-
137120
func HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest) *extProcPb.ProcessingResponse {
138121
klog.V(3).Info("--- In RequestHeaders processing ...")
139122
r := req.Request

pkg/ext-proc/handlers/server.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import (
99
"google.golang.org/grpc/status"
1010
klog "k8s.io/klog/v2"
1111

12+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
1213
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1314
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
1415
)
1516

16-
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore *backend.K8sDatastore) *Server {
17+
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore ModelDataStore) *Server {
1718
return &Server{
1819
scheduler: scheduler,
1920
podProvider: pp,
@@ -30,7 +31,7 @@ type Server struct {
3031
// The key of the header to specify the target pod address. This value needs to match Envoy
3132
// configuration.
3233
targetPodHeader string
33-
datastore *backend.K8sDatastore
34+
datastore ModelDataStore
3435
}
3536

3637
type Scheduler interface {
@@ -43,6 +44,10 @@ type PodProvider interface {
4344
UpdatePodMetrics(pod backend.Pod, pm *backend.PodMetrics)
4445
}
4546

47+
type ModelDataStore interface {
48+
FetchModelData(modelName string) (returnModel *v1alpha1.Model)
49+
}
50+
4651
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
4752
klog.V(3).Info("Processing")
4853
ctx := srv.Context()

pkg/ext-proc/test/benchmark/benchmark.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"google.golang.org/protobuf/proto"
1313
klog "k8s.io/klog/v2"
1414

15+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
1516
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1617
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/test"
1718
)
@@ -36,7 +37,7 @@ func main() {
3637
flag.Parse()
3738

3839
if *localServer {
39-
test.StartExtProc(port, *refreshPodsInterval, *refreshMetricsInterval, fakePods())
40+
test.StartExtProc(port, *refreshPodsInterval, *refreshMetricsInterval, fakePods(), fakeModels())
4041
time.Sleep(time.Second) // wait until server is up
4142
klog.Info("Server started")
4243
}
@@ -70,6 +71,18 @@ func generateRequest(mtd *desc.MethodDescriptor, callData *runner.CallData) []by
7071
return data
7172
}
7273

74+
func fakeModels() map[string]*v1alpha1.Model {
75+
models := map[string]*v1alpha1.Model{}
76+
for i := range *numFakePods {
77+
for j := range *numModelsPerPod {
78+
m := modelName(i*(*numModelsPerPod) + j)
79+
models[m] = &v1alpha1.Model{Name: m}
80+
}
81+
}
82+
83+
return models
84+
}
85+
7386
func fakePods() []*backend.PodMetrics {
7487
pms := make([]*backend.PodMetrics, 0, *numFakePods)
7588
for i := 0; i < *numFakePods; i++ {

pkg/ext-proc/test/hermetic_test.go

+21-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99
"time"
1010

11+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
1112
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1213

1314
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -28,13 +29,25 @@ func TestHandleRequestBody(t *testing.T) {
2829
name string
2930
req *extProcPb.ProcessingRequest
3031
pods []*backend.PodMetrics
32+
models map[string]*v1alpha1.Model
3133
wantHeaders []*configPb.HeaderValueOption
3234
wantBody []byte
3335
wantErr bool
3436
}{
3537
{
3638
name: "success",
3739
req: GenerateRequest("my-model"),
40+
models: map[string]*v1alpha1.Model{
41+
"my-model": {
42+
Name: "my-model",
43+
TargetModels: []v1alpha1.TargetModel{
44+
{
45+
Name: "my-model-v1",
46+
Weight: 100,
47+
},
48+
},
49+
},
50+
},
3851
// pod-1 will be picked because it has relatively low queue size, with the requested
3952
// model being active, and has low KV cache.
4053
pods: []*backend.PodMetrics{
@@ -52,11 +65,11 @@ func TestHandleRequestBody(t *testing.T) {
5265
{
5366
Pod: FakePod(1),
5467
Metrics: backend.Metrics{
55-
WaitingQueueSize: 3,
68+
WaitingQueueSize: 0,
5669
KVCacheUsagePercent: 0.1,
5770
ActiveModels: map[string]int{
58-
"foo": 1,
59-
"my-model": 1,
71+
"foo": 1,
72+
"my-model-v1": 1,
6073
},
6174
},
6275
},
@@ -81,17 +94,17 @@ func TestHandleRequestBody(t *testing.T) {
8194
{
8295
Header: &configPb.HeaderValue{
8396
Key: "Content-Length",
84-
RawValue: []byte("70"),
97+
RawValue: []byte("73"),
8598
},
8699
},
87100
},
88-
wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model\",\"prompt\":\"hello\",\"temperature\":0}"),
101+
wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model-v1\",\"prompt\":\"hello\",\"temperature\":0}"),
89102
},
90103
}
91104

92105
for _, test := range tests {
93106
t.Run(test.name, func(t *testing.T) {
94-
client, cleanup := setUpServer(t, test.pods)
107+
client, cleanup := setUpServer(t, test.pods, test.models)
95108
t.Cleanup(cleanup)
96109
want := &extProcPb.ProcessingResponse{
97110
Response: &extProcPb.ProcessingResponse_RequestBody{
@@ -123,8 +136,8 @@ func TestHandleRequestBody(t *testing.T) {
123136

124137
}
125138

126-
func setUpServer(t *testing.T, pods []*backend.PodMetrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
127-
server := StartExtProc(port, time.Second, time.Second, pods)
139+
func setUpServer(t *testing.T, pods []*backend.PodMetrics, models map[string]*v1alpha1.Model) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
140+
server := StartExtProc(port, time.Second, time.Second, pods, models)
128141

129142
address := fmt.Sprintf("localhost:%v", port)
130143
// Create a grpc connection

pkg/ext-proc/test/utils.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ import (
1313

1414
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1515

16+
"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
1617
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
1718
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/handlers"
1819
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
1920
)
2021

21-
func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics) *grpc.Server {
22+
func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics, models map[string]*v1alpha1.Model) *grpc.Server {
2223
ps := make(backend.PodSet)
2324
pms := make(map[backend.Pod]*backend.PodMetrics)
2425
for _, pod := range pods {
@@ -30,19 +31,19 @@ func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Dur
3031
if err := pp.Init(refreshPodsInterval, refreshMetricsInterval); err != nil {
3132
klog.Fatalf("failed to initialize: %v", err)
3233
}
33-
return startExtProc(port, pp)
34+
return startExtProc(port, pp, models)
3435
}
3536

3637
// startExtProc starts an extProc server with fake pods.
37-
func startExtProc(port int, pp *backend.Provider) *grpc.Server {
38+
func startExtProc(port int, pp *backend.Provider, models map[string]*v1alpha1.Model) *grpc.Server {
3839
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
3940
if err != nil {
4041
klog.Fatalf("failed to listen: %v", err)
4142
}
4243

4344
s := grpc.NewServer()
4445

45-
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), "target-pod"))
46+
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), "target-pod", &backend.FakeDataStore{Res: models}))
4647

4748
klog.Infof("Starting gRPC server on port :%v", port)
4849
reflection.Register(s)

0 commit comments

Comments
 (0)