Skip to content

Fix build and test #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module inference.networking.x-k8s.io/llm-instance-gateway

go 1.22.0
toolchain go1.22.9
go 1.22.7

toolchain go1.23.2

require (
github.com/bojand/ghz v0.120.0
Expand Down
18 changes: 17 additions & 1 deletion pkg/ext-proc/backend/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ func (ds *K8sDatastore) GetPodIPs() []string {
return ips
}

func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
s.LLMServices.Range(func(k, v any) bool {
service := v.(*v1alpha1.LLMService)
klog.V(3).Infof("Service name: %v", service.Name)
for _, model := range service.Spec.Models {
if model.Name == modelName {
returnModel = &model
// We want to stop iterating, return false.
return false
}
}
return true
})
return
}

func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
weights := 0

Expand All @@ -36,7 +52,7 @@ func RandomWeightedDraw(model *v1alpha1.Model, seed int64) string {
for _, model := range model.TargetModels {
weights += model.Weight
}
klog.Infof("Weights for Model(%v) total to: %v", model.Name, weights)
klog.V(3).Infof("Weights for Model(%v) total to: %v", model.Name, weights)
randomVal := r.Intn(weights)
for _, model := range model.TargetModels {
if randomVal < model.Weight {
Expand Down
14 changes: 12 additions & 2 deletions pkg/ext-proc/backend/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package backend

import (
"context"
"fmt"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
klog "k8s.io/klog/v2"
)

type FakePodMetricsClient struct {
Expand All @@ -14,6 +16,14 @@ func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existi
if err, ok := f.Err[pod]; ok {
return nil, err
}
fmt.Printf("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod])
klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod])
return f.Res[pod], nil
}

type FakeDataStore struct {
Res map[string]*v1alpha1.Model
}

func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
return fds.Res[modelName]
}
25 changes: 4 additions & 21 deletions pkg/ext-proc/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
klog "k8s.io/klog/v2"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
)
Expand Down Expand Up @@ -40,14 +39,14 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
// This might be a security risk in the future where adapters not registered in the LLMService
// are able to be requested by using their distinct name.
modelObj := s.FetchModelData(model)
modelObj := s.datastore.FetchModelData(model)
if modelObj != nil && len(modelObj.TargetModels) > 0 {
modelName = backend.RandomWeightedDraw(modelObj)
modelName = backend.RandomWeightedDraw(modelObj, 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bug, missing the seed value argument

if modelName == "" {
return nil, fmt.Errorf("Error getting target model name for model %v", modelObj.Name)
return nil, fmt.Errorf("error getting target model name for model %v", modelObj.Name)
}
}
klog.Infof("Model is null %v", modelObj == nil)
klog.V(3).Infof("Model is null %v", modelObj == nil)
llmReq := &scheduling.LLMRequest{
Model: model,
ResolvedTargetModel: modelName,
Expand Down Expand Up @@ -118,22 +117,6 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
return resp, nil
}

func (s *Server) FetchModelData(modelName string) (returnModel *v1alpha1.Model) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to datastore for better encapsulation. Also makes it easier to fake the datastore in tests.

s.datastore.LLMServices.Range(func(k, v any) bool {
service := v.(*v1alpha1.LLMService)
klog.Infof("Service name: %v", service.Name)
for _, model := range service.Spec.Models {
if model.Name == modelName {
returnModel = &model
// We want to stop iterating, return false.
return false
}
}
return true
})
return
}

func HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest) *extProcPb.ProcessingResponse {
klog.V(3).Info("--- In RequestHeaders processing ...")
r := req.Request
Expand Down
9 changes: 7 additions & 2 deletions pkg/ext-proc/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"google.golang.org/grpc/status"
klog "k8s.io/klog/v2"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
)

func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore *backend.K8sDatastore) *Server {
func NewServer(pp PodProvider, scheduler Scheduler, targetPodHeader string, datastore ModelDataStore) *Server {
return &Server{
scheduler: scheduler,
podProvider: pp,
Expand All @@ -30,7 +31,7 @@ type Server struct {
// The key of the header to specify the target pod address. This value needs to match Envoy
// configuration.
targetPodHeader string
datastore *backend.K8sDatastore
datastore ModelDataStore
}

type Scheduler interface {
Expand All @@ -43,6 +44,10 @@ type PodProvider interface {
UpdatePodMetrics(pod backend.Pod, pm *backend.PodMetrics)
}

type ModelDataStore interface {
FetchModelData(modelName string) (returnModel *v1alpha1.Model)
}

func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
klog.V(3).Info("Processing")
ctx := srv.Context()
Expand Down
15 changes: 14 additions & 1 deletion pkg/ext-proc/test/benchmark/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/proto"
klog "k8s.io/klog/v2"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/test"
)
Expand All @@ -36,7 +37,7 @@ func main() {
flag.Parse()

if *localServer {
test.StartExtProc(port, *refreshPodsInterval, *refreshMetricsInterval, fakePods())
test.StartExtProc(port, *refreshPodsInterval, *refreshMetricsInterval, fakePods(), fakeModels())
time.Sleep(time.Second) // wait until server is up
klog.Info("Server started")
}
Expand Down Expand Up @@ -70,6 +71,18 @@ func generateRequest(mtd *desc.MethodDescriptor, callData *runner.CallData) []by
return data
}

func fakeModels() map[string]*v1alpha1.Model {
models := map[string]*v1alpha1.Model{}
for i := range *numFakePods {
for j := range *numModelsPerPod {
m := modelName(i*(*numModelsPerPod) + j)
models[m] = &v1alpha1.Model{Name: m}
}
}

return models
}

func fakePods() []*backend.PodMetrics {
pms := make([]*backend.PodMetrics, 0, *numFakePods)
for i := 0; i < *numFakePods; i++ {
Expand Down
29 changes: 21 additions & 8 deletions pkg/ext-proc/test/hermetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand All @@ -28,13 +29,25 @@ func TestHandleRequestBody(t *testing.T) {
name string
req *extProcPb.ProcessingRequest
pods []*backend.PodMetrics
models map[string]*v1alpha1.Model
wantHeaders []*configPb.HeaderValueOption
wantBody []byte
wantErr bool
}{
{
name: "success",
req: GenerateRequest("my-model"),
models: map[string]*v1alpha1.Model{
"my-model": {
Name: "my-model",
TargetModels: []v1alpha1.TargetModel{
{
Name: "my-model-v1",
Weight: 100,
},
},
},
},
// pod-1 will be picked because it has relatively low queue size, with the requested
// model being active, and has low KV cache.
pods: []*backend.PodMetrics{
Expand All @@ -52,11 +65,11 @@ func TestHandleRequestBody(t *testing.T) {
{
Pod: FakePod(1),
Metrics: backend.Metrics{
WaitingQueueSize: 3,
WaitingQueueSize: 0,
KVCacheUsagePercent: 0.1,
ActiveModels: map[string]int{
"foo": 1,
"my-model": 1,
"foo": 1,
"my-model-v1": 1,
},
},
},
Expand All @@ -81,17 +94,17 @@ func TestHandleRequestBody(t *testing.T) {
{
Header: &configPb.HeaderValue{
Key: "Content-Length",
RawValue: []byte("70"),
RawValue: []byte("73"),
},
},
},
wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model\",\"prompt\":\"hello\",\"temperature\":0}"),
wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model-v1\",\"prompt\":\"hello\",\"temperature\":0}"),
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client, cleanup := setUpServer(t, test.pods)
client, cleanup := setUpServer(t, test.pods, test.models)
t.Cleanup(cleanup)
want := &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestBody{
Expand Down Expand Up @@ -123,8 +136,8 @@ func TestHandleRequestBody(t *testing.T) {

}

func setUpServer(t *testing.T, pods []*backend.PodMetrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
server := StartExtProc(port, time.Second, time.Second, pods)
func setUpServer(t *testing.T, pods []*backend.PodMetrics, models map[string]*v1alpha1.Model) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) {
server := StartExtProc(port, time.Second, time.Second, pods, models)

address := fmt.Sprintf("localhost:%v", port)
// Create a grpc connection
Expand Down
9 changes: 5 additions & 4 deletions pkg/ext-proc/test/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ import (

extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"

"inference.networking.x-k8s.io/llm-instance-gateway/api/v1alpha1"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/backend"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/handlers"
"inference.networking.x-k8s.io/llm-instance-gateway/pkg/ext-proc/scheduling"
)

func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics) *grpc.Server {
func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics, models map[string]*v1alpha1.Model) *grpc.Server {
ps := make(backend.PodSet)
pms := make(map[backend.Pod]*backend.PodMetrics)
for _, pod := range pods {
Expand All @@ -30,19 +31,19 @@ func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Dur
if err := pp.Init(refreshPodsInterval, refreshMetricsInterval); err != nil {
klog.Fatalf("failed to initialize: %v", err)
}
return startExtProc(port, pp)
return startExtProc(port, pp, models)
}

// startExtProc starts an extProc server with fake pods.
func startExtProc(port int, pp *backend.Provider) *grpc.Server {
func startExtProc(port int, pp *backend.Provider, models map[string]*v1alpha1.Model) *grpc.Server {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
klog.Fatalf("failed to listen: %v", err)
}

s := grpc.NewServer()

extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), "target-pod"))
extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), "target-pod", &backend.FakeDataStore{Res: models}))

klog.Infof("Starting gRPC server on port :%v", port)
reflection.Register(s)
Expand Down