diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 9c023f26d..e1bb094ee 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -43,9 +43,9 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/multi/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" @@ -157,13 +157,19 @@ func run() error { }) setupLog.Info("Flags processed", "flags", flags) - // Init runtime. + // --- Load Configurations from Environment Variables --- + // Note: Scheduler config is loaded via its package init currently. We may + // want to load it here explicitly: + sdConfig := saturationdetector.LoadConfigFromEnv() + + // --- Get Kubernetes Config --- cfg, err := ctrl.GetConfig() if err != nil { - setupLog.Error(err, "Failed to get rest config") + setupLog.Error(err, "Failed to get Kubernetes rest config") return err } + // --- Setup Manager --- poolNamespacedName := types.NamespacedName{ Name: *poolName, Namespace: *poolNamespace, @@ -174,7 +180,7 @@ func run() error { return err } - // Set up mapper for metric scraping. + // --- Setup Datastore --- mapping, err := backendmetrics.NewMetricMapping( *totalQueuedRequestsMetric, *kvCacheUsagePercentageMetric, @@ -185,14 +191,12 @@ func run() error { return err } verifyMetricMapping(*mapping, setupLog) - pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{MetricMapping: mapping}, *refreshMetricsInterval) - // Setup runner. ctx := ctrl.SetupSignalHandler() + appDatastore := datastore.NewDatastore(ctx, pmf) - datastore := datastore.NewDatastore(ctx, pmf) - - scheduler := scheduling.NewScheduler(datastore) + // --- Initialize EPP Components --- + appScheduler := scheduling.NewScheduler(appDatastore) if schedulerV2 == "true" { queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog) kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog) @@ -207,47 +211,62 @@ func run() error { } schedulerConfig := scheduling.NewSchedulerConfig( []plugins.PreSchedule{}, - []plugins.Filter{filter.NewSheddableCapacityFilter()}, + []plugins.Filter{}, scorers, picker.NewMaxScorePicker(), []plugins.PostSchedule{}, []plugins.PostResponse{}, schedConfigOpts...) - scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig) + appScheduler = scheduling.NewSchedulerWithConfig(appDatastore, schedulerConfig) + } + + appSaturationDetector, err := saturationdetector.NewDetector( + *sdConfig, + appDatastore, + ctrl.Log.WithName("saturation-detector"), + ) + if err != nil { + setupLog.Error(err, "Failed to create SaturationDetector") + return err } + + // --- Setup ExtProc Server Runner --- serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, DestinationEndpointHintKey: *destinationEndpointHintKey, PoolNamespacedName: poolNamespacedName, - Datastore: datastore, + Datastore: appDatastore, SecureServing: *secureServing, CertPath: *certPath, RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, - Scheduler: scheduler, + Scheduler: appScheduler, + SaturationDetector: appSaturationDetector, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { - setupLog.Error(err, "Failed to setup ext-proc controllers") + setupLog.Error(err, "Failed to setup EPP controllers") return err } + // --- Add Runnables to Manager --- + // Register health server. - if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), datastore, *grpcHealthPort); err != nil { + if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), appDatastore, *grpcHealthPort); err != nil { return err } // Register ext-proc server. - if err := mgr.Add(serverRunner.AsRunnable(ctrl.Log.WithName("ext-proc"))); err != nil { - setupLog.Error(err, "Failed to register ext-proc gRPC server") + if err := registerExtProcServer(mgr, serverRunner, ctrl.Log.WithName("ext-proc")); err != nil { return err } // Register metrics handler. - if err := registerMetricsHandler(mgr, *metricsPort, cfg, datastore); err != nil { + if err := registerMetricsHandler(mgr, *metricsPort, cfg, appDatastore); err != nil { return err } - // Start the manager. This blocks until a signal is received. + // --- Start Manager --- + // This blocks until a signal is received. setupLog.Info("Controller manager starting") if err := mgr.Start(ctx); err != nil { setupLog.Error(err, "Error starting controller manager") @@ -275,6 +294,17 @@ func initLogging(opts *zap.Options) { ctrl.SetLogger(logger) } +// registerExtProcServer adds the ExtProcServerRunner as a Runnable to the +// manager. +func registerExtProcServer(mgr manager.Manager, runner *runserver.ExtProcServerRunner, logger logr.Logger) error { + if err := mgr.Add(runner.AsRunnable(logger)); err != nil { + setupLog.Error(err, "Failed to register ext-proc gRPC server runnable") + return err + } + setupLog.Info("ExtProc server runner added to manager.") + return nil +} + // registerHealthServer adds the Health gRPC server as a Runnable to the given manager. func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore.Datastore, port int) error { srv := grpc.NewServer() @@ -364,5 +394,4 @@ func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logge if mapping.LoraRequestInfo == nil { logger.Info("Not scraping metric: LoraRequestInfo") } - } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 85c8ee34f..7b1363bff 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +// Package requestcontrol defines the Director component responsible for +// orchestrating request processing after initial parsing. package requestcontrol import ( @@ -34,106 +36,203 @@ import ( requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) +// Scheduler defines the interface required by the Director for scheduling. type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) } +// SaturationDetector provides a signal indicating whether the backends are +// considered saturated. +type SaturationDetector interface { + IsSaturated(ctx context.Context) bool +} + +// Director orchestrates the request handling flow, including scheduling. type Director struct { - datastore datastore.Datastore - scheduler Scheduler + datastore datastore.Datastore + scheduler Scheduler + saturationDetector SaturationDetector } -func NewDirector(datastore datastore.Datastore, scheduler Scheduler) *Director { +// NewDirector creates a new Director instance with all dependencies. +func NewDirector(ds datastore.Datastore, sched Scheduler, sd SaturationDetector) *Director { return &Director{ - datastore: datastore, - scheduler: scheduler, + datastore: ds, + scheduler: sched, + saturationDetector: sd, } } -// HandleRequest always returns the requestContext even in the error case, as the request context is used in error handling. +// HandleRequest orchestrates the request lifecycle: +// 1. Parses request details. +// 2. Calls PreDispatch for admission control. +// 3. Calls Dispatch (which calls Scheduler) if request is approved. +// 4. Calls PostDispatch to populate RequestContext with results. +// +// It always returns the requestContext even in the error case, as the request +// context is used in error handling. func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { logger := log.FromContext(ctx) - // Resolve target models. + // --- 1. Parse Request, Resolve Target Models, and Determine Parameters --- var ok bool requestBodyMap := reqCtx.Request.Body reqCtx.Model, ok = requestBodyMap["model"].(string) if !ok { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"} } prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap) if err != nil { return reqCtx, err } - // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. - // This might be a security risk in the future where adapters not registered in the InferenceModel - // are able to be requested by using their distinct name. + // NOTE: The nil checking for the modelObject means that we DO allow + // passthrough currently. + // This might be a security risk in the future where adapters not registered + // in the InferenceModel are able to be requested by using their distinct + // name. modelObj := d.datastore.ModelGet(reqCtx.Model) if modelObj == nil { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", reqCtx.Model)} + logger.V(logutil.DEFAULT).Info("InferenceModel not found in datastore", + "model", reqCtx.Model) + return reqCtx, errutil.Error{ + Code: errutil.BadConfiguration, + Msg: fmt.Sprintf("InferenceModel %s not found", reqCtx.Model), + } } reqCtx.ResolvedTargetModel = reqCtx.Model if len(modelObj.Spec.TargetModels) > 0 { reqCtx.ResolvedTargetModel = RandomWeightedDraw(logger, modelObj, 0) if reqCtx.ResolvedTargetModel == "" { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} + logger.Error(nil, "Failed to get a resolved target model from TargetModels spec", + "model", reqCtx.Model) + return reqCtx, errutil.Error{ + Code: errutil.BadConfiguration, + Msg: "error resolving target model for " + reqCtx.Model, + } } reqCtx.Request.Body["model"] = reqCtx.ResolvedTargetModel // Update target model in the body. } + requestCriticality := v1alpha2.Standard + if modelObj.Spec.Criticality != nil { + requestCriticality = *modelObj.Spec.Criticality + } + + // Prepare LLMRequest (needed for both saturation detection and Scheduler) llmReq := &schedulingtypes.LLMRequest{ TargetModel: reqCtx.ResolvedTargetModel, RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], - Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Critical: requestCriticality == v1alpha2.Critical, Prompt: prompt, Headers: reqCtx.Request.Headers, } - logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) - results, err := d.Dispatch(ctx, llmReq) - if err != nil { - return reqCtx, err + logger = logger.WithValues("model", reqCtx.Model, + "resolvedTargetModel", llmReq.TargetModel, + "criticality", requestCriticality, + "isCriticalFlag", llmReq.Critical) + ctx = log.IntoContext(ctx, logger) + logger.V(logutil.DEBUG).Info("LLM request assembled") + + // --- 2. Saturation Check --- + logger.V(logutil.DEBUG).Info("Calling PreDispatch") + preDispatchErr := d.PreDispatch(ctx, reqCtx, requestCriticality) + if preDispatchErr != nil { + logger.Error(preDispatchErr, "PreDispatch failed") + return reqCtx, preDispatchErr } - // Insert target endpoint to instruct Envoy to route requests to the specified target pod. - // Attach the port number - reqCtx, err = d.PostDispatch(ctx, reqCtx, results) - if err != nil { - return reqCtx, err + // --- 3. Dispatch (Calls Scheduler) --- + logger.V(logutil.DEBUG).Info("Calling Dispatch") + results, dispatchErr := d.Dispatch(ctx, llmReq) + if dispatchErr != nil { + logger.Error(dispatchErr, "Dispatch failed") + return reqCtx, dispatchErr + } + + // --- 4. PostDispatch (Populates RequestContext) --- + // Insert target endpoint to instruct Envoy to route requests to the + // specified target pod. + // Attach the port number. + logger.V(logutil.DEBUG).Info("Calling PostDispatch") + reqCtx, postDispatchErr := d.PostDispatch(ctx, reqCtx, results) + if postDispatchErr != nil { + logger.Error(postDispatchErr, "PostDispatch failed") + return reqCtx, postDispatchErr } return reqCtx, nil } +// PreDispatch handles admission control before dispatch. +func (d *Director) PreDispatch(ctx context.Context, reqCtx *handlers.RequestContext, reqCriticality v1alpha2.Criticality) error { + logger := log.FromContext(ctx) + logger.V(logutil.DEBUG).Info("Performing saturation check if request is non-critical.") + if d.saturationDetector == nil { + // Should we fail close here? + logger.Error(nil, "SaturationDetector is nil; cannot perform direct saturation check. Proceeding.") + return nil + } + + // Check saturation directly ONLY for non-critical requests. + if reqCriticality != v1alpha2.Critical && d.saturationDetector.IsSaturated(ctx) { + logger.Info("System saturated, dropping non-critical request") + return errutil.Error{ + Code: errutil.InferencePoolResourceExhausted, + Msg: "system saturated, non-critical request dropped", + } + } + logger.V(logutil.DEBUG).Info("Proceeding to Dispatch (request is critical or system not saturated).") + return nil +} + // Dispatch runs one or many scheduling cycles. func (d *Director) Dispatch(ctx context.Context, llmReq *schedulingtypes.LLMRequest) ([]*schedulingtypes.Result, error) { - var err error res, err := d.scheduler.Schedule(ctx, llmReq) if err != nil { - return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} + return nil, errutil.Error{ + Code: errutil.InferencePoolResourceExhausted, + Msg: fmt.Errorf("scheduler failed: %w", err).Error(), + } + } + if res == nil { // Defensive check + return nil, errutil.Error{ + Code: errutil.Internal, + Msg: "scheduler returned nil result without error", + } } return []*schedulingtypes.Result{res}, nil } +// PostDispatch populates the RequestContext based on scheduling results. func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestContext, results []*schedulingtypes.Result) (*handlers.RequestContext, error) { logger := log.FromContext(ctx) - // currently only get a single result. Will refactor to pluggably implement the PostSchedule - if len(results) == 0 { - return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"} + // Currently only get a single result. Will refactor to pluggably implement + // the PostSchedule. + if len(results) == 0 || results[0] == nil || results[0].TargetPod == nil || results[0].TargetPod.GetPod() == nil { + logger.Error(nil, "PostDispatch called with invalid scheduling results") + return reqCtx, errutil.Error{ + Code: errutil.Internal, + Msg: "invalid scheduling result in PostDispatch", + } } targetPod := results[0].TargetPod.GetPod() pool, err := d.datastore.PoolGet() if err != nil { - return reqCtx, err + logger.Error(err, "Failed to get InferencePool info for port") + return reqCtx, errutil.Error{ + Code: errutil.Internal, + Msg: "failed to get pool configuration", + } } endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - logger.V(logutil.DEFAULT).Info("Request handled", "model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod) - + logger.V(logutil.DEFAULT).Info("Request scheduled", + "targetPod", targetPod.NamespacedName.String(), "endpoint", endpoint) reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint @@ -154,6 +253,7 @@ func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestC return reqCtx, nil } +// GetRandomPod selects a random pod. func (d *Director) GetRandomPod() *backend.Pod { pods := d.datastore.PodGetAll() if len(pods) == 0 { @@ -164,16 +264,18 @@ func (d *Director) GetRandomPod() *backend.Pod { return pod.GetPod() } +// RandomWeightedDraw selects a model name based on weighted random +// distribution. func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { - // TODO: after we are down to 1 server implementation, make these methods a part of the struct - // and handle random seeding on the struct. + // TODO: after we are down to 1 server implementation, make these methods a + // part of the struct and handle random seeding on the struct. source := rand.NewSource(rand.Int63()) if seed > 0 { source = rand.NewSource(seed) } r := rand.New(source) - // all the weight values are nil, then we should return random model name + // All the weight values are nil, then we should return random model name. if model.Spec.TargetModels[0].Weight == nil { index := r.Int31n(int32(len(model.Spec.TargetModels))) return model.Spec.TargetModels[index].Name @@ -183,7 +285,8 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed for _, model := range model.Spec.TargetModels { weights += *model.Weight } - logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) + logger.V(logutil.TRACE).Info("Weights for model computed", + "model", model.Name, "weights", weights) randomVal := r.Int31n(weights) // TODO: optimize this without using loop for _, model := range model.Spec.TargetModels { diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index e4384a80b..35a53098e 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -18,120 +18,216 @@ package requestcontrol import ( "context" - "strings" + "errors" "testing" "time" "github.com/google/go-cmp/cmp" - corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + k8stypes "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) -func TestHandleRequest(t *testing.T) { +// --- Mocks --- + +// mockSaturationDetector provides a minimal mock for testing. +type mockSaturationDetector struct { + isSaturated bool +} + +func (m *mockSaturationDetector) IsSaturated(_ context.Context) bool { + return m.isSaturated +} + +// mockScheduler is a configurable mock for the Scheduler interface. +type mockScheduler struct { + // Fields for Schedule + scheduleResult *schedulingtypes.Result + scheduleErr error + scheduleFunc func(ctx context.Context, b *schedulingtypes.LLMRequest) (*schedulingtypes.Result, error) + scheduleCalled bool + + // Fields for OnResponse + onResponseFunc func(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) + onResponseCalled bool + lastCtxOnResponse context.Context + lastRespOnResponse *schedulingtypes.LLMResponse + lastTargetPodOnResponse string +} + +func (m *mockScheduler) Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (*schedulingtypes.Result, error) { + m.scheduleCalled = true + if m.scheduleFunc != nil { + return m.scheduleFunc(ctx, b) + } + if m.scheduleErr != nil { + return nil, m.scheduleErr + } + if m.scheduleResult == nil { + // Provide a default valid pod if not specified. + return &schedulingtypes.Result{ + TargetPod: &schedulingtypes.ScoredPod{ + Pod: &schedulingtypes.PodMetrics{ + Pod: &backend.Pod{ + Address: "192.168.1.100", + NamespacedName: k8stypes.NamespacedName{ + Name: "pod1", + Namespace: "default", + }, + }, + }, + }, + }, nil + } + return m.scheduleResult, nil +} + +func (m *mockScheduler) OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) { + m.onResponseCalled = true + m.lastCtxOnResponse = ctx + m.lastRespOnResponse = resp + m.lastTargetPodOnResponse = targetPodName + if m.onResponseFunc != nil { + m.onResponseFunc(ctx, resp, targetPodName) + } +} + +func (m *mockScheduler) Reset() { + m.scheduleResult = nil + m.scheduleErr = nil + m.scheduleFunc = nil + m.scheduleCalled = false + m.onResponseFunc = nil + m.onResponseCalled = false + m.lastCtxOnResponse = nil + m.lastRespOnResponse = nil + m.lastTargetPodOnResponse = "" +} + +func TestDirector_HandleRequest(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) - // Setup datastore - tsModel := "food-review" - modelWithTarget := "food-review-0" - model1 := testutil.MakeInferenceModel("model1"). + // --- Setup common objects --- + model := "food-review" + modelSheddable := "food-review-sheddable" + modelWithResolvedTarget := "food-review-resolve" + + // InferenceModel definitions + imFoodReview := testutil.MakeInferenceModel("imFoodReview"). + CreationTimestamp(metav1.Unix(1000, 0)). + ModelName(model). + Criticality(v1alpha2.Critical). + ObjRef() + imFoodReviewSheddable := testutil.MakeInferenceModel("imFoodReviewSheddable"). CreationTimestamp(metav1.Unix(1000, 0)). - ModelName(tsModel).ObjRef() - model2 := testutil.MakeInferenceModel("model2"). + ModelName(modelSheddable). + Criticality(v1alpha2.Sheddable). + ObjRef() + imFoodReviewResolve := testutil.MakeInferenceModel("imFoodReviewResolve"). CreationTimestamp(metav1.Unix(1000, 0)). - ModelName(modelWithTarget).ObjRef() + ModelName(modelWithResolvedTarget). + Criticality(v1alpha2.Standard). + TargetModel("resolved-target-model-A"). + ObjRef() + + // Datastore setup pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := datastore.NewDatastore(t.Context(), pmf) - ds.ModelSetIfOlder(model1) - ds.ModelSetIfOlder(model2) + ds.ModelSetIfOlder(imFoodReview) + ds.ModelSetIfOlder(imFoodReviewResolve) + ds.ModelSetIfOlder(imFoodReviewSheddable) pool := &v1alpha2.InferencePool{ + ObjectMeta: metav1.ObjectMeta{Name: "test-pool", Namespace: "default"}, Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ - "some-key": "some-val", + "app": "inference", }, }, } - pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}} + + // Pod setup + testPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + Labels: map[string]string{"app": "inference"}, + }, + Status: corev1.PodStatus{ + PodIP: "192.168.1.100", + Phase: corev1.PodRunning, + Conditions: []corev1.PodCondition{{Type: corev1.PodReady, Status: corev1.ConditionTrue}}, + }, + } scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) - fakeClient := fake.NewClientBuilder(). - WithScheme(scheme). - Build() + fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() if err := ds.PoolSet(ctx, fakeClient, pool); err != nil { - t.Error(err, "Error while setting inference pool") + t.Fatalf("Error while setting inference pool: %v", err) } - ds.PodUpdateOrAddIfNotExist(pod) + ds.PodUpdateOrAddIfNotExist(testPod) tests := []struct { - name string - reqBodyMap map[string]interface{} - wantErrCode string - wantReqCtx *handlers.RequestContext - wantRespBody map[string]interface{} + name string + reqBodyMap map[string]interface{} + mockSaturationDetector *mockSaturationDetector + schedulerMockSetup func(m *mockScheduler) // Configure the scheduler mock for this test + wantErrCode string // Expected errutil code string + wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext + wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch }{ { - name: "successful completions request", + name: "successful completions request (critical, saturation ignored)", reqBodyMap: map[string]interface{}{ - "model": tsModel, - "prompt": "test prompt", + "model": model, + "prompt": "critical prompt", }, + mockSaturationDetector: &mockSaturationDetector{isSaturated: true}, wantReqCtx: &handlers.RequestContext{ - Model: tsModel, - ResolvedTargetModel: tsModel, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", - }, - wantRespBody: map[string]interface{}{ - "model": tsModel, - "prompt": "test prompt", + Model: model, + ResolvedTargetModel: model, + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", }, + wantMutatedBodyModel: model, }, { - name: "successful chat completions request", + name: "successful chat completions request (critical, saturation ignored)", reqBodyMap: map[string]interface{}{ - "model": tsModel, + "model": model, "messages": []interface{}{ map[string]interface{}{ "role": "user", - "content": "test prompt", + "content": "critical prompt", }, }, }, wantReqCtx: &handlers.RequestContext{ - Model: tsModel, - ResolvedTargetModel: tsModel, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", - }, - wantRespBody: map[string]interface{}{ - "model": tsModel, - "messages": []interface{}{ - map[string]interface{}{ - "role": "user", - "content": "test prompt", - }, - }, + Model: model, + ResolvedTargetModel: model, + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", }, + wantMutatedBodyModel: model, }, { - name: "successful chat completions request with multiple messages", + name: "successful chat completions request with multiple messages (critical, saturation ignored)", reqBodyMap: map[string]interface{}{ - "model": tsModel, + "model": model, "messages": []interface{}{ map[string]interface{}{ "role": "developer", @@ -144,57 +240,84 @@ func TestHandleRequest(t *testing.T) { }, }, wantReqCtx: &handlers.RequestContext{ - Model: tsModel, - ResolvedTargetModel: tsModel, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", - }, - wantRespBody: map[string]interface{}{ - "model": tsModel, - "messages": []interface{}{ - map[string]interface{}{ - "role": "developer", - "content": "You are a helpful assistant.", - }, - map[string]interface{}{ - "role": "user", - "content": "Hello!", - }, - }, + Model: model, + ResolvedTargetModel: model, + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", }, + wantMutatedBodyModel: model, }, { - name: "successful completions request with target model", + name: "successful completions request (sheddable, not saturated)", reqBodyMap: map[string]interface{}{ - "model": modelWithTarget, - "prompt": "test prompt", + "model": modelSheddable, + "prompt": "sheddable prompt", }, + mockSaturationDetector: &mockSaturationDetector{isSaturated: false}, wantReqCtx: &handlers.RequestContext{ - Model: modelWithTarget, - ResolvedTargetModel: modelWithTarget, - TargetPod: "/pod1", - TargetEndpoint: "address-1:8000", + Model: modelSheddable, + ResolvedTargetModel: modelSheddable, + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", }, - wantRespBody: map[string]interface{}{ - "model": modelWithTarget, - "prompt": "test prompt", + wantMutatedBodyModel: modelSheddable, + }, + { + name: "successful request with target model resolution", + reqBodyMap: map[string]interface{}{ + "model": modelWithResolvedTarget, + "prompt": "prompt for target resolution", }, + mockSaturationDetector: &mockSaturationDetector{isSaturated: false}, + wantReqCtx: &handlers.RequestContext{ + Model: modelWithResolvedTarget, + ResolvedTargetModel: "resolved-target-model-A", + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", + }, + wantMutatedBodyModel: "resolved-target-model-A", }, { - name: "no model defined, expect err", - wantErrCode: errutil.BadRequest, + + name: "request dropped (sheddable, saturated)", + reqBodyMap: map[string]interface{}{ + "model": modelSheddable, + "prompt": "sheddable prompt", + }, + mockSaturationDetector: &mockSaturationDetector{isSaturated: true}, + wantErrCode: errutil.InferencePoolResourceExhausted, }, { - name: "prompt or messages not found, expect err", + name: "nil saturation detector (proceeds)", reqBodyMap: map[string]interface{}{ - "model": tsModel, + "model": modelSheddable, + "prompt": "sheddable prompt", + }, + mockSaturationDetector: nil, // Simulate detector not being configured + wantReqCtx: &handlers.RequestContext{ + Model: modelSheddable, + ResolvedTargetModel: modelSheddable, + TargetPod: "default/pod1", + TargetEndpoint: "192.168.1.100:8000", }, + wantMutatedBodyModel: modelSheddable, + }, + { + name: "model not found, expect err", + reqBodyMap: map[string]interface{}{"prompt": "p"}, + mockSaturationDetector: &mockSaturationDetector{isSaturated: false}, + wantErrCode: errutil.BadRequest, + }, + + { + name: "prompt or messages not found, expect err", + reqBodyMap: map[string]interface{}{"model": model}, wantErrCode: errutil.BadRequest, }, { name: "empty messages, expect err", reqBodyMap: map[string]interface{}{ - "model": tsModel, + "model": model, "messages": []interface{}{}, }, wantErrCode: errutil.BadRequest, @@ -205,7 +328,8 @@ func TestHandleRequest(t *testing.T) { "model": "non-existent-model", "prompt": "test prompt", }, - wantErrCode: errutil.BadConfiguration, + mockSaturationDetector: &mockSaturationDetector{isSaturated: false}, + wantErrCode: errutil.BadConfiguration, }, { name: "invalid target defined, expect err", @@ -213,46 +337,117 @@ func TestHandleRequest(t *testing.T) { "model": "food-review-1", "prompt": "test prompt", }, - wantErrCode: errutil.BadConfiguration, + mockSaturationDetector: &mockSaturationDetector{isSaturated: false}, + wantErrCode: errutil.BadConfiguration, + }, + { + name: "scheduler returns error", + reqBodyMap: map[string]interface{}{ + "model": model, + "prompt": "prompt that causes scheduler error", + }, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleErr = errors.New("simulated scheduler failure") + }, + wantErrCode: errutil.InferencePoolResourceExhausted, + }, + { + name: "scheduler returns nil result and nil error", + reqBodyMap: map[string]interface{}{ + "model": model, + "prompt": "prompt for nil,nil scheduler return", + }, + schedulerMockSetup: func(m *mockScheduler) { + // Explicitly set scheduleFunc to return nil, nil + m.scheduleFunc = func(ctx context.Context, b *schedulingtypes.LLMRequest) (*schedulingtypes.Result, error) { + return nil, nil + } + }, + wantErrCode: errutil.Internal, + }, + { + name: "scheduler returns result with nil TargetPod", + reqBodyMap: map[string]interface{}{ + "model": model, + "prompt": "prompt for nil TargetPod in result", + }, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResult = &schedulingtypes.Result{TargetPod: nil} + }, + wantErrCode: errutil.Internal, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - server := NewDirector(ds, scheduling.NewScheduler(ds)) + mockSched := &mockScheduler{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSched.Reset() + if tt.schedulerMockSetup != nil { + tt.schedulerMockSetup(mockSched) + } + + var sd SaturationDetector + if tt.mockSaturationDetector != nil { + sd = tt.mockSaturationDetector + } + director := NewDirector(ds, mockSched, sd) + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ - Body: test.reqBodyMap, + // Create a copy of the map for each test run to avoid mutation + // issues. + Body: make(map[string]interface{}), + Headers: make(map[string]string), // Initialize headers }, } - reqCtx, err := server.HandleRequest(ctx, reqCtx) + // Deep copy the body map + for k, v := range tt.reqBodyMap { + reqCtx.Request.Body[k] = v + } + + returnedReqCtx, err := director.HandleRequest(ctx, reqCtx) - if test.wantErrCode != "" { + if tt.wantErrCode != "" { if err == nil { - t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode) + t.Fatalf("HandleRequest() should have returned an error with code '%s', but got nil", tt.wantErrCode) } - if !strings.Contains(err.Error(), test.wantErrCode) { - t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode) + if e, ok := err.(errutil.Error); ok { + if e.Code != tt.wantErrCode { + t.Fatalf("HandleRequest() returned error with code %s, want %s. Full error: %v", e.Code, tt.wantErrCode, err) + } + } else { + t.Fatalf("HandleRequest() returned error of type %T, want errutil.Error. Full error: %v", err, err) } return } if err != nil { - t.Fatalf("HandleRequestBody returned unexpected error: %v", err) + t.Fatalf("HandleRequest() returned unexpected error: %v", err) } - if test.wantReqCtx != nil { - if diff := cmp.Diff(test.wantReqCtx.Model, reqCtx.Model); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff) + if tt.wantReqCtx != nil { + if diff := cmp.Diff(tt.wantReqCtx.Model, returnedReqCtx.Model); diff != "" { + t.Errorf("reqCtx.Model mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, reqCtx.ResolvedTargetModel); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff) + if diff := cmp.Diff(tt.wantReqCtx.ResolvedTargetModel, returnedReqCtx.ResolvedTargetModel); diff != "" { + t.Errorf("reqCtx.ResolvedTargetModel mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(test.wantReqCtx.TargetPod, reqCtx.TargetPod); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff) + if diff := cmp.Diff(tt.wantReqCtx.TargetPod, returnedReqCtx.TargetPod); diff != "" { + t.Errorf("reqCtx.TargetPod mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, reqCtx.TargetEndpoint); diff != "" { - t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff) + if diff := cmp.Diff(tt.wantReqCtx.TargetEndpoint, returnedReqCtx.TargetEndpoint); diff != "" { + t.Errorf("reqCtx.TargetEndpoint mismatch (-want +got):\n%s", diff) + } + } + + if tt.wantMutatedBodyModel != "" { + if returnedReqCtx.Request.Body == nil { + t.Errorf("Expected mutated body with model %s, but reqCtx.Request.Body is nil", tt.wantMutatedBodyModel) + } else { + if gotModel, ok := returnedReqCtx.Request.Body["model"].(string); !ok || gotModel != tt.wantMutatedBodyModel { + t.Errorf("Mutated reqCtx.Request.Body model = %q, want %q. Full body: %v", gotModel, tt.wantMutatedBodyModel, returnedReqCtx.Request.Body) + } } } }) @@ -261,13 +456,15 @@ func TestHandleRequest(t *testing.T) { func TestRandomWeightedDraw(t *testing.T) { logger := logutil.NewTestLogger() + // Note: These tests verify deterministic outcomes for a fixed seed (420). + // They do not test the statistical properties of the random draw. tests := []struct { name string model *v1alpha2.InferenceModel want string }{ { - name: "'random' distribution", + name: "deterministic draw: 50/50 weights, seed 420", model: &v1alpha2.InferenceModel{ Spec: v1alpha2.InferenceModelSpec{ TargetModels: []v1alpha2.TargetModel{ @@ -285,7 +482,7 @@ func TestRandomWeightedDraw(t *testing.T) { want: "canary", }, { - name: "'random' distribution", + name: "deterministic draw: 25/55/50 weights, seed 420", model: &v1alpha2.InferenceModel{ Spec: v1alpha2.InferenceModelSpec{ TargetModels: []v1alpha2.TargetModel{ @@ -307,7 +504,7 @@ func TestRandomWeightedDraw(t *testing.T) { want: "v1", }, { - name: "'random' distribution", + name: "deterministic draw: 20/20/10 weights, seed 420", model: &v1alpha2.InferenceModel{ Spec: v1alpha2.InferenceModelSpec{ TargetModels: []v1alpha2.TargetModel{ @@ -329,7 +526,7 @@ func TestRandomWeightedDraw(t *testing.T) { want: "v1.1", }, { - name: "weighted distribution with weight unset", + name: "deterministic draw: nil weights (uniform), seed 420", model: &v1alpha2.InferenceModel{ Spec: v1alpha2.InferenceModelSpec{ TargetModels: []v1alpha2.TargetModel{ @@ -351,12 +548,9 @@ func TestRandomWeightedDraw(t *testing.T) { var seedVal int64 = 420 for _, test := range tests { t.Run(test.name, func(t *testing.T) { - for range 10000 { - model := RandomWeightedDraw(logger, test.model, seedVal) - if model != test.want { - t.Errorf("Model returned: %v != %v", model, test.want) - break - } + model := RandomWeightedDraw(logger, test.model, seedVal) + if model != test.want { + t.Errorf("RandomWeightedDraw() with seed %d = %q, want %q", seedVal, model, test.want) } }) } @@ -393,7 +587,7 @@ func TestGetRandomPod(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - pmf := metrics.NewPodMetricsFactory(&metrics.FakePodMetricsClient{}, time.Millisecond) + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Millisecond) ds := datastore.NewDatastore(t.Context(), pmf) for _, pod := range test.storePods { ds.PodUpdateOrAddIfNotExist(pod) @@ -414,3 +608,40 @@ func TestGetRandomPod(t *testing.T) { func pointer(v int32) *int32 { return &v } + +func TestDirector_HandleResponse(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + ds := datastore.NewDatastore(t.Context(), nil) + mockSched := &mockScheduler{} + director := NewDirector(ds, mockSched, nil) + + reqCtx := &handlers.RequestContext{ + Request: &handlers.Request{ + Headers: map[string]string{ + requtil.RequestIdHeaderKey: "test-req-id-for-response", + }, + }, + Response: &handlers.Response{ // Simulate some response headers + Headers: map[string]string{"X-Test-Response-Header": "TestValue"}, + }, + TargetPod: "namespace1/test-pod-name", + } + + _, err := director.HandleResponse(ctx, reqCtx) + if err != nil { + t.Fatalf("HandleResponse() returned unexpected error: %v", err) + } + + if !mockSched.onResponseCalled { + t.Fatal("Scheduler.OnResponse was not called") + } + if diff := cmp.Diff("test-req-id-for-response", mockSched.lastRespOnResponse.RequestId); diff != "" { + t.Errorf("Scheduler.OnResponse RequestId mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(reqCtx.Response.Headers, mockSched.lastRespOnResponse.Headers); diff != "" { + t.Errorf("Scheduler.OnResponse Headers mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff("namespace1/test-pod-name", mockSched.lastTargetPodOnResponse); diff != "" { + t.Errorf("Scheduler.OnResponse TargetPodName mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 4b8620826..d5b9ef2f4 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -50,8 +50,10 @@ type ExtProcServerRunner struct { UseStreaming bool RefreshPrometheusMetricsInterval time.Duration Scheduler requestcontrol.Scheduler + SaturationDetector requestcontrol.SaturationDetector - // This should only be used in tests. We won't need this once we don't inject metrics in the tests. + // This should only be used in tests. We won't need this once we do not + // inject metrics in the tests. // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup TestPodMetricsClient *backendmetrics.FakePodMetricsClient } @@ -68,6 +70,8 @@ const ( DefaultSecureServing = true // default for --secureServing ) +// NewDefaultExtProcServerRunner creates a runner with default values. +// Note: Dependencies like Datastore, Scheduler, SD need to be set separately. func NewDefaultExtProcServerRunner() *ExtProcServerRunner { return &ExtProcServerRunner{ GrpcPort: DefaultGrpcPort, @@ -76,7 +80,7 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { PoolNamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, SecureServing: DefaultSecureServing, RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, - // Datastore can be assigned later. + // Dependencies can be assigned later. } } @@ -138,7 +142,14 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } else { srv = grpc.NewServer() } - extProcServer := handlers.NewStreamingServer(r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore, requestcontrol.NewDirector(r.Datastore, r.Scheduler)) + + director := requestcontrol.NewDirector(r.Datastore, r.Scheduler, r.SaturationDetector) + extProcServer := handlers.NewStreamingServer( + r.DestinationEndpointHintMetadataNamespace, + r.DestinationEndpointHintKey, + r.Datastore, + director, + ) extProcPb.RegisterExternalProcessorServer( srv, extProcServer, diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 923392530..e15d73ce7 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -65,6 +65,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + sdconfig "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" @@ -1414,6 +1415,17 @@ func BeforeSuite() func() { serverRunner.PoolNamespacedName = types.NamespacedName{Name: "vllm-llama3-8b-instruct-pool", Namespace: "default"} serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) serverRunner.Scheduler = scheduling.NewScheduler(serverRunner.Datastore) + + saturationDetectorConfig := sdconfig.Config{ + QueueDepthThreshold: sdconfig.DefaultQueueDepthThreshold, + KVCacheUtilThreshold: sdconfig.DefaultKVCacheUtilThreshold, + MetricsStalenessThreshold: sdconfig.DefaultMetricsStalenessThreshold, + } + detector, err := sdconfig.NewDetector(saturationDetectorConfig, serverRunner.Datastore, logger.WithName("saturation-detector")) + if err != nil { + logutil.Fatal(logger, err, "Failed to create saturation detector for hermetic tests") + } + serverRunner.SaturationDetector = detector serverRunner.SecureServing = false if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil {