Skip to content

Commit 6f0be81

Browse files
committed
Removed the intermediate cache in provider, and consolidating all storage behind datastore.
1 parent 6b42ab8 commit 6f0be81

25 files changed

+790
-662
lines changed

Diff for: pkg/ext-proc/backend/datastore.go

+159-79
Original file line numberDiff line numberDiff line change
@@ -10,136 +10,187 @@ import (
1010
"github.com/go-logr/logr"
1111
corev1 "k8s.io/api/core/v1"
1212
"k8s.io/apimachinery/pkg/labels"
13+
"k8s.io/apimachinery/pkg/types"
1314
"sigs.k8s.io/controller-runtime/pkg/client"
1415
"sigs.k8s.io/controller-runtime/pkg/log"
1516
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
1617
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1718
)
1819

19-
func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore {
20-
store := &K8sDatastore{
21-
poolMu: sync.RWMutex{},
22-
InferenceModels: &sync.Map{},
23-
pods: &sync.Map{},
24-
}
25-
for _, opt := range options {
26-
opt(store)
20+
// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
21+
type Datastore interface {
22+
// InferencePool operations
23+
PoolSet(pool *v1alpha1.InferencePool)
24+
PoolGet() (*v1alpha1.InferencePool, error)
25+
PoolHasSynced() bool
26+
PoolLabelsMatch(podLabels map[string]string) bool
27+
28+
// InferenceModel operations
29+
ModelSet(infModel *v1alpha1.InferenceModel)
30+
ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel)
31+
ModelDelete(modelName string)
32+
33+
// PodMetrics operations
34+
PodAddIfNotExist(pod *corev1.Pod) bool
35+
PodUpdateMetricsIfExist(pm *PodMetrics)
36+
PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool)
37+
PodDelete(namespacedName types.NamespacedName)
38+
PodFlush(ctx context.Context, ctrlClient client.Client)
39+
PodGetAll() []*PodMetrics
40+
PodRange(f func(key, value any) bool)
41+
PodDeleteAll() // This is only for testing.
42+
}
43+
44+
func NewDatastore() Datastore {
45+
store := &datastore{
46+
poolMu: sync.RWMutex{},
47+
models: &sync.Map{},
48+
pods: &sync.Map{},
2749
}
2850
return store
2951
}
3052

31-
// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
32-
type K8sDatastore struct {
53+
type datastore struct {
3354
// poolMu is used to synchronize access to the inferencePool.
34-
poolMu sync.RWMutex
35-
inferencePool *v1alpha1.InferencePool
36-
InferenceModels *sync.Map
37-
pods *sync.Map
38-
}
39-
40-
type K8sDatastoreOption func(*K8sDatastore)
41-
42-
// WithPods can be used in tests to override the pods.
43-
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
44-
return func(store *K8sDatastore) {
45-
store.pods = &sync.Map{}
46-
for _, pod := range pods {
47-
store.pods.Store(pod.Pod, true)
48-
}
49-
}
55+
poolMu sync.RWMutex
56+
pool *v1alpha1.InferencePool
57+
models *sync.Map
58+
// key: types.NamespacedName, value: *PodMetrics
59+
pods *sync.Map
5060
}
5161

52-
func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) {
62+
// /// InferencePool APIs ///
63+
func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) {
5364
ds.poolMu.Lock()
5465
defer ds.poolMu.Unlock()
55-
ds.inferencePool = pool
66+
ds.pool = pool
5667
}
5768

