diff --git a/Makefile b/Makefile index 61b17f5b..257d2cbb 100644 --- a/Makefile +++ b/Makefile @@ -119,7 +119,7 @@ vet: ## Run go vet against code. .PHONY: test test: manifests generate fmt vet envtest ## Run tests. - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -race -coverprofile cover.out .PHONY: test-integration test-integration: manifests generate fmt vet envtest ## Run tests. diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 1f62d94a..e1cd5015 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -37,7 +37,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/vllm" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" @@ -143,22 +143,20 @@ func run() error { ctx := ctrl.SetupSignalHandler() + pmf := backendmetrics.NewPodMetricsFactory(&vllm.PodMetricsClientImpl{}, *refreshMetricsInterval) // Setup runner. - datastore := datastore.NewDatastore() - provider := backend.NewProvider(&vllm.PodMetricsClientImpl{}, datastore) + datastore := datastore.NewDatastore(ctx, pmf) serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, DestinationEndpointHintKey: *destinationEndpointHintKey, PoolName: *poolName, PoolNamespace: *poolNamespace, - RefreshMetricsInterval: *refreshMetricsInterval, - RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, Datastore: datastore, SecureServing: *secureServing, CertPath: *certPath, - Provider: provider, UseStreaming: useStreamingServer, + RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup ext-proc controllers") diff --git a/pkg/epp/backend/fake.go b/pkg/epp/backend/fake.go deleted file mode 100644 index 584486c2..00000000 --- a/pkg/epp/backend/fake.go +++ /dev/null @@ -1,48 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "context" - - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -type FakePodMetricsClient struct { - Err map[types.NamespacedName]error - Res map[types.NamespacedName]*datastore.PodMetrics -} - -func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, existing *datastore.PodMetrics, port int32) (*datastore.PodMetrics, error) { - if err, ok := f.Err[existing.NamespacedName]; ok { - return nil, err - } - log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", f.Res[existing.NamespacedName]) - return f.Res[existing.NamespacedName], nil -} - -type FakeDataStore struct { - Res map[string]*v1alpha2.InferenceModel -} - -func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha2.InferenceModel) { - return fds.Res[modelName] -} diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go new file mode 100644 index 00000000..fae7149d --- /dev/null +++ b/pkg/epp/backend/metrics/fake.go @@ -0,0 +1,90 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "context" + "fmt" + "sync" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop. +type FakePodMetrics struct { + Pod *Pod + Metrics *Metrics +} + +func (fpm *FakePodMetrics) GetPod() *Pod { + return fpm.Pod +} +func (fpm *FakePodMetrics) GetMetrics() *Metrics { + return fpm.Metrics +} +func (fpm *FakePodMetrics) UpdatePod(pod *corev1.Pod) { + fpm.Pod = toInternalPod(pod) +} +func (fpm *FakePodMetrics) StopRefreshLoop() {} // noop + +type FakePodMetricsClient struct { + errMu sync.RWMutex + Err map[types.NamespacedName]error + resMu sync.RWMutex + Res map[types.NamespacedName]*Metrics +} + +func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *Pod, existing *Metrics, port int32) (*Metrics, error) { + f.errMu.RLock() + err, ok := f.Err[pod.NamespacedName] + f.errMu.RUnlock() + if ok { + return nil, err + } + f.resMu.RLock() + res, ok := f.Res[pod.NamespacedName] + f.resMu.RUnlock() + if !ok { + return nil, fmt.Errorf("no pod found: %v", pod.NamespacedName) + } + log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", res) + return res.Clone(), nil +} + +func (f *FakePodMetricsClient) SetRes(new map[types.NamespacedName]*Metrics) { + f.resMu.Lock() + defer f.resMu.Unlock() + f.Res = new +} + +func (f *FakePodMetricsClient) SetErr(new map[types.NamespacedName]error) { + f.errMu.Lock() + defer f.errMu.Unlock() + f.Err = new +} + +type FakeDataStore struct { + Res map[string]*v1alpha2.InferenceModel +} + +func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha2.InferenceModel) { + return fds.Res[modelName] +} diff --git a/pkg/epp/backend/metrics/logger.go b/pkg/epp/backend/metrics/logger.go new file mode 100644 index 00000000..664115eb --- /dev/null +++ b/pkg/epp/backend/metrics/logger.go @@ -0,0 +1,111 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "context" + "time" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + // Note currently the EPP treats stale metrics same as fresh. + // TODO: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/336 + metricsValidityPeriod = 5 * time.Second +) + +type Datastore interface { + PoolGet() (*v1alpha2.InferencePool, error) + // PodMetrics operations + // PodGetAll returns all pods and metrics, including fresh and stale. + PodGetAll() []PodMetrics + PodList(func(PodMetrics) bool) []PodMetrics +} + +// StartMetricsLogger starts goroutines to 1) Print metrics debug logs if the DEBUG log level is +// enabled; 2) flushes Prometheus metrics about the backend servers. +func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometheusMetricsInterval time.Duration) { + logger := log.FromContext(ctx) + + // Periodically flush prometheus metrics for inference pool + go func() { + for { + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Shutting down prometheus metrics thread") + return + default: + time.Sleep(refreshPrometheusMetricsInterval) + flushPrometheusMetricsOnce(logger, datastore) + } + } + }() + + // Periodically print out the pods and metrics for DEBUGGING. + if logger := logger.V(logutil.DEBUG); logger.Enabled() { + go func() { + for { + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Shutting down metrics logger thread") + return + default: + time.Sleep(5 * time.Second) + podsWithFreshMetrics := datastore.PodList(func(pm PodMetrics) bool { + return time.Since(pm.GetMetrics().UpdateTime) <= metricsValidityPeriod + }) + podsWithStaleMetrics := datastore.PodList(func(pm PodMetrics) bool { + return time.Since(pm.GetMetrics().UpdateTime) > metricsValidityPeriod + }) + logger.Info("Current Pods and metrics gathered", "fresh metrics", podsWithFreshMetrics, "stale metrics", podsWithStaleMetrics) + } + } + }() + } +} + +func flushPrometheusMetricsOnce(logger logr.Logger, datastore Datastore) { + pool, err := datastore.PoolGet() + if err != nil { + // No inference pool or not initialize. + logger.V(logutil.VERBOSE).Info("pool is not initialized, skipping flushing metrics") + return + } + + var kvCacheTotal float64 + var queueTotal int + + podMetrics := datastore.PodGetAll() + logger.V(logutil.VERBOSE).Info("Flushing Prometheus Metrics", "ReadyPods", len(podMetrics)) + if len(podMetrics) == 0 { + return + } + + for _, pod := range podMetrics { + kvCacheTotal += pod.GetMetrics().KVCacheUsagePercent + queueTotal += pod.GetMetrics().WaitingQueueSize + } + + podTotalCount := len(podMetrics) + metrics.RecordInferencePoolAvgKVCache(pool.Name, kvCacheTotal/float64(podTotalCount)) + metrics.RecordInferencePoolAvgQueueSize(pool.Name, float64(queueTotal/podTotalCount)) +} diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go new file mode 100644 index 00000000..f76c2e8c --- /dev/null +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -0,0 +1,129 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "context" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + fetchMetricsTimeout = 5 * time.Second +) + +type podMetrics struct { + pod unsafe.Pointer // stores a *Pod + metrics unsafe.Pointer // stores a *Metrics + pmc PodMetricsClient + ds Datastore + interval time.Duration + + parentCtx context.Context + once sync.Once // ensure the StartRefreshLoop is only called once. + done chan struct{} + + logger logr.Logger +} + +type PodMetricsClient interface { + FetchMetrics(ctx context.Context, pod *Pod, existing *Metrics, port int32) (*Metrics, error) +} + +func (pm *podMetrics) GetPod() *Pod { + return (*Pod)(atomic.LoadPointer(&pm.pod)) +} + +func (pm *podMetrics) GetMetrics() *Metrics { + return (*Metrics)(atomic.LoadPointer(&pm.metrics)) +} + +func (pm *podMetrics) UpdatePod(in *corev1.Pod) { + atomic.StorePointer(&pm.pod, unsafe.Pointer(toInternalPod(in))) +} + +func toInternalPod(in *corev1.Pod) *Pod { + return &Pod{ + NamespacedName: types.NamespacedName{ + Name: in.Name, + Namespace: in.Namespace, + }, + Address: in.Status.PodIP, + } +} + +// start starts a goroutine exactly once to periodically update metrics. The goroutine will be +// stopped either when stop() is called, or the parentCtx is cancelled. +func (pm *podMetrics) startRefreshLoop() { + pm.once.Do(func() { + go func() { + pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod()) + for { + select { + case <-pm.done: + return + case <-pm.parentCtx.Done(): + return + default: + } + + err := pm.refreshMetrics() + if err != nil { + pm.logger.V(logutil.TRACE).Error(err, "Failed to refresh metrics", "pod", pm.GetPod()) + } + + time.Sleep(pm.interval) + } + }() + }) +} + +func (pm *podMetrics) refreshMetrics() error { + pool, err := pm.ds.PoolGet() + if err != nil { + // No inference pool or not initialize. + return err + } + ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout) + defer cancel() + updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics(), pool.Spec.TargetPortNumber) + if err != nil { + // As refresher is running in the background, it's possible that the pod is deleted but + // the refresh goroutine doesn't read the done channel yet. In this case, we just return nil. + // The refresher will be stopped after this interval. + return nil + } + updated.UpdateTime = time.Now() + + pm.logger.V(logutil.TRACE).Info("Refreshed metrics", "updated", updated) + + atomic.StorePointer(&pm.metrics, unsafe.Pointer(updated)) + return nil +} + +func (pm *podMetrics) StopRefreshLoop() { + pm.logger.V(logutil.DEFAULT).Info("Stopping refresher", "pod", pm.GetPod()) + close(pm.done) +} diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go new file mode 100644 index 00000000..cf6698ca --- /dev/null +++ b/pkg/epp/backend/metrics/pod_metrics_test.go @@ -0,0 +1,96 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package metrics + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" +) + +var ( + pod1 = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + }, + } + initial = &Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + } + updated = &Metrics{ + WaitingQueueSize: 9999, + KVCacheUsagePercent: 0.99, + MaxActiveModels: 99, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + } +) + +func TestMetricsRefresh(t *testing.T) { + ctx := context.Background() + pmc := &FakePodMetricsClient{} + pmf := NewPodMetricsFactory(pmc, time.Millisecond) + + // The refresher is initialized with empty metrics. + pm := pmf.NewPodMetrics(ctx, pod1, &fakeDataStore{}) + + namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} + // Use SetRes to simulate an update of metrics from the pod. + // Verify that the metrics are updated. + pmc.SetRes(map[types.NamespacedName]*Metrics{namespacedName: initial}) + condition := func(collect *assert.CollectT) { + assert.True(collect, cmp.Equal(pm.GetMetrics(), initial, cmpopts.IgnoreFields(Metrics{}, "UpdateTime"))) + } + assert.EventuallyWithT(t, condition, time.Second, time.Millisecond) + + // Stop the loop, and simulate metric update again, this time the PodMetrics won't get the + // new update. + pm.StopRefreshLoop() + pmc.SetRes(map[types.NamespacedName]*Metrics{namespacedName: updated}) + // Still expect the same condition (no metrics update). + assert.EventuallyWithT(t, condition, time.Second, time.Millisecond) +} + +type fakeDataStore struct{} + +func (f *fakeDataStore) PoolGet() (*v1alpha2.InferencePool, error) { + return &v1alpha2.InferencePool{Spec: v1alpha2.InferencePoolSpec{TargetPortNumber: 8000}}, nil +} +func (f *fakeDataStore) PodGetAll() []PodMetrics { + // Not implemented. + return nil +} +func (f *fakeDataStore) PodList(func(PodMetrics) bool) []PodMetrics { + // Not implemented. + return nil +} diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go new file mode 100644 index 00000000..cdbdb2ce --- /dev/null +++ b/pkg/epp/backend/metrics/types.go @@ -0,0 +1,114 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package metrics is a library to interact with backend metrics. +package metrics + +import ( + "context" + "fmt" + "sync" + "time" + "unsafe" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +func NewPodMetricsFactory(pmc PodMetricsClient, refreshMetricsInterval time.Duration) *PodMetricsFactory { + return &PodMetricsFactory{ + pmc: pmc, + refreshMetricsInterval: refreshMetricsInterval, + } +} + +type PodMetricsFactory struct { + pmc PodMetricsClient + refreshMetricsInterval time.Duration +} + +func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics { + pm := &podMetrics{ + pod: unsafe.Pointer(toInternalPod(in)), + metrics: unsafe.Pointer(newMetrics()), + pmc: f.pmc, + ds: ds, + interval: f.refreshMetricsInterval, + parentCtx: parentCtx, + once: sync.Once{}, + done: make(chan struct{}), + logger: log.FromContext(parentCtx), + } + pm.startRefreshLoop() + return pm +} + +type PodMetrics interface { + GetPod() *Pod + GetMetrics() *Metrics + UpdatePod(*corev1.Pod) + StopRefreshLoop() +} + +type Pod struct { + NamespacedName types.NamespacedName + Address string +} + +type Metrics struct { + // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. + ActiveModels map[string]int + // MaxActiveModels is the maximum number of models that can be loaded to GPU. + MaxActiveModels int + RunningQueueSize int + WaitingQueueSize int + KVCacheUsagePercent float64 + KvCacheMaxTokenCapacity int + + // UpdateTime record the last time when the metrics were updated. + UpdateTime time.Time +} + +func newMetrics() *Metrics { + return &Metrics{ + ActiveModels: make(map[string]int), + } +} + +func (m *Metrics) String() string { + if m == nil { + return "" + } + return fmt.Sprintf("%+v", *m) +} + +func (m *Metrics) Clone() *Metrics { + cm := make(map[string]int, len(m.ActiveModels)) + for k, v := range m.ActiveModels { + cm[k] = v + } + clone := &Metrics{ + ActiveModels: cm, + MaxActiveModels: m.MaxActiveModels, + RunningQueueSize: m.RunningQueueSize, + WaitingQueueSize: m.WaitingQueueSize, + KVCacheUsagePercent: m.KVCacheUsagePercent, + KvCacheMaxTokenCapacity: m.KvCacheMaxTokenCapacity, + UpdateTime: m.UpdateTime, + } + return clone +} diff --git a/pkg/epp/backend/provider.go b/pkg/epp/backend/provider.go deleted file mode 100644 index 959f3e0c..00000000 --- a/pkg/epp/backend/provider.go +++ /dev/null @@ -1,183 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/go-logr/logr" - "go.uber.org/multierr" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -const ( - fetchMetricsTimeout = 5 * time.Second -) - -func NewProvider(pmc PodMetricsClient, datastore datastore.Datastore) *Provider { - p := &Provider{ - pmc: pmc, - datastore: datastore, - } - return p -} - -// Provider provides backend pods and information such as metrics. -type Provider struct { - pmc PodMetricsClient - datastore datastore.Datastore -} - -type PodMetricsClient interface { - FetchMetrics(ctx context.Context, existing *datastore.PodMetrics, port int32) (*datastore.PodMetrics, error) -} - -func (p *Provider) Init(ctx context.Context, refreshMetricsInterval, refreshPrometheusMetricsInterval time.Duration) error { - // periodically refresh metrics - logger := log.FromContext(ctx) - go func() { - for { - select { - case <-ctx.Done(): - logger.V(logutil.DEFAULT).Info("Shutting down metrics prober") - return - default: - time.Sleep(refreshMetricsInterval) - if err := p.refreshMetricsOnce(logger); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Failed to refresh metrics") - } - } - } - }() - - // Periodically flush prometheus metrics for inference pool - go func() { - for { - select { - case <-ctx.Done(): - logger.V(logutil.DEFAULT).Info("Shutting down prometheus metrics thread") - return - default: - time.Sleep(refreshPrometheusMetricsInterval) - p.flushPrometheusMetricsOnce(logger) - } - } - }() - - // Periodically print out the pods and metrics for DEBUGGING. - if logger := logger.V(logutil.DEBUG); logger.Enabled() { - go func() { - for { - select { - case <-ctx.Done(): - logger.V(logutil.DEFAULT).Info("Shutting down metrics logger thread") - return - default: - time.Sleep(5 * time.Second) - logger.Info("Current Pods and metrics gathered", "metrics", p.datastore.PodGetAll()) - } - } - }() - } - - return nil -} - -func (p *Provider) refreshMetricsOnce(logger logr.Logger) error { - loggerTrace := logger.V(logutil.TRACE) - pool, _ := p.datastore.PoolGet() - if pool == nil { - loggerTrace.Info("No inference pool or not initialized") - return nil - } - ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout) - defer cancel() - start := time.Now() - defer func() { - d := time.Since(start) - // TODO: add a metric instead of logging - loggerTrace.Info("Metrics refreshed", "duration", d) - }() - - var wg sync.WaitGroup - errCh := make(chan error) - processOnePod := func(key, value any) bool { - loggerTrace.Info("Pod and metric being processed", "pod", key, "metric", value) - existing := value.(*datastore.PodMetrics) - wg.Add(1) - go func() { - defer wg.Done() - updated, err := p.pmc.FetchMetrics(ctx, existing, pool.Spec.TargetPortNumber) - if err != nil { - errCh <- fmt.Errorf("failed to parse metrics from %s: %v", existing.NamespacedName, err) - return - } - p.datastore.PodUpdateMetricsIfExist(updated.NamespacedName, &updated.Metrics) - loggerTrace.Info("Updated metrics for pod", "pod", updated.NamespacedName, "metrics", updated.Metrics) - }() - return true - } - p.datastore.PodRange(processOnePod) - - // Wait for metric collection for all pods to complete and close the error channel in a - // goroutine so this is unblocking, allowing the code to proceed to the error collection code - // below. - // Note we couldn't use a buffered error channel with a size because the size of the podMetrics - // sync.Map is unknown beforehand. - go func() { - wg.Wait() - close(errCh) - }() - - var errs error - for err := range errCh { - errs = multierr.Append(errs, err) - } - return errs -} - -func (p *Provider) flushPrometheusMetricsOnce(logger logr.Logger) { - pool, _ := p.datastore.PoolGet() - if pool == nil { - // No inference pool or not initialize. - return - } - - var kvCacheTotal float64 - var queueTotal int - - podMetrics := p.datastore.PodGetAll() - logger.V(logutil.VERBOSE).Info("Flushing Prometheus Metrics", "ReadyPods", len(podMetrics)) - if len(podMetrics) == 0 { - return - } - - for _, pod := range podMetrics { - kvCacheTotal += pod.KVCacheUsagePercent - queueTotal += pod.WaitingQueueSize - } - - podTotalCount := len(podMetrics) - metrics.RecordInferencePoolAvgKVCache(pool.Name, kvCacheTotal/float64(podTotalCount)) - metrics.RecordInferencePoolAvgQueueSize(pool.Name, float64(queueTotal/podTotalCount)) -} diff --git a/pkg/epp/backend/provider_test.go b/pkg/epp/backend/provider_test.go deleted file mode 100644 index 12994723..00000000 --- a/pkg/epp/backend/provider_test.go +++ /dev/null @@ -1,151 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/stretchr/testify/assert" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" -) - -var ( - pod1 = &datastore.PodMetrics{ - Pod: datastore.Pod{ - NamespacedName: types.NamespacedName{ - Name: "pod1", - }, - }, - } - pod1WithMetrics = &datastore.PodMetrics{ - Pod: pod1.Pod, - Metrics: datastore.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - } - pod2 = &datastore.PodMetrics{ - Pod: datastore.Pod{ - NamespacedName: types.NamespacedName{ - Name: "pod2", - }, - }, - } - pod2WithMetrics = &datastore.PodMetrics{ - Pod: pod2.Pod, - Metrics: datastore.Metrics{ - WaitingQueueSize: 1, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo1": 1, - "bar1": 1, - }, - }, - } - - inferencePool = &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - TargetPortNumber: 8000, - }, - } -) - -func TestProvider(t *testing.T) { - tests := []struct { - name string - pmc PodMetricsClient - storePods []*datastore.PodMetrics - want []*datastore.PodMetrics - }{ - { - name: "Probing metrics success", - pmc: &FakePodMetricsClient{ - Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1WithMetrics, - pod2.NamespacedName: pod2WithMetrics, - }, - }, - storePods: []*datastore.PodMetrics{pod1, pod2}, - want: []*datastore.PodMetrics{pod1WithMetrics, pod2WithMetrics}, - }, - { - name: "Only pods in the datastore are probed", - pmc: &FakePodMetricsClient{ - Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1WithMetrics, - pod2.NamespacedName: pod2WithMetrics, - }, - }, - storePods: []*datastore.PodMetrics{pod1}, - want: []*datastore.PodMetrics{pod1WithMetrics}, - }, - { - name: "Probing metrics error", - pmc: &FakePodMetricsClient{ - Err: map[types.NamespacedName]error{ - pod2.NamespacedName: errors.New("injected error"), - }, - Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1WithMetrics, - }, - }, - storePods: []*datastore.PodMetrics{pod1, pod2}, - want: []*datastore.PodMetrics{ - pod1WithMetrics, - // Failed to fetch pod2 metrics so it remains the default values. - { - Pod: datastore.Pod{NamespacedName: pod2.NamespacedName}, - Metrics: datastore.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0, - MaxActiveModels: 0, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ds := datastore.NewFakeDatastore(test.storePods, nil, inferencePool) - p := NewProvider(test.pmc, ds) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - _ = p.Init(ctx, time.Millisecond, time.Millisecond) - assert.EventuallyWithT(t, func(t *assert.CollectT) { - metrics := ds.PodGetAll() - diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(func(a, b *datastore.PodMetrics) bool { - return a.String() < b.String() - })) - assert.Equal(t, "", diff, "Unexpected diff (+got/-want)") - }, 5*time.Second, time.Millisecond) - }) - } -} diff --git a/pkg/epp/backend/vllm/metrics.go b/pkg/epp/backend/vllm/metrics.go index 5b36b930..f83326eb 100644 --- a/pkg/epp/backend/vllm/metrics.go +++ b/pkg/epp/backend/vllm/metrics.go @@ -30,7 +30,7 @@ import ( "github.com/prometheus/common/expfmt" "go.uber.org/multierr" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -57,15 +57,16 @@ type PodMetricsClientImpl struct{} // FetchMetrics fetches metrics from a given pod. func (p *PodMetricsClientImpl) FetchMetrics( ctx context.Context, - existing *datastore.PodMetrics, + pod *metrics.Pod, + existing *metrics.Metrics, port int32, -) (*datastore.PodMetrics, error) { +) (*metrics.Metrics, error) { logger := log.FromContext(ctx) loggerDefault := logger.V(logutil.DEFAULT) // Currently the metrics endpoint is hard-coded, which works with vLLM. // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16): Consume this from InferencePool config. - url := "http://" + existing.Address + ":" + strconv.Itoa(int(port)) + "/metrics" + url := "http://" + pod.Address + ":" + strconv.Itoa(int(port)) + "/metrics" req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -74,16 +75,16 @@ func (p *PodMetricsClientImpl) FetchMetrics( } resp, err := http.DefaultClient.Do(req) if err != nil { - loggerDefault.Error(err, "Failed to fetch metrics", "pod", existing.NamespacedName) - return nil, fmt.Errorf("failed to fetch metrics from %s: %w", existing.NamespacedName, err) + loggerDefault.Error(err, "Failed to fetch metrics", "pod", pod.NamespacedName) + return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod.NamespacedName, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - loggerDefault.Error(nil, "Unexpected status code returned", "pod", existing.NamespacedName, "statusCode", resp.StatusCode) - return nil, fmt.Errorf("unexpected status code from %s: %v", existing.NamespacedName, resp.StatusCode) + loggerDefault.Error(nil, "Unexpected status code returned", "pod", pod.NamespacedName, "statusCode", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code from %s: %v", pod.NamespacedName, resp.StatusCode) } parser := expfmt.TextParser{} @@ -100,8 +101,8 @@ func (p *PodMetricsClientImpl) FetchMetrics( func promToPodMetrics( logger logr.Logger, metricFamilies map[string]*dto.MetricFamily, - existing *datastore.PodMetrics, -) (*datastore.PodMetrics, error) { + existing *metrics.Metrics, +) (*metrics.Metrics, error) { var errs error updated := existing.Clone() runningQueueSize, err := getLatestMetric(logger, metricFamilies, RunningQueueSizeMetricName) diff --git a/pkg/epp/backend/vllm/metrics_test.go b/pkg/epp/backend/vllm/metrics_test.go index 12aac1a1..5555bd26 100644 --- a/pkg/epp/backend/vllm/metrics_test.go +++ b/pkg/epp/backend/vllm/metrics_test.go @@ -23,7 +23,7 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -31,11 +31,11 @@ func TestPromToPodMetrics(t *testing.T) { logger := logutil.NewTestLogger() testCases := []struct { - name string - metricFamilies map[string]*dto.MetricFamily - expectedMetrics *datastore.Metrics - expectedErr error - initialPodMetrics *datastore.PodMetrics + name string + metricFamilies map[string]*dto.MetricFamily + initialMetrics *metrics.Metrics + expectedMetrics *metrics.Metrics + expectedErr error }{ { name: "all metrics available", @@ -123,7 +123,7 @@ func TestPromToPodMetrics(t *testing.T) { }, }, }, - expectedMetrics: &datastore.Metrics{ + expectedMetrics: &metrics.Metrics{ RunningQueueSize: 15, WaitingQueueSize: 25, KVCacheUsagePercent: 0.9, @@ -133,8 +133,8 @@ func TestPromToPodMetrics(t *testing.T) { }, MaxActiveModels: 2, }, - initialPodMetrics: &datastore.PodMetrics{}, - expectedErr: nil, + initialMetrics: &metrics.Metrics{}, + expectedErr: nil, }, { name: "invalid max lora", @@ -222,7 +222,7 @@ func TestPromToPodMetrics(t *testing.T) { }, }, }, - expectedMetrics: &datastore.Metrics{ + expectedMetrics: &metrics.Metrics{ RunningQueueSize: 15, WaitingQueueSize: 25, KVCacheUsagePercent: 0.9, @@ -232,18 +232,18 @@ func TestPromToPodMetrics(t *testing.T) { }, MaxActiveModels: 0, }, - initialPodMetrics: &datastore.PodMetrics{}, - expectedErr: errors.New("strconv.Atoi: parsing '2a': invalid syntax"), + initialMetrics: &metrics.Metrics{}, + expectedErr: errors.New("strconv.Atoi: parsing '2a': invalid syntax"), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - updated, err := promToPodMetrics(logger, tc.metricFamilies, tc.initialPodMetrics) + updated, err := promToPodMetrics(logger, tc.metricFamilies, tc.initialMetrics) if tc.expectedErr != nil { assert.Error(t, err) } else { assert.NoError(t, err) - assert.Equal(t, tc.expectedMetrics, &updated.Metrics) + assert.Equal(t, tc.expectedMetrics, updated) } }) } diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index 2ac5bb1e..cd1ff1fb 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -19,6 +19,7 @@ package controller import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -29,6 +30,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -189,12 +191,16 @@ func TestInferenceModelReconciler(t *testing.T) { WithObjects(initObjs...). WithIndex(&v1alpha2.InferenceModel{}, datastore.ModelNameIndexKey, indexInferenceModelsByModelName). Build() - - datastore := datastore.NewFakeDatastore(nil, test.modelsInStore, pool) + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := datastore.NewDatastore(t.Context(), pmf) + for _, m := range test.modelsInStore { + ds.ModelSetIfOlder(m) + } + ds.PoolSet(pool) reconciler := &InferenceModelReconciler{ Client: fakeClient, Record: record.NewFakeRecorder(10), - Datastore: datastore, + Datastore: ds, PoolNamespacedName: types.NamespacedName{Name: pool.Name, Namespace: pool.Namespace}, } if test.incomingReq == nil { @@ -211,11 +217,11 @@ func TestInferenceModelReconciler(t *testing.T) { t.Errorf("Unexpected result diff (+got/-want): %s", diff) } - if len(test.wantModels) != len(datastore.ModelGetAll()) { - t.Errorf("Unexpected; want: %d, got:%d", len(test.wantModels), len(datastore.ModelGetAll())) + if len(test.wantModels) != len(ds.ModelGetAll()) { + t.Errorf("Unexpected; want: %d, got:%d", len(test.wantModels), len(ds.ModelGetAll())) } - if diff := diffStore(datastore, diffStoreParams{wantPool: pool, wantModels: test.wantModels}); diff != "" { + if diff := diffStore(ds, diffStoreParams{wantPool: pool, wantModels: test.wantModels}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go index 880aec8c..c92d4ecc 100644 --- a/pkg/epp/controller/inferencepool_reconciler.go +++ b/pkg/epp/controller/inferencepool_reconciler.go @@ -80,7 +80,7 @@ func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool * // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need // to resync the whole pool: remove pods in the store that don't match the new selector and add // the ones that may have existed already to the store. - c.Datastore.PodResyncAll(ctx, c.Client) + c.Datastore.PodResyncAll(ctx, c.Client, newPool) } } diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go index f35b8dc0..27c4238e 100644 --- a/pkg/epp/controller/inferencepool_reconciler_test.go +++ b/pkg/epp/controller/inferencepool_reconciler_test.go @@ -19,6 +19,7 @@ package controller import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -30,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -92,7 +94,8 @@ func TestInferencePoolReconciler(t *testing.T) { req := ctrl.Request{NamespacedName: namespacedName} ctx := context.Background() - datastore := datastore.NewDatastore() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + datastore := datastore.NewDatastore(ctx, pmf) inferencePoolReconciler := &InferencePoolReconciler{PoolNamespacedName: namespacedName, Client: fakeClient, Datastore: datastore} // Step 1: Inception, only ready pods matching pool1 are added to the store. @@ -167,7 +170,7 @@ func diffStore(datastore datastore.Datastore, params diffStoreParams) string { } gotPods := []string{} for _, pm := range datastore.PodGetAll() { - gotPods = append(gotPods, pm.NamespacedName.Name) + gotPods = append(gotPods, pm.GetPod().NamespacedName.Name) } if diff := cmp.Diff(params.wantPods, gotPods, cmpopts.SortSlices(func(a, b string) bool { return a < b })); diff != "" { return "pods:" + diff diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index a6c897c2..046561e4 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -27,6 +27,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -39,7 +40,8 @@ type PodReconciler struct { func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - if !c.Datastore.PoolHasSynced() { + pool, err := c.Datastore.PoolGet() + if err != nil { logger.V(logutil.TRACE).Info("Skipping reconciling Pod because the InferencePool is not available yet") // When the inferencePool is initialized it lists the appropriate pods and populates the datastore, so no need to requeue. return ctrl.Result{}, nil @@ -57,7 +59,7 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R return ctrl.Result{}, err } - c.updateDatastore(logger, pod) + c.updateDatastore(logger, pod, pool) return ctrl.Result{}, nil } @@ -67,13 +69,13 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(c) } -func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { +func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod, pool *v1alpha2.InferencePool) { namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podIsReady(pod) { logger.V(logutil.DEBUG).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { - if c.Datastore.PodUpdateOrAddIfNotExist(pod) { + if c.Datastore.PodUpdateOrAddIfNotExist(pod, pool) { logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) } else { logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index 7534ac0f..e4cb0b62 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -19,10 +19,12 @@ package controller import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" @@ -30,129 +32,138 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) var ( - basePod1 = &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: "address-1"}} - basePod2 = &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}, Address: "address-2"}} - basePod3 = &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}, Address: "address-3"}} - basePod11 = &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: "address-11"}} + basePod1 = &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}} + basePod2 = &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2"}, Status: corev1.PodStatus{PodIP: "address-2"}} + basePod3 = &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod3"}, Status: corev1.PodStatus{PodIP: "address-3"}} + basePod11 = &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-11"}} + pmc = &backendmetrics.FakePodMetricsClient{} + pmf = backendmetrics.NewPodMetricsFactory(pmc, time.Second) ) func TestPodReconciler(t *testing.T) { tests := []struct { - name string - datastore datastore.Datastore - incomingPod *corev1.Pod - wantPods []datastore.Pod - req *ctrl.Request + name string + pool *v1alpha2.InferencePool + existingPods []*corev1.Pod + incomingPod *corev1.Pod + wantPods []*corev1.Pod + req *ctrl.Request }{ { - name: "Add new pod", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Add new pod", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod(basePod3.NamespacedName.Name). + }, + incomingPod: utiltest.FromBase(basePod3). Labels(map[string]string{"some-key": "some-val"}). - IP(basePod3.Address). ReadyCondition().ObjRef(), - wantPods: []datastore.Pod{basePod1.Pod, basePod2.Pod, basePod3.Pod}, + wantPods: []*corev1.Pod{basePod1, basePod2, basePod3}, }, { - name: "Update pod1 address", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Update pod1 address", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod(basePod11.NamespacedName.Name). + }, + incomingPod: utiltest.FromBase(basePod11). Labels(map[string]string{"some-key": "some-val"}). - IP(basePod11.Address). ReadyCondition().ObjRef(), - wantPods: []datastore.Pod{basePod11.Pod, basePod2.Pod}, + wantPods: []*corev1.Pod{basePod11, basePod2}, }, { - name: "Delete pod with DeletionTimestamp", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Delete pod with DeletionTimestamp", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod("pod1"). + }, + incomingPod: utiltest.FromBase(basePod1). Labels(map[string]string{"some-key": "some-val"}). DeletionTimestamp(). ReadyCondition().ObjRef(), - wantPods: []datastore.Pod{basePod2.Pod}, + wantPods: []*corev1.Pod{basePod2}, }, { - name: "Delete notfound pod", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Delete notfound pod", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), + }, req: &ctrl.Request{NamespacedName: types.NamespacedName{Name: "pod1"}}, - wantPods: []datastore.Pod{basePod2.Pod}, + wantPods: []*corev1.Pod{basePod2}, }, { - name: "New pod, not ready, valid selector", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "New pod, not ready, valid selector", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod("pod3"). + }, + incomingPod: utiltest.FromBase(basePod3). Labels(map[string]string{"some-key": "some-val"}).ObjRef(), - wantPods: []datastore.Pod{basePod1.Pod, basePod2.Pod}, + wantPods: []*corev1.Pod{basePod1, basePod2}, }, { - name: "Remove pod that does not match selector", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Remove pod that does not match selector", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod("pod1"). + }, + incomingPod: utiltest.FromBase(basePod1). Labels(map[string]string{"some-wrong-key": "some-val"}). ReadyCondition().ObjRef(), - wantPods: []datastore.Pod{basePod2.Pod}, + wantPods: []*corev1.Pod{basePod2}, }, { - name: "Remove pod that is not ready", - datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ + name: "Remove pod that is not ready", + existingPods: []*corev1.Pod{basePod1, basePod2}, + pool: &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ "some-key": "some-val", }, }, - }), - incomingPod: utiltest.MakePod("pod1"). + }, + incomingPod: utiltest.FromBase(basePod1). Labels(map[string]string{"some-wrong-key": "some-val"}). ReadyCondition().ObjRef(), - wantPods: []datastore.Pod{basePod2.Pod}, + wantPods: []*corev1.Pod{basePod2}, }, } for _, test := range tests { @@ -169,24 +180,28 @@ func TestPodReconciler(t *testing.T) { WithObjects(initialObjects...). Build() - podReconciler := &PodReconciler{Client: fakeClient, Datastore: test.datastore} - namespacedName := types.NamespacedName{Name: test.incomingPod.Name, Namespace: test.incomingPod.Namespace} + // Configure the initial state of the datastore. + store := datastore.NewDatastore(t.Context(), pmf) + store.PoolSet(test.pool) + for _, pod := range test.existingPods { + store.PodUpdateOrAddIfNotExist(pod, pool) + } + + podReconciler := &PodReconciler{Client: fakeClient, Datastore: store} if test.req == nil { + namespacedName := types.NamespacedName{Name: test.incomingPod.Name, Namespace: test.incomingPod.Namespace} test.req = &ctrl.Request{NamespacedName: namespacedName} } if _, err := podReconciler.Reconcile(context.Background(), *test.req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - var gotPods []datastore.Pod - test.datastore.PodRange(func(k, v any) bool { - pod := v.(*datastore.PodMetrics) - if v != nil { - gotPods = append(gotPods, pod.Pod) - } - return true - }) - if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b datastore.Pod) bool { return a.NamespacedName.String() < b.NamespacedName.String() })) { + var gotPods []*corev1.Pod + for _, pm := range store.PodGetAll() { + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}} + gotPods = append(gotPods, pod) + } + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { t.Errorf("got (%v) != want (%v);", gotPods, test.wantPods) } }) diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index c7050437..af31da42 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -57,56 +58,40 @@ type Datastore interface { ModelGetAll() []*v1alpha2.InferenceModel // PodMetrics operations - PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool - PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool - PodGet(namespacedName types.NamespacedName) *PodMetrics + // PodGetAll returns all pods and metrics, including fresh and stale. + PodGetAll() []backendmetrics.PodMetrics + // PodList lists pods matching the given predicate. + PodList(func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool PodDelete(namespacedName types.NamespacedName) - PodResyncAll(ctx context.Context, ctrlClient client.Client) - PodGetAll() []*PodMetrics - PodDeleteAll() // This is only for testing. - PodRange(f func(key, value any) bool) + PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) // Clears the store state, happens when the pool gets deleted. Clear() } -func NewDatastore() Datastore { +func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFactory) *datastore { store := &datastore{ + parentCtx: parentCtx, poolAndModelsMu: sync.RWMutex{}, models: make(map[string]*v1alpha2.InferenceModel), pods: &sync.Map{}, - } - return store -} - -// Used for test only -func NewFakeDatastore(pods []*PodMetrics, models []*v1alpha2.InferenceModel, pool *v1alpha2.InferencePool) Datastore { - store := NewDatastore() - - for _, pod := range pods { - // Making a copy since in tests we may use the same global PodMetric across tests. - p := *pod - store.(*datastore).pods.Store(pod.NamespacedName, &p) - } - - for _, m := range models { - store.ModelSetIfOlder(m) - } - - if pool != nil { - store.(*datastore).pool = pool + pmf: pmf, } return store } type datastore struct { + // parentCtx controls the lifecycle of the background metrics goroutines that spawn up by the datastore. + parentCtx context.Context // poolAndModelsMu is used to synchronize access to pool and the models map. poolAndModelsMu sync.RWMutex pool *v1alpha2.InferencePool // key: InferenceModel.Spec.ModelName, value: *InferenceModel models map[string]*v1alpha2.InferenceModel - // key: types.NamespacedName, value: *PodMetrics + // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map + pmf *backendmetrics.PodMetricsFactory } func (ds *datastore) Clear() { @@ -227,68 +212,44 @@ func (ds *datastore) ModelGetAll() []*v1alpha2.InferenceModel { } // /// Pods/endpoints APIs /// -func (ds *datastore) PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool { - if val, ok := ds.pods.Load(namespacedName); ok { - existing := val.(*PodMetrics) - existing.Metrics = *m - return true - } - return false -} -func (ds *datastore) PodGet(namespacedName types.NamespacedName) *PodMetrics { - val, ok := ds.pods.Load(namespacedName) - if ok { - return val.(*PodMetrics) - } - return nil +func (ds *datastore) PodGetAll() []backendmetrics.PodMetrics { + return ds.PodList(func(backendmetrics.PodMetrics) bool { return true }) } -func (ds *datastore) PodGetAll() []*PodMetrics { - res := []*PodMetrics{} +func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics { + res := []backendmetrics.PodMetrics{} fn := func(k, v any) bool { - res = append(res, v.(*PodMetrics)) + pm := v.(backendmetrics.PodMetrics) + if predicate(pm) { + res = append(res, pm) + } return true } ds.pods.Range(fn) return res } -func (ds *datastore) PodRange(f func(key, value any) bool) { - ds.pods.Range(f) -} - -func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { - ds.pods.Delete(namespacedName) -} - -func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { - new := &PodMetrics{ - Pod: Pod{ - NamespacedName: types.NamespacedName{ - Name: pod.Name, - Namespace: pod.Namespace, - }, - Address: pod.Status.PodIP, - }, - Metrics: Metrics{ - ActiveModels: make(map[string]int), - }, +func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool { + namespacedName := types.NamespacedName{ + Name: pod.Name, + Namespace: pod.Namespace, } - existing, ok := ds.pods.Load(new.NamespacedName) + var pm backendmetrics.PodMetrics + existing, ok := ds.pods.Load(namespacedName) if !ok { - ds.pods.Store(new.NamespacedName, new) - return true + pm = ds.pmf.NewPodMetrics(ds.parentCtx, pod, ds) + ds.pods.Store(namespacedName, pm) + } else { + pm = existing.(backendmetrics.PodMetrics) } - // Update pod properties if anything changed. - existing.(*PodMetrics).Pod = new.Pod - return false + pm.UpdatePod(pod) + return ok } -func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client) { - // Pool must exist to invoke this function. - pool, _ := ds.PoolGet() +func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) { + logger := log.FromContext(ctx) podList := &corev1.PodList{} if err := ctrlClient.List(ctx, podList, &client.ListOptions{ LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector), @@ -301,24 +262,34 @@ func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client) activePods := make(map[string]bool) for _, pod := range podList.Items { if podIsReady(&pod) { + namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} activePods[pod.Name] = true - ds.PodUpdateOrAddIfNotExist(&pod) + if ds.PodUpdateOrAddIfNotExist(&pod, pool) { + logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) + } else { + logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) + } } } // Remove pods that don't belong to the pool or not ready any more. deleteFn := func(k, v any) bool { - pm := v.(*PodMetrics) - if exist := activePods[pm.NamespacedName.Name]; !exist { - ds.pods.Delete(pm.NamespacedName) + pm := v.(backendmetrics.PodMetrics) + if exist := activePods[pm.GetPod().NamespacedName.Name]; !exist { + logger.V(logutil.VERBOSE).Info("Removing pod", "pod", pm.GetPod()) + ds.PodDelete(pm.GetPod().NamespacedName) } return true } ds.pods.Range(deleteFn) } -func (ds *datastore) PodDeleteAll() { - ds.pods.Clear() +func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { + v, ok := ds.pods.LoadAndDelete(namespacedName) + if ok { + pmr := v.(backendmetrics.PodMetrics) + pmr.StopRefreshLoop() + } } func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector { diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 8fb269bc..f60a4cc9 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -17,13 +17,19 @@ limitations under the License. package datastore import ( + "context" + "errors" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -66,7 +72,8 @@ func TestPool(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - datastore := NewDatastore() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + datastore := NewDatastore(context.Background(), pmf) datastore.PoolSet(tt.inferencePool) gotPool, gotErr := datastore.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { @@ -197,7 +204,12 @@ func TestModel(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ds := NewFakeDatastore(nil, test.existingModels, nil) + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := NewDatastore(t.Context(), pmf) + for _, m := range test.existingModels { + ds.ModelSetIfOlder(m) + } + gotOpResult := test.op(ds) if gotOpResult != test.wantOpResult { t.Errorf("Unexpected operation result, want: %v, got: %v", test.wantOpResult, gotOpResult) @@ -317,3 +329,119 @@ func TestRandomWeightedDraw(t *testing.T) { func pointer(v int32) *int32 { return &v } + +var ( + pod1 = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + }, + } + pod1Metrics = &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + } + pod2 = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod2", + }, + } + pod2Metrics = &backendmetrics.Metrics{ + WaitingQueueSize: 1, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo1": 1, + "bar1": 1, + }, + } + pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} + pod2NamespacedName = types.NamespacedName{Name: pod2.Name, Namespace: pod2.Namespace} + inferencePool = &v1alpha2.InferencePool{ + Spec: v1alpha2.InferencePoolSpec{ + TargetPortNumber: 8000, + }, + } +) + +func TestMetrics(t *testing.T) { + tests := []struct { + name string + pmc backendmetrics.PodMetricsClient + storePods []*corev1.Pod + want []*backendmetrics.Metrics + }{ + { + name: "Probing metrics success", + pmc: &backendmetrics.FakePodMetricsClient{ + Res: map[types.NamespacedName]*backendmetrics.Metrics{ + pod1NamespacedName: pod1Metrics, + pod2NamespacedName: pod2Metrics, + }, + }, + storePods: []*corev1.Pod{pod1, pod2}, + want: []*backendmetrics.Metrics{pod1Metrics, pod2Metrics}, + }, + { + name: "Only pods in are probed", + pmc: &backendmetrics.FakePodMetricsClient{ + Res: map[types.NamespacedName]*backendmetrics.Metrics{ + pod1NamespacedName: pod1Metrics, + pod2NamespacedName: pod2Metrics, + }, + }, + storePods: []*corev1.Pod{pod1}, + want: []*backendmetrics.Metrics{pod1Metrics}, + }, + { + name: "Probing metrics error", + pmc: &backendmetrics.FakePodMetricsClient{ + Err: map[types.NamespacedName]error{ + pod2NamespacedName: errors.New("injected error"), + }, + Res: map[types.NamespacedName]*backendmetrics.Metrics{ + pod1NamespacedName: pod1Metrics, + }, + }, + storePods: []*corev1.Pod{pod1, pod2}, + want: []*backendmetrics.Metrics{ + pod1Metrics, + // Failed to fetch pod2 metrics so it remains the default values. + { + ActiveModels: map[string]int{}, + WaitingQueueSize: 0, + KVCacheUsagePercent: 0, + MaxActiveModels: 0, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) + ds := NewDatastore(ctx, pmf) + ds.PoolSet(inferencePool) + for _, pod := range test.storePods { + ds.PodUpdateOrAddIfNotExist(pod, inferencePool) + } + assert.EventuallyWithT(t, func(t *assert.CollectT) { + got := ds.PodGetAll() + metrics := []*backendmetrics.Metrics{} + for _, one := range got { + metrics = append(metrics, one.GetMetrics()) + } + diff := cmp.Diff(test.want, metrics, cmpopts.IgnoreFields(backendmetrics.Metrics{}, "UpdateTime"), cmpopts.SortSlices(func(a, b *backendmetrics.Metrics) bool { + return a.String() < b.String() + })) + assert.Equal(t, "", diff, "Unexpected diff (+got/-want)") + }, 5*time.Second, time.Millisecond) + }) + } +} diff --git a/pkg/epp/datastore/types.go b/pkg/epp/datastore/types.go deleted file mode 100644 index 8cfcf1d1..00000000 --- a/pkg/epp/datastore/types.go +++ /dev/null @@ -1,71 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package datastore is a library to interact with backend model servers such as probing metrics. -package datastore - -import ( - "fmt" - - "k8s.io/apimachinery/pkg/types" -) - -type Pod struct { - NamespacedName types.NamespacedName - Address string -} - -type Metrics struct { - // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. - ActiveModels map[string]int - // MaxActiveModels is the maximum number of models that can be loaded to GPU. - MaxActiveModels int - RunningQueueSize int - WaitingQueueSize int - KVCacheUsagePercent float64 - KvCacheMaxTokenCapacity int -} - -type PodMetrics struct { - Pod - Metrics -} - -func (pm *PodMetrics) String() string { - return fmt.Sprintf("Pod: %+v; Address: %+v; Metrics: %+v", pm.NamespacedName, pm.Address, pm.Metrics) -} - -func (pm *PodMetrics) Clone() *PodMetrics { - cm := make(map[string]int, len(pm.ActiveModels)) - for k, v := range pm.ActiveModels { - cm[k] = v - } - clone := &PodMetrics{ - Pod: Pod{ - NamespacedName: pm.NamespacedName, - Address: pm.Address, - }, - Metrics: Metrics{ - ActiveModels: cm, - MaxActiveModels: pm.MaxActiveModels, - RunningQueueSize: pm.RunningQueueSize, - WaitingQueueSize: pm.WaitingQueueSize, - KVCacheUsagePercent: pm.KVCacheUsagePercent, - KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity, - }, - } - return clone -} diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 20271913..12afe4d7 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -94,10 +94,11 @@ func (s *Server) HandleRequestBody( loggerVerbose.Info("Updated request body marshalled", "body", string(requestBody)) } - targetPod, err := s.scheduler.Schedule(ctx, llmReq) + target, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } + targetPod := target.GetPod() logger.V(logutil.DEFAULT).Info("Request handled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index bbdbe83e..be882fc7 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" @@ -56,7 +57,7 @@ type Server struct { } type Scheduler interface { - Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod datastore.PodMetrics, err error) + Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backendmetrics.PodMetrics, err error) } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go index 821dd989..c8de7bb7 100644 --- a/pkg/epp/handlers/streamingserver.go +++ b/pkg/epp/handlers/streamingserver.go @@ -347,10 +347,11 @@ func (s *StreamingServer) HandleRequestBody( loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes)) } - targetPod, err := s.scheduler.Schedule(ctx, llmReq) + target, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } + targetPod := target.GetPod() // Insert target endpoint to instruct Envoy to route requests to the specified target pod. // Attach the port number diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/filter.go index d3c22673..cee683c5 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/filter.go @@ -23,13 +23,13 @@ import ( "time" "github.com/go-logr/logr" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) type Filter interface { Name() string - Filter(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) + Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) } // filter applies current filterFunc, and then recursively applies next filters depending success or @@ -59,7 +59,7 @@ func (f *filter) Name() string { return f.name } -func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { +func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { loggerTrace := logger.V(logutil.TRACE) loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) @@ -92,12 +92,12 @@ func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []*datastore.P } // filterFunc filters a set of input pods to a subset. -type filterFunc func(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) +type filterFunc func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { - filtered := []*datastore.PodMetrics{} + return func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { + filtered := []backendmetrics.PodMetrics{} for _, pod := range pods { pass := pp(req, pod) if pass { @@ -118,30 +118,30 @@ func toFilterFunc(pp podPredicate) filterFunc { // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { +func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { min := math.MaxInt max := 0 - filtered := []*datastore.PodMetrics{} + filtered := []backendmetrics.PodMetrics{} for _, pod := range pods { - if pod.WaitingQueueSize <= min { - min = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize <= min { + min = pod.GetMetrics().WaitingQueueSize } - if pod.WaitingQueueSize >= max { - max = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize >= max { + max = pod.GetMetrics().WaitingQueueSize } } for _, pod := range pods { - if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) { + if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { filtered = append(filtered, pod) } } return filtered, nil } -func lowQueueingPodPredicate(_ *LLMRequest, pod *datastore.PodMetrics) bool { - return pod.WaitingQueueSize < queueingThresholdLoRA +func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool { + return pod.GetMetrics().WaitingQueueSize < queueingThresholdLoRA } // leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range @@ -150,22 +150,22 @@ func lowQueueingPodPredicate(_ *LLMRequest, pod *datastore.PodMetrics) bool { // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { +func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { min := math.MaxFloat64 var max float64 = 0 - filtered := []*datastore.PodMetrics{} + filtered := []backendmetrics.PodMetrics{} for _, pod := range pods { - if pod.KVCacheUsagePercent <= min { - min = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent <= min { + min = pod.GetMetrics().KVCacheUsagePercent } - if pod.KVCacheUsagePercent >= max { - max = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent >= max { + max = pod.GetMetrics().KVCacheUsagePercent } } for _, pod := range pods { - if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { filtered = append(filtered, pod) } } @@ -173,16 +173,16 @@ func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []*datasto } // podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *LLMRequest, pod *datastore.PodMetrics) bool +type podPredicate func(req *LLMRequest, pod backendmetrics.PodMetrics) bool // We consider serving an adapter low cost it the adapter is active in the model server, or the // model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by // spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to // a single pod. This gave good performance in our initial benchmarking results in the scenario // where # of lora slots > # of lora adapters. -func lowLoRACostPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool { - _, ok := pod.ActiveModels[req.ResolvedTargetModel] - return ok || len(pod.ActiveModels) < pod.MaxActiveModels +func lowLoRACostPredicate(req *LLMRequest, pod backendmetrics.PodMetrics) bool { + _, ok := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel] + return ok || len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels } // loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods @@ -201,18 +201,18 @@ func lowLoRACostPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool { // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { +func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]*datastore.PodMetrics, 0, len(pods)) - filtered_available := make([]*datastore.PodMetrics, 0, len(pods)) + filtered_affinity := make([]backendmetrics.PodMetrics, 0, len(pods)) + filtered_available := make([]backendmetrics.PodMetrics, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { - if _, exists := pod.ActiveModels[req.ResolvedTargetModel]; exists { + if _, exists := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel]; exists { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.ActiveModels) < pod.MaxActiveModels { + } else if len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -237,12 +237,12 @@ func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []*datasto return filtered_available, nil } -func criticalRequestPredicate(req *LLMRequest, _ *datastore.PodMetrics) bool { +func criticalRequestPredicate(req *LLMRequest, _ backendmetrics.PodMetrics) bool { return req.Critical } func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate { - return func(req *LLMRequest, pod *datastore.PodMetrics) bool { - return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold + return func(req *LLMRequest, pod backendmetrics.PodMetrics) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold && pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold } } diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/filter_test.go index f76cece9..62ffe7f2 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/filter_test.go @@ -23,7 +23,7 @@ import ( "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -33,14 +33,14 @@ func TestFilter(t *testing.T) { tests := []struct { name string req *LLMRequest - input []*datastore.PodMetrics - output []*datastore.PodMetrics + input []*backendmetrics.FakePodMetrics + output []*backendmetrics.FakePodMetrics err bool filter *filter }{ { name: "simple filter without successor, failure", - filter: &filter{filter: func(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { + filter: &filter{filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { return nil, errors.New("filter error") }}, err: true, @@ -55,10 +55,10 @@ func TestFilter(t *testing.T) { }, // pod2 will be picked because it has relatively low queue size, with the requested // model being active, and has low KV cache. - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, MaxActiveModels: 2, @@ -69,8 +69,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, MaxActiveModels: 2, @@ -81,8 +81,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, MaxActiveModels: 2, @@ -92,10 +92,10 @@ func TestFilter(t *testing.T) { }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, MaxActiveModels: 2, @@ -116,10 +116,10 @@ func TestFilter(t *testing.T) { Critical: false, }, // pod1 will be picked because it has capacity for the sheddable request. - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, MaxActiveModels: 2, @@ -130,8 +130,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, MaxActiveModels: 2, @@ -142,8 +142,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, MaxActiveModels: 2, @@ -153,10 +153,10 @@ func TestFilter(t *testing.T) { }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, MaxActiveModels: 2, @@ -178,10 +178,10 @@ func TestFilter(t *testing.T) { }, // All pods have higher KV cache thant the threshold, so the sheddable request will be // dropped. - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.9, MaxActiveModels: 2, @@ -192,8 +192,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.85, MaxActiveModels: 2, @@ -204,8 +204,8 @@ func TestFilter(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.85, MaxActiveModels: 2, @@ -215,19 +215,19 @@ func TestFilter(t *testing.T) { }, }, }, - output: []*datastore.PodMetrics{}, + output: []*backendmetrics.FakePodMetrics{}, err: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.filter.Filter(logger, test.req, test.input) + got, err := test.filter.Filter(logger, test.req, toInterface(test.input)) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -241,44 +241,44 @@ func TestFilterFunc(t *testing.T) { name string f filterFunc req *LLMRequest - input []*datastore.PodMetrics - output []*datastore.PodMetrics + input []*backendmetrics.FakePodMetrics + output []*backendmetrics.FakePodMetrics err bool }{ { name: "least queuing empty input", f: leastQueuingFilterFunc, - input: []*datastore.PodMetrics{}, - output: []*datastore.PodMetrics{}, + input: []*backendmetrics.FakePodMetrics{}, + output: []*backendmetrics.FakePodMetrics{}, }, { name: "least queuing", f: leastQueuingFilterFunc, - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, }, @@ -287,37 +287,37 @@ func TestFilterFunc(t *testing.T) { { name: "least kv cache empty input", f: leastKVCacheFilterFunc, - input: []*datastore.PodMetrics{}, - output: []*datastore.PodMetrics{}, + input: []*backendmetrics.FakePodMetrics{}, + output: []*backendmetrics.FakePodMetrics{}, }, { name: "least kv cache", f: leastKVCacheFilterFunc, - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 1.0, }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, }, @@ -326,32 +326,32 @@ func TestFilterFunc(t *testing.T) { { name: "noQueueAndLessThanKVCacheThresholdPredicate", f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)), - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ { // This pod should be returned. - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, }, }, { // Queue is non zero, despite low kv cache, should not return. - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.3, }, }, { // High kv cache despite zero queue, should not return - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 1.0, }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, }, @@ -365,10 +365,10 @@ func TestFilterFunc(t *testing.T) { Model: "model", ResolvedTargetModel: "model", }, - input: []*datastore.PodMetrics{ + input: []*backendmetrics.FakePodMetrics{ // ActiveModels include input model, should be returned. { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ "model": 1, @@ -377,7 +377,7 @@ func TestFilterFunc(t *testing.T) { }, // Input model is not active, however the server has room to load another adapter. { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ "another-model": 1, @@ -386,7 +386,7 @@ func TestFilterFunc(t *testing.T) { }, // Input is not active, and the server has reached max active models. { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ "foo": 1, @@ -395,9 +395,9 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*datastore.PodMetrics{ + output: []*backendmetrics.FakePodMetrics{ { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ "model": 1, @@ -405,7 +405,7 @@ func TestFilterFunc(t *testing.T) { }, }, { - Metrics: datastore.Metrics{ + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ "another-model": 1, @@ -418,12 +418,12 @@ func TestFilterFunc(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.f(logger, test.req, test.input) + got, err := test.f(logger, test.req, toInterface(test.input)) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -449,10 +449,10 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { } // Test setup: One affinity pod and one available pod - pods := []*datastore.PodMetrics{ + pods := []*backendmetrics.FakePodMetrics{ { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "affinity-pod"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "affinity-pod"}}, + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{ testAffinityModel: 1, @@ -460,8 +460,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, { - Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "available-pod"}}, - Metrics: datastore.Metrics{ + Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "available-pod"}}, + Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, ActiveModels: map[string]int{}, }, @@ -476,7 +476,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { // This test should work with whatever value is set there expectedAffinityPercent := loraAffinityThreshold * 100 for i := 0; i < numIterations; i++ { - result, err := loRASoftAffinityFilter(logger, req, pods) + result, err := loRASoftAffinityFilter(logger, req, toInterface(pods)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -487,7 +487,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { } // Identify if the returned pod is the affinity pod or available pod - if _, exists := result[0].ActiveModels[testAffinityModel]; exists { + if _, exists := result[0].GetMetrics().ActiveModels[testAffinityModel]; exists { affinityCount++ } else { availableCount++ @@ -519,3 +519,22 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { actualAvailablePercent, availableLowerBound, availableUpperBound) } } + +func toInterface(input []*backendmetrics.FakePodMetrics) []backendmetrics.PodMetrics { + output := []backendmetrics.PodMetrics{} + for _, i := range input { + output = append(output, i) + } + return output +} + +func toStruct(input []backendmetrics.PodMetrics) []*backendmetrics.FakePodMetrics { + if input == nil { + return nil + } + output := []*backendmetrics.FakePodMetrics{} + for _, i := range input { + output = append(output, i.(*backendmetrics.FakePodMetrics)) + } + return output +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index bdddd972..82410787 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -24,6 +24,7 @@ import ( "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -97,9 +98,9 @@ var ( // request to make room for critical requests. nextOnFailure: &filter{ name: "drop request", - filter: func(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) { + filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { logger.V(logutil.DEFAULT).Info("Request dropped", "request", req) - return []*datastore.PodMetrics{}, errutil.Error{ + return []backendmetrics.PodMetrics{}, errutil.Error{ Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", } }, @@ -120,16 +121,16 @@ type Scheduler struct { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod datastore.PodMetrics, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backendmetrics.PodMetrics, err error) { logger := log.FromContext(ctx).WithValues("request", req) podMetrics := s.datastore.PodGetAll() logger.V(logutil.VERBOSE).Info("Scheduling a request", "metrics", podMetrics) pods, err := s.filter.Filter(logger, req, podMetrics) if err != nil || len(pods) == 0 { - return datastore.PodMetrics{}, fmt.Errorf( + return nil, fmt.Errorf( "failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) } logger.V(logutil.VERBOSE).Info("Selecting a random pod from the candidates", "candidatePods", pods) i := rand.Intn(len(pods)) - return *pods[i], nil + return pods[i], nil } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 5b8269c1..a6c9f1d3 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -31,7 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" tlsutil "sigs.k8s.io/gateway-api-inference-extension/internal/tls" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/controller" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" @@ -45,13 +45,15 @@ type ExtProcServerRunner struct { DestinationEndpointHintKey string PoolName string PoolNamespace string - RefreshMetricsInterval time.Duration - RefreshPrometheusMetricsInterval time.Duration Datastore datastore.Datastore - Provider *backend.Provider SecureServing bool CertPath string UseStreaming bool + RefreshPrometheusMetricsInterval time.Duration + + // This should only be used in tests. We won't need this once we don't inject metrics in the tests. + // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup + TestPodMetricsClient *backendmetrics.FakePodMetricsClient } // Default values for CLI flags in main @@ -73,8 +75,6 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { DestinationEndpointHintMetadataNamespace: DefaultDestinationEndpointHintMetadataNamespace, PoolName: DefaultPoolName, PoolNamespace: DefaultPoolNamespace, - RefreshMetricsInterval: DefaultRefreshMetricsInterval, - RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, SecureServing: DefaultSecureServing, // Datastore can be assigned later. } @@ -121,12 +121,7 @@ func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Man // The runnable implements LeaderElectionRunnable with leader election disabled. func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { return runnable.NoLeaderElection(manager.RunnableFunc(func(ctx context.Context) error { - // Initialize backend provider - if err := r.Provider.Init(ctx, r.RefreshMetricsInterval, r.RefreshPrometheusMetricsInterval); err != nil { - logger.Error(err, "Failed to initialize backend provider") - return err - } - + backendmetrics.StartMetricsLogger(ctx, r.Datastore, r.RefreshPrometheusMetricsInterval) var srv *grpc.Server if r.SecureServing { var cert tls.Certificate diff --git a/pkg/epp/test/benchmark/benchmark.go b/pkg/epp/test/benchmark/benchmark.go deleted file mode 100644 index 67783480..00000000 --- a/pkg/epp/test/benchmark/benchmark.go +++ /dev/null @@ -1,145 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "context" - "flag" - "fmt" - "os" - "time" - - "github.com/bojand/ghz/printer" - "github.com/bojand/ghz/runner" - "github.com/go-logr/logr" - "github.com/jhump/protoreflect/desc" - uberzap "go.uber.org/zap" - "google.golang.org/protobuf/proto" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/controller-runtime/pkg/log/zap" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/test" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -var ( - svrAddr = flag.String("server_address", fmt.Sprintf("localhost:%d", runserver.DefaultGrpcPort), "Address of the ext proc server") - totalRequests = flag.Int("total_requests", 100000, "number of requests to be sent for load test") - // Flags when running a local ext proc server. - numFakePods = flag.Int("num_fake_pods", 200, "number of fake pods when running a local ext proc server") - numModelsPerPod = flag.Int("num_models_per_pod", 5, "number of fake models per pod when running a local ext proc server") - localServer = flag.Bool("local_server", true, "whether to start a local ext proc server") - refreshPodsInterval = flag.Duration("refreshPodsInterval", 10*time.Second, "interval to refresh pods") - refreshMetricsInterval = flag.Duration("refreshMetricsInterval", 50*time.Millisecond, "interval to refresh metrics via polling pods") - refreshPrometheusMetricsInterval = flag.Duration("refreshPrometheusMetricsInterval", 5*time.Second, "interval to flush prometheus metrics") -) - -const ( - port = runserver.DefaultGrpcPort -) - -func main() { - if err := run(); err != nil { - os.Exit(1) - } -} - -func run() error { - opts := zap.Options{ - Development: true, - } - opts.BindFlags(flag.CommandLine) - flag.Parse() - logger := zap.New(zap.UseFlagOptions(&opts), zap.RawZapOpts(uberzap.AddCaller())) - ctx := log.IntoContext(context.Background(), logger) - - if *localServer { - test.StartExtProc(ctx, port, *refreshPodsInterval, *refreshMetricsInterval, *refreshPrometheusMetricsInterval, fakePods(), fakeModels()) - time.Sleep(time.Second) // wait until server is up - logger.Info("Server started") - } - - report, err := runner.Run( - "envoy.service.ext_proc.v3.ExternalProcessor.Process", - *svrAddr, - runner.WithInsecure(true), - runner.WithBinaryDataFunc(generateRequestFunc(logger)), - runner.WithTotalRequests(uint(*totalRequests)), - ) - if err != nil { - logger.Error(err, "Runner failed") - return err - } - - printer := printer.ReportPrinter{ - Out: os.Stdout, - Report: report, - } - - printer.Print("summary") - return nil -} - -func generateRequestFunc(logger logr.Logger) func(mtd *desc.MethodDescriptor, callData *runner.CallData) []byte { - return func(mtd *desc.MethodDescriptor, callData *runner.CallData) []byte { - numModels := *numFakePods * (*numModelsPerPod) - req := test.GenerateRequest(logger, "hello", modelName(int(callData.RequestNumber)%numModels)) - data, err := proto.Marshal(req) - if err != nil { - logutil.Fatal(logger, err, "Failed to marshal request", "request", req) - } - return data - } -} - -func fakeModels() map[string]*v1alpha2.InferenceModel { - models := map[string]*v1alpha2.InferenceModel{} - for i := range *numFakePods { - for j := range *numModelsPerPod { - m := modelName(i*(*numModelsPerPod) + j) - models[m] = &v1alpha2.InferenceModel{Spec: v1alpha2.InferenceModelSpec{ModelName: m}} - } - } - - return models -} - -func fakePods() []*datastore.PodMetrics { - pms := make([]*datastore.PodMetrics, 0, *numFakePods) - for i := 0; i < *numFakePods; i++ { - pms = append(pms, test.FakePodMetrics(i, fakeMetrics(i))) - } - - return pms -} - -// fakeMetrics adds numModelsPerPod number of adapters to the pod metrics. -func fakeMetrics(podNumber int) datastore.Metrics { - metrics := datastore.Metrics{ - ActiveModels: make(map[string]int), - } - for i := 0; i < *numModelsPerPod; i++ { - metrics.ActiveModels[modelName(podNumber*(*numModelsPerPod)+i)] = 0 - } - return metrics -} - -func modelName(i int) string { - return fmt.Sprintf("adapter-%v", i) -} diff --git a/pkg/epp/test/utils.go b/pkg/epp/test/utils.go deleted file mode 100644 index b18b0919..00000000 --- a/pkg/epp/test/utils.go +++ /dev/null @@ -1,126 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package test - -import ( - "context" - "encoding/json" - "fmt" - "net" - "time" - - extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" - utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" -) - -func StartExtProc( - ctx context.Context, - port int, - refreshPodsInterval, refreshMetricsInterval, refreshPrometheusMetricsInterval time.Duration, - pods []*datastore.PodMetrics, - models map[string]*v1alpha2.InferenceModel, -) *grpc.Server { - logger := log.FromContext(ctx) - pms := make(map[types.NamespacedName]*datastore.PodMetrics) - for _, pod := range pods { - pms[pod.NamespacedName] = pod - } - pmc := &backend.FakePodMetricsClient{Res: pms} - datastore := datastore.NewDatastore() - for _, m := range models { - datastore.ModelSetIfOlder(m) - } - for _, pm := range pods { - pod := utiltesting.MakePod(pm.NamespacedName.Name). - Namespace(pm.NamespacedName.Namespace). - ReadyCondition(). - IP(pm.Address). - ObjRef() - datastore.PodUpdateOrAddIfNotExist(pod) - datastore.PodUpdateMetricsIfExist(pm.NamespacedName, &pm.Metrics) - } - pp := backend.NewProvider(pmc, datastore) - if err := pp.Init(ctx, refreshMetricsInterval, refreshPrometheusMetricsInterval); err != nil { - logutil.Fatal(logger, err, "Failed to initialize") - } - return startExtProc(logger, port, datastore) -} - -// startExtProc starts an extProc server with fake pods. -func startExtProc(logger logr.Logger, port int, datastore datastore.Datastore) *grpc.Server { - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - if err != nil { - logutil.Fatal(logger, err, "Failed to listen", "port", port) - } - - s := grpc.NewServer() - - extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(scheduling.NewScheduler(datastore), "", "target-pod", datastore)) - - logger.Info("gRPC server starting", "port", port) - reflection.Register(s) - go func() { - err := s.Serve(lis) - if err != nil { - logutil.Fatal(logger, err, "Ext-proc failed with the err") - } - }() - return s -} - -func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.ProcessingRequest { - j := map[string]interface{}{ - "model": model, - "prompt": prompt, - "max_tokens": 100, - "temperature": 0, - } - - llmReq, err := json.Marshal(j) - if err != nil { - logutil.Fatal(logger, err, "Failed to unmarshal LLM request") - } - req := &extProcPb.ProcessingRequest{ - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: llmReq}, - }, - } - return req -} - -func FakePodMetrics(index int, metrics datastore.Metrics) *datastore.PodMetrics { - address := fmt.Sprintf("192.168.1.%d", index+1) - pod := datastore.PodMetrics{ - Pod: datastore.Pod{ - NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index), Namespace: "default"}, - Address: address, - }, - Metrics: metrics, - } - return &pod -} diff --git a/pkg/epp/util/testing/request.go b/pkg/epp/util/testing/request.go new file mode 100644 index 00000000..fe9a0d08 --- /dev/null +++ b/pkg/epp/util/testing/request.go @@ -0,0 +1,45 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "encoding/json" + + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.ProcessingRequest { + j := map[string]interface{}{ + "model": model, + "prompt": prompt, + "max_tokens": 100, + "temperature": 0, + } + + llmReq, err := json.Marshal(j) + if err != nil { + logutil.Fatal(logger, err, "Failed to unmarshal LLM request") + } + req := &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: llmReq}, + }, + } + return req +} diff --git a/pkg/epp/util/testing/wrappers.go b/pkg/epp/util/testing/wrappers.go index 2693734f..c4018631 100644 --- a/pkg/epp/util/testing/wrappers.go +++ b/pkg/epp/util/testing/wrappers.go @@ -27,6 +27,12 @@ type PodWrapper struct { corev1.Pod } +func FromBase(pod *corev1.Pod) *PodWrapper { + return &PodWrapper{ + Pod: *pod, + } +} + // MakePod creates a wrapper for a Pod. func MakePod(podName string) *PodWrapper { return &PodWrapper{ diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 765449f3..c5e7c10a 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -55,12 +55,11 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" - extprocutils "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/test" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" "sigs.k8s.io/yaml" @@ -83,7 +82,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { tests := []struct { name string req *extProcPb.ProcessingRequest - pods []*datastore.PodMetrics + pods map[backendmetrics.Pod]*backendmetrics.Metrics wantHeaders []*configPb.HeaderValueOption wantMetadata *structpb.Struct wantBody []byte @@ -93,21 +92,21 @@ func TestKubeInferenceModelRequest(t *testing.T) { }{ { name: "select lower queue and kv cache, no active lora", - req: extprocutils.GenerateRequest(logger, "test1", "my-model"), + req: utiltesting.GenerateRequest(logger, "test1", "my-model"), // pod-1 will be picked because it has relatively low queue size and low KV cache. - pods: []*datastore.PodMetrics{ - extprocutils.FakePodMetrics(0, datastore.Metrics{ + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { WaitingQueueSize: 3, KVCacheUsagePercent: 0.2, - }), - extprocutils.FakePodMetrics(1, datastore.Metrics{ + }, + fakePod(1): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.1, - }), - extprocutils.FakePodMetrics(2, datastore.Metrics{ + }, + fakePod(2): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, - }), + }, }, wantHeaders: []*configPb.HeaderValueOption{ { @@ -134,34 +133,34 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, { name: "select active lora, low queue", - req: extprocutils.GenerateRequest(logger, "test2", "sql-lora"), + req: utiltesting.GenerateRequest(logger, "test2", "sql-lora"), // pod-1 will be picked because it has relatively low queue size, with the requested // model being active, and has low KV cache. - pods: []*datastore.PodMetrics{ - extprocutils.FakePodMetrics(0, datastore.Metrics{ + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ "foo": 1, "bar": 1, }, - }), - extprocutils.FakePodMetrics(1, datastore.Metrics{ + }, + fakePod(1): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.1, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg2": 1, }, - }), - extprocutils.FakePodMetrics(2, datastore.Metrics{ + }, + fakePod(2): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ "foo": 1, "bar": 1, }, - }), + }, }, wantHeaders: []*configPb.HeaderValueOption{ { @@ -188,34 +187,34 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, { name: "select no lora despite active model, avoid excessive queue size", - req: extprocutils.GenerateRequest(logger, "test3", "sql-lora"), + req: utiltesting.GenerateRequest(logger, "test3", "sql-lora"), // pod-2 will be picked despite it NOT having the requested model being active // as it's above the affinity for queue size. Also is critical, so we should // still honor request despite all queues > 5 - pods: []*datastore.PodMetrics{ - extprocutils.FakePodMetrics(0, datastore.Metrics{ + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ "foo": 1, "bar": 1, }, - }), - extprocutils.FakePodMetrics(1, datastore.Metrics{ + }, + fakePod(1): { WaitingQueueSize: 200, KVCacheUsagePercent: 0.1, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg2": 1, }, - }), - extprocutils.FakePodMetrics(2, datastore.Metrics{ + }, + fakePod(2): { WaitingQueueSize: 6, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ "foo": 1, }, - }), + }, }, wantHeaders: []*configPb.HeaderValueOption{ { @@ -242,11 +241,11 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, { name: "noncritical and all models past threshold, shed request", - req: extprocutils.GenerateRequest(logger, "test4", "sql-lora-sheddable"), + req: utiltesting.GenerateRequest(logger, "test4", "sql-lora-sheddable"), // no pods will be picked as all models are either above kv threshold, // queue threshold, or both. - pods: []*datastore.PodMetrics{ - extprocutils.FakePodMetrics(0, datastore.Metrics{ + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { WaitingQueueSize: 6, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ @@ -254,23 +253,23 @@ func TestKubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, - }), - extprocutils.FakePodMetrics(1, datastore.Metrics{ + }, + fakePod(1): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.85, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg3": 1, }, - }), - extprocutils.FakePodMetrics(2, datastore.Metrics{ + }, + fakePod(2): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.9, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg3": 1, }, - }), + }, }, wantHeaders: []*configPb.HeaderValueOption{}, wantMetadata: &structpb.Struct{}, @@ -285,10 +284,10 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, { name: "noncritical, but one server has capacity, do not shed", - req: extprocutils.GenerateRequest(logger, "test5", "sql-lora-sheddable"), + req: utiltesting.GenerateRequest(logger, "test5", "sql-lora-sheddable"), // pod 0 will be picked as all other models are above threshold - pods: []*datastore.PodMetrics{ - extprocutils.FakePodMetrics(0, datastore.Metrics{ + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ @@ -296,23 +295,23 @@ func TestKubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, - }), - extprocutils.FakePodMetrics(1, datastore.Metrics{ + }, + fakePod(1): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.85, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg3": 1, }, - }), - extprocutils.FakePodMetrics(2, datastore.Metrics{ + }, + fakePod(2): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.9, ActiveModels: map[string]int{ "foo": 1, "sql-lora-1fdg3": 1, }, - }), + }, }, wantHeaders: []*configPb.HeaderValueOption{ { @@ -391,12 +390,13 @@ func TestKubeInferenceModelRequest(t *testing.T) { } } -func setUpHermeticServer(t *testing.T, podMetrics []*datastore.PodMetrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { - pms := make(map[types.NamespacedName]*datastore.PodMetrics) - for _, pm := range podMetrics { - pms[pm.NamespacedName] = pm +func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*backendmetrics.Metrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { + // Reconfigure the TestPodMetricsClient. + res := map[types.NamespacedName]*backendmetrics.Metrics{} + for pod, metrics := range podAndMetrics { + res[pod.NamespacedName] = metrics } - pmc := &backend.FakePodMetricsClient{Res: pms} + serverRunner.TestPodMetricsClient.SetRes(res) serverCtx, stopServer := context.WithCancel(context.Background()) @@ -405,27 +405,26 @@ func setUpHermeticServer(t *testing.T, podMetrics []*datastore.PodMetrics) (clie "app": "vllm-llama2-7b-pool", } - for _, pm := range podMetrics { - pod := utiltesting.MakePod(pm.NamespacedName.Name). - Namespace(pm.NamespacedName.Namespace). + for pod := range podAndMetrics { + pod := utiltesting.MakePod(pod.NamespacedName.Name). + Namespace(pod.NamespacedName.Namespace). ReadyCondition(). Labels(podLabels). - IP(pm.Address). + IP(pod.Address). Complete(). ObjRef() copy := pod.DeepCopy() if err := k8sClient.Create(context.Background(), copy); err != nil { - logutil.Fatal(logger, err, "Failed to create pod", "pod", pm.NamespacedName) + logutil.Fatal(logger, err, "Failed to create pod", "pod", pod) } // since no pod controllers deployed in fake environment, we manually update pod status copy.Status = pod.Status if err := k8sClient.Status().Update(context.Background(), copy); err != nil { - logutil.Fatal(logger, err, "Failed to update pod status", "pod", pm.NamespacedName) + logutil.Fatal(logger, err, "Failed to update pod status", "pod", pod) } } - serverRunner.Provider = backend.NewProvider(pmc, serverRunner.Datastore) go func() { if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { logutil.Fatal(logger, err, "Failed to start ext-proc server") @@ -434,7 +433,7 @@ func setUpHermeticServer(t *testing.T, podMetrics []*datastore.PodMetrics) (clie // check if all pods are synced to datastore assert.EventuallyWithT(t, func(t *assert.CollectT) { - assert.Len(t, serverRunner.Datastore.PodGetAll(), len(podMetrics), "Datastore not synced") + assert.Len(t, serverRunner.Datastore.PodGetAll(), len(podAndMetrics), "Datastore not synced") }, 10*time.Second, time.Second) address := fmt.Sprintf("localhost:%v", port) @@ -455,12 +454,12 @@ func setUpHermeticServer(t *testing.T, podMetrics []*datastore.PodMetrics) (clie stopServer() // clear created pods - for _, pm := range podMetrics { - pod := utiltesting.MakePod(pm.NamespacedName.Name). - Namespace(pm.NamespacedName.Namespace).Complete().ObjRef() + for pod := range podAndMetrics { + pod := utiltesting.MakePod(pod.NamespacedName.Name). + Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() if err := k8sClient.Delete(context.Background(), pod); err != nil { - logutil.Fatal(logger, err, "Failed to delete pod", "pod", pm.NamespacedName) + logutil.Fatal(logger, err, "Failed to delete pod", "pod", fakePod) } } // wait a little until the goroutines actually exit @@ -468,6 +467,13 @@ func setUpHermeticServer(t *testing.T, podMetrics []*datastore.PodMetrics) (clie } } +func fakePod(index int) backendmetrics.Pod { + return backendmetrics.Pod{ + NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index), Namespace: "default"}, + Address: fmt.Sprintf("192.168.1.%d", index+1), + } +} + // Sets up a test environment and returns the runner struct func BeforeSuit(t *testing.T) func() { // Set up mock k8s API Client @@ -503,9 +509,11 @@ func BeforeSuit(t *testing.T) func() { } serverRunner = runserver.NewDefaultExtProcServerRunner() + serverRunner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} + pmf := backendmetrics.NewPodMetricsFactory(serverRunner.TestPodMetricsClient, 10*time.Millisecond) // Adjust from defaults serverRunner.PoolName = "vllm-llama2-7b-pool" - serverRunner.Datastore = datastore.NewDatastore() + serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) serverRunner.SecureServing = false if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil {