Skip to content

Commit f9113b3

Browse files
committed
Fixed the provider test and covered the pool deletion events.
1 parent 37f17e8 commit f9113b3

File tree

5 files changed

+127
-54
lines changed

5 files changed

+127
-54
lines changed

Diff for: pkg/ext-proc/backend/datastore.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ type Datastore interface {
3535
PodUpdateMetricsIfExist(pm *PodMetrics)
3636
PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool)
3737
PodDelete(namespacedName types.NamespacedName)
38-
PodFlush(ctx context.Context, ctrlClient client.Client)
38+
PodFlushAll(ctx context.Context, ctrlClient client.Client)
3939
PodGetAll() []*PodMetrics
40-
PodRange(f func(key, value any) bool)
4140
PodDeleteAll() // This is only for testing.
41+
PodRange(f func(key, value any) bool)
42+
43+
// Clears the store state, happens when the pool gets deleted.
44+
Clear()
4245
}
4346

4447
func NewDatastore() Datastore {
@@ -59,6 +62,14 @@ type datastore struct {
5962
pods *sync.Map
6063
}
6164

65+
func (ds *datastore) Clear() {
66+
ds.poolMu.Lock()
67+
defer ds.poolMu.Unlock()
68+
ds.pool = nil
69+
ds.models.Clear()
70+
ds.pods.Clear()
71+
}
72+
6273
// /// InferencePool APIs ///
6374
func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) {
6475
ds.poolMu.Lock()
@@ -158,7 +169,7 @@ func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool {
158169
return false
159170
}
160171

161-
func (ds *datastore) PodFlush(ctx context.Context, ctrlClient client.Client) {
172+
func (ds *datastore) PodFlushAll(ctx context.Context, ctrlClient client.Client) {
162173
// Pool must exist to invoke this function.
163174
pool, _ := ds.PoolGet()
164175
podList := &corev1.PodList{}

Diff for: pkg/ext-proc/backend/inferencepool_reconciler.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"reflect"
66

7+
"k8s.io/apimachinery/pkg/api/errors"
78
"k8s.io/apimachinery/pkg/runtime"
89
"k8s.io/apimachinery/pkg/types"
910
"k8s.io/client-go/tools/record"
@@ -36,12 +37,19 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques
3637
loggerDefault.Info("Reconciling InferencePool", "name", req.NamespacedName)
3738

3839
serverPool := &v1alpha1.InferencePool{}
40+
3941
if err := c.Get(ctx, req.NamespacedName, serverPool); err != nil {
42+
if errors.IsNotFound(err) {
43+
loggerDefault.Info("InferencePool not found. Clearing the datastore", "name", req.NamespacedName)
44+
c.Datastore.Clear()
45+
return ctrl.Result{}, nil
46+
}
4047
loggerDefault.Error(err, "Unable to get InferencePool", "name", req.NamespacedName)
4148
return ctrl.Result{}, err
42-
43-
// TODO: Handle InferencePool deletions. Need to flush the datastore.
44-
// TODO: Handle port updates, podMetrics should not be storing that as part of the address.
49+
} else if !serverPool.DeletionTimestamp.IsZero() {
50+
loggerDefault.Info("InferencePool is marked for deletion. Clearing the datastore", "name", req.NamespacedName)
51+
c.Datastore.Clear()
52+
return ctrl.Result{}, nil
4553
}
4654

4755
c.updateDatastore(ctx, serverPool)
@@ -55,7 +63,7 @@ func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *
5563
c.Datastore.PoolSet(newPool)
5664
if oldPool == nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) {
5765
logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "target", klog.KMetadata(&newPool.ObjectMeta))
58-
c.Datastore.PodFlush(ctx, c.Client)
66+
c.Datastore.PodFlushAll(ctx, c.Client)
5967
}
6068
}
6169

Diff for: pkg/ext-proc/backend/inferencepool_reconciler_test.go

