From 87162b4ce585c67ef2085746a02bb13bd78fbca8 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Tue, 1 Apr 2025 14:47:26 -0700 Subject: [PATCH] Refactor scheduler --- pkg/epp/backend/metrics/metrics.go | 3 +- pkg/epp/backend/metrics/metrics_test.go | 11 +- pkg/epp/backend/metrics/pod_metrics_test.go | 2 + pkg/epp/backend/metrics/types.go | 26 +- pkg/epp/datastore/datastore_test.go | 3 + pkg/epp/handlers/request.go | 4 +- pkg/epp/handlers/server.go | 5 +- pkg/epp/handlers/streamingserver.go | 4 +- pkg/epp/scheduling/filter.go | 151 +++++---- pkg/epp/scheduling/filter_test.go | 326 +++----------------- pkg/epp/scheduling/scheduler.go | 121 ++++---- pkg/epp/scheduling/scheduler_test.go | 232 ++++++++++++++ pkg/epp/scheduling/types.go | 27 -- pkg/epp/scheduling/types/types.go | 88 ++++++ test/integration/epp/hermetic_test.go | 39 ++- 15 files changed, 592 insertions(+), 450 deletions(-) create mode 100644 pkg/epp/scheduling/scheduler_test.go delete mode 100644 pkg/epp/scheduling/types.go create mode 100644 pkg/epp/scheduling/types/types.go diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index d48b1dc5..96814b4b 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -109,6 +109,7 @@ func (p *PodMetricsClientImpl) promToPodMetrics( if loraMetrics != nil { updated.ActiveModels = make(map[string]int) + updated.WaitingModels = make(map[string]int) for _, label := range loraMetrics.GetLabel() { if label.GetName() == LoraInfoRunningAdaptersMetricName { if label.GetValue() != "" { @@ -122,7 +123,7 @@ func (p *PodMetricsClientImpl) promToPodMetrics( if label.GetValue() != "" { adapterList := strings.Split(label.GetValue(), ",") for _, adapter := range adapterList { - updated.ActiveModels[adapter] = 0 + updated.WaitingModels[adapter] = 0 } } } diff --git a/pkg/epp/backend/metrics/metrics_test.go b/pkg/epp/backend/metrics/metrics_test.go index d0396bf7..e3b45b94 100644 --- a/pkg/epp/backend/metrics/metrics_test.go +++ b/pkg/epp/backend/metrics/metrics_test.go @@ -404,7 +404,8 @@ func TestPromToPodMetrics(t *testing.T) { expectedMetrics: &Metrics{ WaitingQueueSize: 7, KVCacheUsagePercent: 0.8, - ActiveModels: map[string]int{"lora1": 0, "lora2": 0, "lora3": 0}, + ActiveModels: map[string]int{"lora1": 0, "lora2": 0}, + WaitingModels: map[string]int{"lora3": 0}, MaxActiveModels: 3, }, }, @@ -416,8 +417,8 @@ func TestPromToPodMetrics(t *testing.T) { KVCacheUtilization: &MetricSpec{MetricName: "vllm_usage"}, LoraRequestInfo: &MetricSpec{MetricName: "vllm:lora_requests_info"}, }, - existingMetrics: &Metrics{ActiveModels: map[string]int{}}, - expectedMetrics: &Metrics{ActiveModels: map[string]int{}}, + existingMetrics: &Metrics{ActiveModels: map[string]int{}, WaitingModels: map[string]int{}}, + expectedMetrics: &Metrics{ActiveModels: map[string]int{}, WaitingModels: map[string]int{}}, expectedErr: multierr.Combine(errors.New("metric family \"vllm_waiting\" not found"), errors.New("metric family \"vllm_usage\" not found"), errors.New("metric family \"vllm:lora_requests_info\" not found")), }, { @@ -439,7 +440,8 @@ func TestPromToPodMetrics(t *testing.T) { expectedMetrics: &Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.8, - ActiveModels: map[string]int{"lora1": 0, "lora2": 0, "lora3": 0}, + ActiveModels: map[string]int{"lora1": 0, "lora2": 0}, + WaitingModels: map[string]int{"lora3": 0}, MaxActiveModels: 3, }, expectedErr: errors.New("metric family \"vllm_waiting\" not found"), @@ -457,6 +459,7 @@ func TestPromToPodMetrics(t *testing.T) { existingMetrics: &Metrics{}, expectedMetrics: &Metrics{ ActiveModels: map[string]int{"lora1": 0}, + WaitingModels: map[string]int{}, MaxActiveModels: 0, // Should still default to 0. }, diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go index cf6698ca..e79c1bf0 100644 --- a/pkg/epp/backend/metrics/pod_metrics_test.go +++ b/pkg/epp/backend/metrics/pod_metrics_test.go @@ -44,6 +44,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } updated = &Metrics{ WaitingQueueSize: 9999, @@ -53,6 +54,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } ) diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 17db23b4..925a0cc5 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -41,6 +41,7 @@ type PodMetricsFactory struct { } func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics { + pod := toInternalPod(in) pm := &podMetrics{ pmc: f.pmc, ds: ds, @@ -48,9 +49,9 @@ func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1. parentCtx: parentCtx, once: sync.Once{}, done: make(chan struct{}), - logger: log.FromContext(parentCtx), + logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), } - pm.pod.Store(toInternalPod(in)) + pm.pod.Store(pod) pm.metrics.Store(newMetrics()) pm.startRefreshLoop() @@ -77,9 +78,20 @@ func (p *Pod) String() string { return fmt.Sprintf("%+v", *p) } +func (p *Pod) Clone() *Pod { + return &Pod{ + NamespacedName: types.NamespacedName{ + Name: p.NamespacedName.Name, + Namespace: p.NamespacedName.Namespace, + }, + Address: p.Address, + } +} + type Metrics struct { // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. - ActiveModels map[string]int + ActiveModels map[string]int + WaitingModels map[string]int // MaxActiveModels is the maximum number of models that can be loaded to GPU. MaxActiveModels int RunningQueueSize int @@ -93,7 +105,8 @@ type Metrics struct { func newMetrics() *Metrics { return &Metrics{ - ActiveModels: make(map[string]int), + ActiveModels: make(map[string]int), + WaitingModels: make(map[string]int), } } @@ -109,8 +122,13 @@ func (m *Metrics) Clone() *Metrics { for k, v := range m.ActiveModels { cm[k] = v } + wm := make(map[string]int, len(m.WaitingModels)) + for k, v := range m.WaitingModels { + wm[k] = v + } clone := &Metrics{ ActiveModels: cm, + WaitingModels: wm, MaxActiveModels: m.MaxActiveModels, RunningQueueSize: m.RunningQueueSize, WaitingQueueSize: m.WaitingQueueSize, diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 22bb0365..abbff429 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -236,6 +236,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } pod2 = &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ @@ -250,6 +251,7 @@ var ( "foo1": 1, "bar1": 1, }, + WaitingModels: map[string]int{}, } pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} pod2NamespacedName = types.NamespacedName{Name: pod2.Name, Namespace: pod2.Namespace} @@ -305,6 +307,7 @@ func TestMetrics(t *testing.T) { // Failed to fetch pod2 metrics so it remains the default values. { ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, WaitingQueueSize: 0, KVCacheUsagePercent: 0, MaxActiveModels: 0, diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index d7678fad..b786a15d 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -27,7 +27,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -74,7 +74,7 @@ func (s *Server) HandleRequestBody( return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} } } - llmReq := &scheduling.LLMRequest{ + llmReq := &schedulingtypes.LLMRequest{ Model: model, ResolvedTargetModel: modelName, Critical: datastore.IsCritical(modelObj), diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index a92f091c..f6f375dd 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -26,10 +26,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -57,7 +56,7 @@ type Server struct { } type Scheduler interface { - Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backendmetrics.PodMetrics, err error) + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error) } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go index 874dd734..0e9020d8 100644 --- a/pkg/epp/handlers/streamingserver.go +++ b/pkg/epp/handlers/streamingserver.go @@ -37,7 +37,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -343,7 +343,7 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} } } - llmReq := &scheduling.LLMRequest{ + llmReq := &schedulingtypes.LLMRequest{ Model: model, ResolvedTargetModel: modelName, Critical: datastore.IsCritical(modelObj), diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/filter.go index f4848089..99044e97 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/filter.go @@ -22,48 +22,63 @@ import ( "math/rand" "time" - "github.com/go-logr/logr" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) type Filter interface { Name() string - Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) + Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) } -// filter applies current filterFunc, and then recursively applies next filters depending success or -// failure of the current filterFunc. -// It can be used to construct a flow chart algorithm. -type filter struct { +type basicFilter struct { name string filter filterFunc +} + +func (bf *basicFilter) Name() string { + if bf == nil { + return "nil" + } + return bf.name +} + +func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + loggerTrace := ctx.Logger.V(logutil.TRACE) + loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods)) + + return bf.filter(ctx, pods) +} + +// decisionTreeFilter applies current filterFunc, and then recursively applies next filters +// depending success or failure of the current filter. +// It can be used to construct a flow chart algorithm. +type decisionTreeFilter struct { + current Filter // nextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - nextOnSuccess *filter + nextOnSuccess Filter // nextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - nextOnFailure *filter + nextOnFailure Filter // nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. // NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of // nextOnSuccessOrFailure, in the success and failure scenarios, respectively. - nextOnSuccessOrFailure *filter + nextOnSuccessOrFailure Filter } -func (f *filter) Name() string { +func (f *decisionTreeFilter) Name() string { if f == nil { return "nil" } - return f.name + return f.current.Name() } -func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - loggerTrace := logger.V(logutil.TRACE) - loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) - - filtered, err := f.filter(logger, req, pods) +func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + loggerTrace := ctx.Logger.V(logutil.TRACE) + filtered, err := f.current.Filter(ctx, pods) next := f.nextOnSuccessOrFailure if err == nil && len(filtered) > 0 { @@ -76,7 +91,7 @@ func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []backendmetri } loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) // On success, pass the filtered result to the next filter. - return next.Filter(logger, req, filtered) + return next.Filter(ctx, filtered) } else { if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil { // No succeeding filters to run, return. @@ -87,19 +102,19 @@ func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []backendmetri } loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) // On failure, pass the initial set of pods to the next filter. - return next.Filter(logger, req, pods) + return next.Filter(ctx, pods) } } // filterFunc filters a set of input pods to a subset. -type filterFunc func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) +type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - filtered := []backendmetrics.PodMetrics{} + return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + filtered := []*types.PodMetrics{} for _, pod := range pods { - pass := pp(req, pod) + pass := pp(ctx.Req, pod) if pass { filtered = append(filtered, pod) } @@ -111,6 +126,11 @@ func toFilterFunc(pp podPredicate) filterFunc { } } +var leastQueueFilter = &basicFilter{ + name: "least queuing", + filter: leastQueuingFilterFunc, +} + // leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range // (max-min) by the number of pods, and finds the pods that fall into the first range. // The intuition is that if there are multiple pods that share similar queue size in the low range, @@ -118,30 +138,36 @@ func toFilterFunc(pp podPredicate) filterFunc { // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { min := math.MaxInt max := 0 - filtered := []backendmetrics.PodMetrics{} + filtered := []*types.PodMetrics{} for _, pod := range pods { - if pod.GetMetrics().WaitingQueueSize <= min { - min = pod.GetMetrics().WaitingQueueSize + if pod.WaitingQueueSize <= min { + min = pod.WaitingQueueSize } - if pod.GetMetrics().WaitingQueueSize >= max { - max = pod.GetMetrics().WaitingQueueSize + if pod.WaitingQueueSize >= max { + max = pod.WaitingQueueSize } } for _, pod := range pods { - if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { + if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) { filtered = append(filtered, pod) } } return filtered, nil } -func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool { - return pod.GetMetrics().WaitingQueueSize < config.QueueingThresholdLoRA +var lowQueueFilter = &basicFilter{ + name: "low queueing filter", + filter: toFilterFunc((queueThresholdPredicate(config.QueueingThresholdLoRA))), +} + +var leastKVCacheFilter = &basicFilter{ + name: "least KV cache percent", + filter: leastKVCacheFilterFunc, } // leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range @@ -150,39 +176,31 @@ func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { min := math.MaxFloat64 var max float64 = 0 - filtered := []backendmetrics.PodMetrics{} + filtered := []*types.PodMetrics{} for _, pod := range pods { - if pod.GetMetrics().KVCacheUsagePercent <= min { - min = pod.GetMetrics().KVCacheUsagePercent + if pod.KVCacheUsagePercent <= min { + min = pod.KVCacheUsagePercent } - if pod.GetMetrics().KVCacheUsagePercent >= max { - max = pod.GetMetrics().KVCacheUsagePercent + if pod.KVCacheUsagePercent >= max { + max = pod.KVCacheUsagePercent } } for _, pod := range pods { - if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { filtered = append(filtered, pod) } } return filtered, nil } -// podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *LLMRequest, pod backendmetrics.PodMetrics) bool - -// We consider serving an adapter low cost it the adapter is active in the model server, or the -// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by -// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to -// a single pod. This gave good performance in our initial benchmarking results in the scenario -// where # of lora slots > # of lora adapters. -func lowLoRACostPredicate(req *LLMRequest, pod backendmetrics.PodMetrics) bool { - _, ok := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel] - return ok || len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels +var loRAAffinityFilter = &basicFilter{ + name: "affinity LoRA", + filter: loRASoftAffinityFilterFunc, } // loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods @@ -201,18 +219,20 @@ func lowLoRACostPredicate(req *LLMRequest, pod backendmetrics.PodMetrics) bool { // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]backendmetrics.PodMetrics, 0, len(pods)) - filtered_available := make([]backendmetrics.PodMetrics, 0, len(pods)) + filtered_affinity := make([]*types.PodMetrics, 0, len(pods)) + filtered_available := make([]*types.PodMetrics, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { + _, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel] - if _, exists := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel]; exists { + if active || waiting { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels { + } else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -237,12 +257,23 @@ func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendm return filtered_available, nil } -func criticalRequestPredicate(req *LLMRequest, _ backendmetrics.PodMetrics) bool { - return req.Critical +// podPredicate is a filter function to check whether a pod is desired. +type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool + +func queueThresholdPredicate(queueThreshold int) podPredicate { + return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return pod.WaitingQueueSize <= queueThreshold + } +} + +func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { + return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return pod.KVCacheUsagePercent <= kvCacheThreshold + } } -func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate { - return func(req *LLMRequest, pod backendmetrics.PodMetrics) bool { - return pod.GetMetrics().WaitingQueueSize <= queueThreshold && pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold +func (pp podPredicate) and(another podPredicate) podPredicate { + return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return pp(req, pod) && another(req, pod) } } diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/filter_test.go index 127e6c21..543826d0 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/filter_test.go @@ -17,217 +17,48 @@ limitations under the License. package scheduling import ( + "context" "errors" "testing" - "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" - "k8s.io/apimachinery/pkg/types" + k8stypes "k8s.io/apimachinery/pkg/types" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestFilter(t *testing.T) { - logger := logutil.NewTestLogger() - tests := []struct { name string - req *LLMRequest - input []*backendmetrics.FakePodMetrics - output []*backendmetrics.FakePodMetrics + req *types.LLMRequest + input []*types.PodMetrics + output []*types.PodMetrics err bool - filter *filter + filter *decisionTreeFilter }{ { name: "simple filter without successor, failure", - filter: &filter{filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - return nil, errors.New("filter error") - }}, - err: true, - }, - { - name: "default filter, critical request", - filter: defaultFilter, - req: &LLMRequest{ - Model: "critical", - ResolvedTargetModel: "critical", - Critical: true, - }, - // pod2 will be picked because it has relatively low queue size, with the requested - // model being active, and has low KV cache. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - }, - }, - { - name: "default filter, sheddable request, accepted", - filter: defaultFilter, - req: &LLMRequest{ - Model: "sheddable", - ResolvedTargetModel: "sheddable", - Critical: false, - }, - // pod1 will be picked because it has capacity for the sheddable request. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, + filter: &decisionTreeFilter{ + current: &basicFilter{ + name: "error", + filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + return nil, errors.New("filter error") }, }, }, - }, - { - name: "default filter, sheddable request, dropped", - filter: defaultFilter, - req: &LLMRequest{ - Model: "sheddable", - ResolvedTargetModel: "sheddable", - Critical: false, - }, - // All pods have higher KV cache thant the threshold, so the sheddable request will be - // dropped. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.85, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.85, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{}, - err: true, + err: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.filter.Filter(logger, test.req, toInterface(test.input)) + ctx := types.NewContext(context.Background(), test.req, test.input) + got, err := test.filter.Filter(ctx, test.input) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -235,26 +66,24 @@ func TestFilter(t *testing.T) { } func TestFilterFunc(t *testing.T) { - logger := logutil.NewTestLogger() - tests := []struct { name string f filterFunc - req *LLMRequest - input []*backendmetrics.FakePodMetrics - output []*backendmetrics.FakePodMetrics + req *types.LLMRequest + input []*types.PodMetrics + output []*types.PodMetrics err bool }{ { name: "least queuing empty input", f: leastQueuingFilterFunc, - input: []*backendmetrics.FakePodMetrics{}, - output: []*backendmetrics.FakePodMetrics{}, + input: []*types.PodMetrics{}, + output: []*types.PodMetrics{}, }, { name: "least queuing", f: leastQueuingFilterFunc, - input: []*backendmetrics.FakePodMetrics{ + input: []*types.PodMetrics{ { Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -271,7 +100,7 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*backendmetrics.FakePodMetrics{ + output: []*types.PodMetrics{ { Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -287,13 +116,13 @@ func TestFilterFunc(t *testing.T) { { name: "least kv cache empty input", f: leastKVCacheFilterFunc, - input: []*backendmetrics.FakePodMetrics{}, - output: []*backendmetrics.FakePodMetrics{}, + input: []*types.PodMetrics{}, + output: []*types.PodMetrics{}, }, { name: "least kv cache", f: leastKVCacheFilterFunc, - input: []*backendmetrics.FakePodMetrics{ + input: []*types.PodMetrics{ { Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, @@ -310,7 +139,7 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*backendmetrics.FakePodMetrics{ + output: []*types.PodMetrics{ { Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, @@ -324,9 +153,9 @@ func TestFilterFunc(t *testing.T) { }, }, { - name: "noQueueAndLessThanKVCacheThresholdPredicate", - f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)), - input: []*backendmetrics.FakePodMetrics{ + name: "lowQueueAndLessThanKVCacheThresholdPredicate", + f: toFilterFunc(queueThresholdPredicate(0).and(kvCacheThresholdPredicate(0.8))), + input: []*types.PodMetrics{ { // This pod should be returned. Metrics: &backendmetrics.Metrics{ @@ -349,7 +178,7 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*backendmetrics.FakePodMetrics{ + output: []*types.PodMetrics{ { Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -358,72 +187,17 @@ func TestFilterFunc(t *testing.T) { }, }, }, - { - name: "low LoRA cost", - f: toFilterFunc(lowLoRACostPredicate), - req: &LLMRequest{ - Model: "model", - ResolvedTargetModel: "model", - }, - input: []*backendmetrics.FakePodMetrics{ - // ActiveModels include input model, should be returned. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "model": 1, - }, - }, - }, - // Input model is not active, however the server has room to load another adapter. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "another-model": 1, - }, - }, - }, - // Input is not active, and the server has reached max active models. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "model": 1, - }, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "another-model": 1, - }, - }, - }, - }, - }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.f(logger, test.req, toInterface(test.input)) + ctx := types.NewContext(context.Background(), test.req, test.input) + got, err := test.f(ctx, test.input) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -433,8 +207,6 @@ func TestFilterFunc(t *testing.T) { // TestLoRASoftAffinityDistribution tests that the loRASoftAffinityFilter function // properly distributes requests according to the loraAffinityThreshold func TestLoRASoftAffinityDistribution(t *testing.T) { - logger := logutil.NewTestLogger() - const ( testModelName = "test-model" testAffinityModel = "test-affinity-model" @@ -455,15 +227,15 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }() // Create a test request and pods - req := &LLMRequest{ + req := &types.LLMRequest{ Model: testAffinityModel, ResolvedTargetModel: testAffinityModel, } // Test setup: One affinity pod and one available pod - pods := []*backendmetrics.FakePodMetrics{ + pods := []*types.PodMetrics{ { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "affinity-pod"}}, + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "affinity-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ @@ -472,13 +244,14 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "available-pod"}}, + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "available-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{}, }, }, } + ctx := types.NewContext(context.Background(), req, pods) // Run the filter function multiple times and count the results affinityCount := 0 @@ -489,7 +262,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { expectedAvailabilityPercent := 100 - expectedAffinityPercent for i := 0; i < numIterations; i++ { - result, err := loRASoftAffinityFilter(logger, req, toInterface(pods)) + result, err := loRASoftAffinityFilterFunc(ctx, pods) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -533,22 +306,3 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { actualAvailablePercent, availableLowerBound, availableUpperBound) } } - -func toInterface(input []*backendmetrics.FakePodMetrics) []backendmetrics.PodMetrics { - output := []backendmetrics.PodMetrics{} - for _, i := range input { - output = append(output, i) - } - return output -} - -func toStruct(input []backendmetrics.PodMetrics) []*backendmetrics.FakePodMetrics { - if input == nil { - return nil - } - output := []*backendmetrics.FakePodMetrics{} - for _, i := range input { - output = append(output, i.(*backendmetrics.FakePodMetrics)) - } - return output -} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index e874724d..8679ffba 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -22,10 +22,9 @@ import ( "fmt" "math/rand" - "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -67,89 +66,91 @@ func LoadConfig() Config { var config = LoadConfig() var ( - defaultFilter = &filter{ - name: "critical request", - filter: toFilterFunc(criticalRequestPredicate), - nextOnSuccess: lowLatencyFilter, - nextOnFailure: sheddableRequestFilter, - } - - // queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter - queueLoRAAndKVCacheFilter = &filter{ - name: "least queuing", - filter: leastQueuingFilterFunc, - nextOnSuccessOrFailure: &filter{ - name: "low cost LoRA", - filter: loRASoftAffinityFilter, - nextOnSuccessOrFailure: &filter{ - name: "least KV cache percent", - filter: leastKVCacheFilterFunc, + lowLatencyFilter = &decisionTreeFilter{ + current: lowQueueFilter, + nextOnSuccess: &decisionTreeFilter{ + current: loRAAffinityFilter, + nextOnSuccessOrFailure: &decisionTreeFilter{ + current: leastQueueFilter, + nextOnSuccessOrFailure: &decisionTreeFilter{ + current: leastKVCacheFilter, + }, }, }, - } - - // queueAndKVCacheFilter applies least queue followed by least KV Cache filter - queueAndKVCacheFilter = &filter{ - name: "least queuing", - filter: leastQueuingFilterFunc, - nextOnSuccessOrFailure: &filter{ - name: "least KV cache percent", - filter: leastKVCacheFilterFunc, - }, - } - - lowLatencyFilter = &filter{ - name: "low queueing filter", - filter: toFilterFunc((lowQueueingPodPredicate)), - nextOnSuccess: &filter{ - name: "affinity LoRA", - filter: loRASoftAffinityFilter, - nextOnSuccessOrFailure: queueAndKVCacheFilter, + nextOnFailure: &decisionTreeFilter{ + current: leastQueueFilter, + nextOnSuccessOrFailure: &decisionTreeFilter{ + current: loRAAffinityFilter, + nextOnSuccessOrFailure: &decisionTreeFilter{ + current: leastKVCacheFilter, + }, + }, }, - nextOnFailure: queueLoRAAndKVCacheFilter, } - sheddableRequestFilter = &filter{ + sheddableRequestFilter = &decisionTreeFilter{ // When there is at least one model server that's not queuing requests, and still has KV // cache below a certain threshold, we consider this model server has capacity to handle // a sheddable request without impacting critical requests. - name: "has capacity for sheddable requests", - filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(config.QueueThresholdCritical, config.KVCacheThreshold)), - nextOnSuccess: queueLoRAAndKVCacheFilter, + current: hasCapacityFilter, + nextOnSuccess: lowLatencyFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable // request to make room for critical requests. - nextOnFailure: &filter{ - name: "drop request", - filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - logger.V(logutil.DEFAULT).Info("Request dropped", "request", req) - return []backendmetrics.PodMetrics{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, + nextOnFailure: dropRequestFilter, + } + + hasCapacityFilter = &basicFilter{ + name: "has capacity for sheddable requests", + filter: toFilterFunc(queueThresholdPredicate(config.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.KVCacheThreshold))), + } + + dropRequestFilter = &basicFilter{ + name: "drop request", + filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) + return []*types.PodMetrics{}, errutil.Error{ + Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", + } }, } ) -func NewScheduler(datastore datastore.Datastore) *Scheduler { +func NewScheduler(datastore Datastore) *Scheduler { return &Scheduler{ - datastore: datastore, - filter: defaultFilter, + datastore: datastore, + criticalRequestFilter: lowLatencyFilter, + sheddableRequestFilter: sheddableRequestFilter, } } type Scheduler struct { - datastore datastore.Datastore - filter Filter + datastore Datastore + criticalRequestFilter Filter + sheddableRequestFilter Filter +} + +type Datastore interface { + PodGetAll() []backendmetrics.PodMetrics } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backendmetrics.PodMetrics, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (targetPod types.Pod, err error) { logger := log.FromContext(ctx).WithValues("request", req) - podMetrics := s.datastore.PodGetAll() - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", podMetrics)) + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + sCtx := types.NewContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) + + var filter Filter + if req.Critical { + filter = s.criticalRequestFilter + } else { + filter = s.sheddableRequestFilter + } - pods, err := s.filter.Filter(logger, req, podMetrics) + pods, err := filter.Filter(sCtx, sCtx.PodsSnapshot) if err != nil || len(pods) == 0 { return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go new file mode 100644 index 00000000..3fd3fb24 --- /dev/null +++ b/pkg/epp/scheduling/scheduler_test.go @@ -0,0 +1,232 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + k8stypes "k8s.io/apimachinery/pkg/types" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestSchedule(t *testing.T) { + tests := []struct { + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + output types.Pod + err bool + }{ + { + name: "critical request", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + }, + // pod2 will be picked because it has relatively low queue size, with the requested + // model being active, and has low KV cache. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + output: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + WaitingModels: map[string]int{}, + }, + }, + }, + { + name: "sheddable request, accepted", + req: &types.LLMRequest{ + Model: "sheddable", + ResolvedTargetModel: "sheddable", + Critical: false, + }, + // pod1 will be picked because it has capacity for the sheddable request. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + output: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, + }, + }, + }, + { + name: "sheddable request, dropped", + req: &types.LLMRequest{ + Model: "sheddable", + ResolvedTargetModel: "sheddable", + Critical: false, + }, + // All pods have higher KV cache thant the threshold, so the sheddable request will be + // dropped. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.85, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.85, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + output: nil, + err: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheduler := NewScheduler(&fakeDataStore{pods: test.input}) + got, err := scheduler.Schedule(context.Background(), test.req) + if test.err != (err != nil) { + t.Errorf("Unexpected error, got %v, want %v", err, test.err) + } + + if diff := cmp.Diff(test.output, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} + +type fakeDataStore struct { + pods []*backendmetrics.FakePodMetrics +} + +func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { + pm := make([]backendmetrics.PodMetrics, 0, len(fds.pods)) + for _, pod := range fds.pods { + pm = append(pm, pod) + } + return pm +} diff --git a/pkg/epp/scheduling/types.go b/pkg/epp/scheduling/types.go deleted file mode 100644 index 29e6648d..00000000 --- a/pkg/epp/scheduling/types.go +++ /dev/null @@ -1,27 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scheduling - -// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. -type LLMRequest struct { - Model string - // Target models is a map of target model name to weight. - TargetModels map[string]int - // Resolved target model is the final target model after traffic split. - ResolvedTargetModel string - Critical bool -} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go new file mode 100644 index 00000000..9450652e --- /dev/null +++ b/pkg/epp/scheduling/types/types.go @@ -0,0 +1,88 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "context" + "fmt" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" +) + +// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. +type LLMRequest struct { + Model string + // Target models is a map of target model name to weight. + TargetModels map[string]int + // Resolved target model is the final target model after traffic split. + ResolvedTargetModel string + Critical bool +} + +// Context holds contextual information during a scheduling operation. +type Context struct { + context.Context + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []*PodMetrics +} + +type Pod interface { + GetPod() *backendmetrics.Pod + GetMetrics() *backendmetrics.Metrics + String() string +} + +func (pm *PodMetrics) String() string { + if pm == nil { + return "" + } + return fmt.Sprintf("%+v", *pm) +} + +func (pm *PodMetrics) GetPod() *backendmetrics.Pod { + return pm.Pod +} + +func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { + return pm.Metrics +} + +type PodMetrics struct { + *backendmetrics.Pod + *backendmetrics.Metrics +} + +func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Context { + logger := log.FromContext(ctx).WithValues("request", req) + return &Context{ + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + } +} + +func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []*PodMetrics { + pm := make([]*PodMetrics, 0, len(pods)) + for _, pod := range pods { + pm = append(pm, &PodMetrics{pod.GetPod().Clone(), pod.GetMetrics().Clone()}) + } + return pm +} diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 0ba0e14a..46bc7353 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -73,7 +73,7 @@ import ( const ( port = runserver.DefaultGrpcPort - metricsPort = 8888 + metricsPort = 8889 ) var ( @@ -157,6 +157,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -165,6 +166,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -173,6 +175,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, }, wantHeaders: []*configPb.HeaderValueOption{ @@ -212,6 +215,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 200, @@ -220,6 +224,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 6, @@ -227,6 +232,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { ActiveModels: map[string]int{ "foo": 1, }, + WaitingModels: map[string]int{}, }, }, wantHeaders: []*configPb.HeaderValueOption{ @@ -266,6 +272,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -274,6 +281,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -282,6 +290,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantHeaders: []*configPb.HeaderValueOption{}, @@ -308,6 +317,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -316,6 +326,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -324,6 +335,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantHeaders: []*configPb.HeaderValueOption{ @@ -496,6 +508,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -504,6 +517,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -512,6 +526,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -578,6 +593,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 200, @@ -586,6 +602,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 6, @@ -593,6 +610,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { ActiveModels: map[string]int{ "foo": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -659,6 +677,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -667,6 +686,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -675,6 +695,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -704,6 +725,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -712,6 +734,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -720,6 +743,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -812,6 +836,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -820,6 +845,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -828,6 +854,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -920,6 +947,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -928,6 +956,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -936,6 +965,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -1029,6 +1059,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -1037,6 +1068,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -1045,6 +1077,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -1125,6 +1158,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -1133,6 +1167,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -1141,6 +1176,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -1470,6 +1506,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_pool_ready_pods`: `