58-
func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) {
69+
func (ds *datastore) PoolGet() (*v1alpha1.InferencePool, error) {
5970
ds.poolMu.RLock()
6071
defer ds.poolMu.RUnlock()
61-
if !ds.HasSynced() {
72+
if !ds.PoolHasSynced() {
6273
return nil, errors.New("InferencePool is not initialized in data store")
6374
}
64-
return ds.inferencePool, nil
75+
return ds.pool, nil
6576
}
6677

67-
func (ds *K8sDatastore) GetPodIPs() []string {
68-
var ips []string
69-
ds.pods.Range(func(name, pod any) bool {
70-
ips = append(ips, pod.(*corev1.Pod).Status.PodIP)
71-
return true
72-
})
73-
return ips
78+
func (ds *datastore) PoolHasSynced() bool {
79+
ds.poolMu.RLock()
80+
defer ds.poolMu.RUnlock()
81+
return ds.pool != nil
82+
}
83+
84+
func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool {
85+
poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector)
86+
podSet := labels.Set(podLabels)
87+
return poolSelector.Matches(podSet)
7488
}
7589

76-
func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) {
77-
infModel, ok := s.InferenceModels.Load(modelName)
90+
// /// InferenceModel APIs ///
91+
func (ds *datastore) ModelSet(infModel *v1alpha1.InferenceModel) {
92+
ds.models.Store(infModel.Spec.ModelName, infModel)
93+
}
94+
95+
func (ds *datastore) ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel) {
96+
infModel, ok := ds.models.Load(modelName)
7897
if ok {
7998
returnModel = infModel.(*v1alpha1.InferenceModel)
8099
}
81100
return
82101
}
83102

84-
// HasSynced returns true if InferencePool is set in the data store.
85-
func (ds *K8sDatastore) HasSynced() bool {
86-
ds.poolMu.RLock()
87-
defer ds.poolMu.RUnlock()
88-
return ds.inferencePool != nil
103+
func (ds *datastore) ModelDelete(modelName string) {
104+
ds.models.Delete(modelName)
89105
}
90106

91-
func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
92-
var weights int32
93-
94-
source := rand.NewSource(rand.Int63())
95-
if seed > 0 {
96-
source = rand.NewSource(seed)
97-
}
98-
r := rand.New(source)
99-
for _, model := range model.Spec.TargetModels {
100-
weights += *model.Weight
107+
// /// Pods/endpoints APIs ///
108+
func (ds *datastore) PodUpdateMetricsIfExist(pm *PodMetrics) {
109+
if val, ok := ds.pods.Load(pm.NamespacedName); ok {
110+
existing := val.(*PodMetrics)
111+
existing.Metrics = pm.Metrics
101112
}
102-
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
103-
randomVal := r.Int31n(weights)
104-
for _, model := range model.Spec.TargetModels {
105-
if randomVal < *model.Weight {
106-
return model.Name
107-
}
108-
randomVal -= *model.Weight
113+
}
114+
115+
func (ds *datastore) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) {
116+
val, ok := ds.pods.Load(namespacedName)
117+
if ok {
118+
return val.(*PodMetrics), true
109119
}
110-
return ""
120+
return nil, false
111121
}
112122

113-
func IsCritical(model *v1alpha1.InferenceModel) bool {
114-
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
123+
func (ds *datastore) PodGetAll() []*PodMetrics {
124+
res := []*PodMetrics{}
125+
fn := func(k, v any) bool {
126+
res = append(res, v.(*PodMetrics))
115127
return true
116128
}
117-
return false
129+
ds.pods.Range(fn)
130+
return res
118131
}
119132

120-
func (ds *K8sDatastore) LabelsMatch(podLabels map[string]string) bool {
121-
poolSelector := selectorFromInferencePoolSelector(ds.inferencePool.Spec.Selector)
122-
podSet := labels.Set(podLabels)
123-
return poolSelector.Matches(podSet)
133+
func (ds *datastore) PodRange(f func(key, value any) bool) {
134+
ds.pods.Range(f)
135+
}
136+
137+
func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
138+
ds.pods.Delete(namespacedName)
139+
}
140+
141+
func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool {
142+
// new pod, add to the store for probing
143+
pool, _ := ds.PoolGet()
144+
new := &PodMetrics{
145+
NamespacedName: types.NamespacedName{
146+
Name: pod.Name,
147+
Namespace: pod.Namespace,
148+
},
149+
Address: pod.Status.PodIP + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)),
150+
Metrics: Metrics{
151+
ActiveModels: make(map[string]int),
152+
},
153+
}
154+
if _, ok := ds.pods.Load(new.NamespacedName); !ok {
155+
ds.pods.Store(new.NamespacedName, new)
156+
return true
157+
}
158+
return false
124159
}
125160

126-
func (ds *K8sDatastore) flushPodsAndRefetch(ctx context.Context, ctrlClient client.Client, newServerPool *v1alpha1.InferencePool) {
161+
func (ds *datastore) PodFlush(ctx context.Context, ctrlClient client.Client) {
162+
// Pool must exist to invoke this function.
163+
pool, _ := ds.PoolGet()
127164
podList := &corev1.PodList{}
128165
if err := ctrlClient.List(ctx, podList, &client.ListOptions{
129-
LabelSelector: selectorFromInferencePoolSelector(newServerPool.Spec.Selector),
130-
Namespace: newServerPool.Namespace,
166+
LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector),
167+
Namespace: pool.Namespace,
131168
}); err != nil {
132169
log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients")
170+
return
133171
}
134-
ds.pods.Clear()
135172

136-
for _, k8sPod := range podList.Items {
137-
pod := Pod{
138-
Name: k8sPod.Name,
139-
Address: k8sPod.Status.PodIP + ":" + strconv.Itoa(int(newServerPool.Spec.TargetPortNumber)),
173+
activePods := make(map[string]bool)
174+
for _, pod := range podList.Items {
175+
if podIsReady(&pod) {
176+
activePods[pod.Name] = true
177+
ds.PodAddIfNotExist(&pod)
140178
}
141-
ds.pods.Store(pod, true)
142179
}
180+
181+
// Remove pods that don't exist or not ready any more.
182+
deleteFn := func(k, v any) bool {
183+
pm := v.(*PodMetrics)
184+
if exist := activePods[pm.NamespacedName.Name]; !exist {
185+
ds.pods.Delete(pm.NamespacedName)
186+
}
187+
return true
188+
}
189+
ds.pods.Range(deleteFn)
190+
}
191+
192+
func (ds *datastore) PodDeleteAll() {
193+
ds.pods.Clear()
143194
}
144195

145196
func selectorFromInferencePoolSelector(selector map[v1alpha1.LabelKey]v1alpha1.LabelValue) labels.Selector {
@@ -153,3 +204,32 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha1.LabelKey]v1alpha1.LabelV
153204
}
154205
return outMap
155206
}
207+
208+
func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
209+
var weights int32
210+
211+
source := rand.NewSource(rand.Int63())
212+
if seed > 0 {
213+
source = rand.NewSource(seed)
214+
}
215+
r := rand.New(source)
216+
for _, model := range model.Spec.TargetModels {
217+
weights += *model.Weight
218+
}
219+
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
220+
randomVal := r.Int31n(weights)
221+
for _, model := range model.Spec.TargetModels {
222+
if randomVal < *model.Weight {
223+
return model.Name
224+
}
225+
randomVal -= *model.Weight
226+
}
227+
return ""
228+
}
229+
230+
func IsCritical(model *v1alpha1.InferenceModel) bool {
231+
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
232+
return true
233+
}
234+
return false
235+
}

Diff for: pkg/ext-proc/backend/datastore_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ func TestHasSynced(t *testing.T) {
3232
}
3333
for _, tt := range tests {
3434
t.Run(tt.name, func(t *testing.T) {
35-
datastore := NewK8sDataStore()
35+
datastore := NewDatastore()
3636
// Set the inference pool
3737
if tt.inferencePool != nil {
38-
datastore.setInferencePool(tt.inferencePool)
38+
datastore.PoolSet(tt.inferencePool)
3939
}
4040
// Check if the data store has been initialized
41-
hasSynced := datastore.HasSynced()
41+
hasSynced := datastore.PoolHasSynced()
4242
if hasSynced != tt.hasSynced {
4343
t.Errorf("IsInitialized() = %v, want %v", hasSynced, tt.hasSynced)
4444
}

Diff for: pkg/ext-proc/backend/fake.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@ package backend
33
import (
44
"context"
55

6+
"k8s.io/apimachinery/pkg/types"
67
"sigs.k8s.io/controller-runtime/pkg/log"
78
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
89
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
910
)
1011

1112
type FakePodMetricsClient struct {
12-
Err map[Pod]error
13-
Res map[Pod]*PodMetrics
13+
Err map[types.NamespacedName]error
14+
Res map[types.NamespacedName]*PodMetrics
1415
}
1516

16-
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) {
17-
if err, ok := f.Err[pod]; ok {
17+
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, existing *PodMetrics) (*PodMetrics, error) {
18+
if err, ok := f.Err[existing.NamespacedName]; ok {
1819
return nil, err
1920
}
20-
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "pod", pod, "existing", existing, "new", f.Res[pod])
21-
return f.Res[pod], nil
21+
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", f.Res[existing.NamespacedName])
22+
return f.Res[existing.NamespacedName], nil
2223
}
2324

2425
type FakeDataStore struct {

Diff for: pkg/ext-proc/backend/inferencemodel_reconciler.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type InferenceModelReconciler struct {
1919
client.Client
2020
Scheme *runtime.Scheme
2121
Record record.EventRecorder
22-
Datastore *K8sDatastore
22+
Datastore Datastore
2323
PoolNamespacedName types.NamespacedName
2424
}
2525

@@ -36,14 +36,14 @@ func (c *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Reque
3636
if err := c.Get(ctx, req.NamespacedName, infModel); err != nil {
3737
if errors.IsNotFound(err) {
3838
loggerDefault.Info("InferenceModel not found. Removing from datastore since object must be deleted", "name", req.NamespacedName)
39-
c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName)
39+
c.Datastore.ModelDelete(infModel.Spec.ModelName)
4040
return ctrl.Result{}, nil
4141
}
4242
loggerDefault.Error(err, "Unable to get InferenceModel", "name", req.NamespacedName)
4343
return ctrl.Result{}, err
4444
} else if !infModel.DeletionTimestamp.IsZero() {
4545
loggerDefault.Info("InferenceModel is marked for deletion. Removing from datastore", "name", req.NamespacedName)
46-
c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName)
46+
c.Datastore.ModelDelete(infModel.Spec.ModelName)
4747
return ctrl.Result{}, nil
4848
}
4949

@@ -57,12 +57,12 @@ func (c *InferenceModelReconciler) updateDatastore(logger logr.Logger, infModel
5757
if infModel.Spec.PoolRef.Name == c.PoolNamespacedName.Name {
5858
loggerDefault.Info("Updating datastore", "poolRef", infModel.Spec.PoolRef, "serverPoolName", c.PoolNamespacedName)
5959
loggerDefault.Info("Adding/Updating InferenceModel", "modelName", infModel.Spec.ModelName)
60-
c.Datastore.InferenceModels.Store(infModel.Spec.ModelName, infModel)
60+
c.Datastore.ModelSet(infModel)
6161
return
6262
}
6363
loggerDefault.Info("Removing/Not adding InferenceModel", "modelName", infModel.Spec.ModelName)
6464
// If we get here. The model is not relevant to this pool, remove.
65-
c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName)
65+
c.Datastore.ModelDelete(infModel.Spec.ModelName)
6666
}
6767

6868
func (c *InferenceModelReconciler) SetupWithManager(mgr ctrl.Manager) error {

0 commit comments

Comments
 (0)