+37-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ var (
2626
Name: "pool1",
2727
Namespace: "pool1-ns",
2828
},
29-
Spec: v1alpha1.InferencePoolSpec{Selector: selector_v1},
29+
Spec: v1alpha1.InferencePoolSpec{
30+
Selector: selector_v1,
31+
TargetPortNumber: 8080,
32+
},
3033
}
3134
pool2 = &v1alpha1.InferencePool{
3235
ObjectMeta: metav1.ObjectMeta{
@@ -48,6 +51,9 @@ var (
4851
)
4952

5053
func TestReconcile_InferencePoolReconciler(t *testing.T) {
54+
// The best practice is to use table-driven tests, however in this scaenario it seems
55+
// more logical to do a single test with steps that depend on each other.
56+
5157
// Set up the scheme.
5258
scheme := runtime.NewScheme()
5359
_ = clientgoscheme.AddToScheme(scheme)
@@ -63,7 +69,7 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) {
6369
WithObjects(initialObjects...).
6470
Build()
6571

66-
// Create a request for the existing resource.
72+
// Create a request for the existing resource.
6773
namespacedName := types.NamespacedName{Name: pool1.Name, Namespace: pool1.Namespace}
6874
req := ctrl.Request{NamespacedName: namespacedName}
6975
ctx := context.Background()
@@ -103,6 +109,35 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) {
103109
if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" {
104110
t.Errorf("Unexpected diff (+got/-want): %s", diff)
105111
}
112+
113+
// Step 4: update the pool port
114+
if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil {
115+
t.Errorf("Unexpected pool get error: %v", err)
116+
}
117+
newPool1.Spec.TargetPortNumber = 9090
118+
if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil {
119+
t.Errorf("Unexpected pool update error: %v", err)
120+
}
121+
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
122+
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
123+
}
124+
if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" {
125+
t.Errorf("Unexpected diff (+got/-want): %s", diff)
126+
}
127+
128+
// Step 5: delete the pool to trigger a datastore clear
129+
if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil {
130+
t.Errorf("Unexpected pool get error: %v", err)
131+
}
132+
if err := fakeClient.Delete(ctx, newPool1, &client.DeleteOptions{}); err != nil {
133+
t.Errorf("Unexpected pool delete error: %v", err)
134+
}
135+
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
136+
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
137+
}
138+
if diff := diffPool(datastore, nil, []string{}); diff != "" {
139+
t.Errorf("Unexpected diff (+got/-want): %s", diff)
140+
}
106141
}
107142

108143
func diffPool(datastore Datastore, wantPool *v1alpha1.InferencePool, wantPods []string) string {

Diff for: pkg/ext-proc/backend/provider_test.go

+57-29
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package backend
22

33
import (
4+
"context"
5+
"errors"
46
"sync"
57
"testing"
8+
"time"
69

710
"github.com/google/go-cmp/cmp"
811
"github.com/google/go-cmp/cmp/cmpopts"
12+
"github.com/stretchr/testify/assert"
913
"k8s.io/apimachinery/pkg/types"
10-
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1114
)
1215

1316
var (
@@ -42,20 +45,15 @@ var (
4245
)
4346

4447
func TestProvider(t *testing.T) {
45-
logger := logutil.NewTestLogger()
46-
4748
tests := []struct {
4849
name string
4950
pmc PodMetricsClient
5051
datastore Datastore
5152
want []*PodMetrics
5253
}{
5354
{
54-
name: "Fetch metrics error",
55+
name: "Probing metrics success",
5556
pmc: &FakePodMetricsClient{
56-
// Err: map[string]error{
57-
// pod2.Name: errors.New("injected error"),
58-
// },
5957
Res: map[types.NamespacedName]*PodMetrics{
6058
pod1.NamespacedName: pod1,
6159
pod2.NamespacedName: pod2,
@@ -67,42 +65,72 @@ func TestProvider(t *testing.T) {
6765
want: []*PodMetrics{
6866
pod1,
6967
pod2,
70-
// // Failed to fetch pod2 metrics so it remains the default values.
71-
// {
72-
// Name: "pod2",
73-
// Metrics: Metrics{
74-
// WaitingQueueSize: 0,
75-
// KVCacheUsagePercent: 0,
76-
// MaxActiveModels: 0,
77-
// ActiveModels: map[string]int{},
78-
// },
79-
// },
68+
},
69+
},
70+
{
71+
name: "Only pods in the datastore are probed",
72+
pmc: &FakePodMetricsClient{
73+
Res: map[types.NamespacedName]*PodMetrics{
74+
pod1.NamespacedName: pod1,
75+
pod2.NamespacedName: pod2,
76+
},
77+
},
78+
datastore: &datastore{
79+
pods: populateMap(pod1),
80+
},
81+
want: []*PodMetrics{
82+
pod1,
83+
},
84+
},
85+
{
86+
name: "Probing metrics error",
87+
pmc: &FakePodMetricsClient{
88+
Err: map[types.NamespacedName]error{
89+
pod2.NamespacedName: errors.New("injected error"),
90+
},
91+
Res: map[types.NamespacedName]*PodMetrics{
92+
pod1.NamespacedName: pod1,
93+
},
94+
},
95+
datastore: &datastore{
96+
pods: populateMap(pod1, pod2),
97+
},
98+
want: []*PodMetrics{
99+
pod1,
100+
// Failed to fetch pod2 metrics so it remains the default values.
101+
{
102+
NamespacedName: pod2.NamespacedName,
103+
Metrics: Metrics{
104+
WaitingQueueSize: 0,
105+
KVCacheUsagePercent: 0,
106+
MaxActiveModels: 0,
107+
},
108+
},
80109
},
81110
},
82111
}
83112

84113
for _, test := range tests {
85114
t.Run(test.name, func(t *testing.T) {
86115
p := NewProvider(test.pmc, test.datastore)
87-
// if err := p.refreshMetricsOnce(logger); err != nil {
88-
// t.Fatalf("Unexpected error: %v", err)
89-
// }
90-
_ = p.refreshMetricsOnce(logger)
91-
metrics := test.datastore.PodGetAll()
92-
lessFunc := func(a, b *PodMetrics) bool {
93-
return a.String() < b.String()
94-
}
95-
if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc)); diff != "" {
96-
t.Errorf("Unexpected output (-want +got): %v", diff)
97-
}
116+
ctx, cancel := context.WithCancel(context.Background())
117+
defer cancel()
118+
_ = p.Init(ctx, time.Millisecond, time.Millisecond)
119+
assert.EventuallyWithT(t, func(t *assert.CollectT) {
120+
metrics := test.datastore.PodGetAll()
121+
diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(func(a, b *PodMetrics) bool {
122+
return a.String() < b.String()
123+
}))
124+
assert.Equal(t, "", diff, "Unexpected diff (+got/-want)")
125+
}, 5*time.Second, time.Millisecond)
98126
})
99127
}
100128
}
101129

102130
func populateMap(pods ...*PodMetrics) *sync.Map {
103131
newMap := &sync.Map{}
104132
for _, pod := range pods {
105-
newMap.Store(pod.NamespacedName, pod)
133+
newMap.Store(pod.NamespacedName, &PodMetrics{NamespacedName: pod.NamespacedName})
106134
}
107135
return newMap
108136
}

Diff for: test/integration/hermetic_test.go

+7-16
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
1818
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
1919
"github.com/google/go-cmp/cmp"
20+
"github.com/stretchr/testify/assert"
2021
"google.golang.org/grpc"
2122
"google.golang.org/grpc/credentials/insecure"
2223
"google.golang.org/protobuf/testing/protocmp"
@@ -320,7 +321,7 @@ func TestKubeInferenceModelRequest(t *testing.T) {
320321
}
321322

322323
// Set up global k8sclient and extproc server runner with test environment config
323-
cleanup := BeforeSuit()
324+
cleanup := BeforeSuit(t)
324325
defer cleanup()
325326

326327
for _, test := range tests {
@@ -409,7 +410,7 @@ func setUpHermeticServer(podMetrics []*backend.PodMetrics) (client extProcPb.Ext
409410
}
410411

411412
// Sets up a test environment and returns the runner struct
412-
func BeforeSuit() func() {
413+
func BeforeSuit(t *testing.T) func() {
413414
// Set up mock k8s API Client
414415
testEnv = &envtest.Environment{
415416
CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")},
@@ -488,26 +489,16 @@ func BeforeSuit() func() {
488489
}
489490
}
490491

491-
if !blockUntilPoolSyncs(serverRunner.Datastore) {
492-
logutil.Fatal(logger, nil, "Timeout waiting for the pool and models to sync")
493-
}
492+
assert.EventuallyWithT(t, func(t *assert.CollectT) {
493+
synced := serverRunner.Datastore.PoolHasSynced() && serverRunner.Datastore.ModelGet("my-model") != nil
494+
assert.True(t, synced, "Timeout waiting for the pool and models to sync")
495+
}, 10*time.Second, 10*time.Millisecond)
494496

495497
return func() {
496498
_ = testEnv.Stop()
497499
}
498500
}
499501

500-
func blockUntilPoolSyncs(datastore backend.Datastore) bool {
501-
// We really need to move those tests to gingo so we can use Eventually...
502-
for i := 1; i < 10; i++ {
503-
if datastore.PoolHasSynced() && datastore.ModelGet("my-model") != nil {
504-
return true
505-
}
506-
time.Sleep(1 * time.Second)
507-
}
508-
return false
509-
}
510-
511502
func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) {
512503
t.Logf("Sending request: %v", req)
513504
if err := client.Send(req); err != nil {

0 commit comments

Comments
 (0)