Skip to content

Commit dc24927

Browse files
ahg-grramkumar1
authored andcommitted
Consolidating all storage behind datastore (kubernetes-sigs#350)
* Removed the intermediate cache in provider, and consolidating all storage behind datastore. * Fixed the provider test and covered the pool deletion events. * Don't store the port number with the pods * Address pod ip address updates * rename PodFlushAll to PodResyncAll * Addressed first round of comments * Addressed more comments * Adding a comment
1 parent 23cd81b commit dc24927

26 files changed

+935
-656
lines changed

pkg/ext-proc/backend/datastore.go

+176-80
Original file line numberDiff line numberDiff line change
@@ -4,142 +4,209 @@ import (
44
"context"
55
"errors"
66
"math/rand"
7-
"strconv"
87
"sync"
98

109
"github.com/go-logr/logr"
1110
corev1 "k8s.io/api/core/v1"
1211
"k8s.io/apimachinery/pkg/labels"
12+
"k8s.io/apimachinery/pkg/types"
1313
"sigs.k8s.io/controller-runtime/pkg/client"
1414
"sigs.k8s.io/controller-runtime/pkg/log"
1515
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
1616
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
1717
)
1818

19-
func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore {
20-
store := &K8sDatastore{
21-
poolMu: sync.RWMutex{},
22-
InferenceModels: &sync.Map{},
23-
pods: &sync.Map{},
24-
}
25-
for _, opt := range options {
26-
opt(store)
19+
// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
20+
type Datastore interface {
21+
// InferencePool operations
22+
PoolSet(pool *v1alpha1.InferencePool)
23+
PoolGet() (*v1alpha1.InferencePool, error)
24+
PoolHasSynced() bool
25+
PoolLabelsMatch(podLabels map[string]string) bool
26+
27+
// InferenceModel operations
28+
ModelSet(infModel *v1alpha1.InferenceModel)
29+
ModelGet(modelName string) (*v1alpha1.InferenceModel, bool)
30+
ModelDelete(modelName string)
31+
32+
// PodMetrics operations
33+
PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool
34+
PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool
35+
PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool)
36+
PodDelete(namespacedName types.NamespacedName)
37+
PodResyncAll(ctx context.Context, ctrlClient client.Client)
38+
PodGetAll() []*PodMetrics
39+
PodDeleteAll() // This is only for testing.
40+
PodRange(f func(key, value any) bool)
41+
42+
// Clears the store state, happens when the pool gets deleted.
43+
Clear()
44+
}
45+
46+
func NewDatastore() Datastore {
47+
store := &datastore{
48+
poolMu: sync.RWMutex{},
49+
models: &sync.Map{},
50+
pods: &sync.Map{},
2751
}
2852
return store
2953
}
3054

31-
// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
32-
type K8sDatastore struct {
55+
type datastore struct {
3356
// poolMu is used to synchronize access to the inferencePool.
34-
poolMu sync.RWMutex
35-
inferencePool *v1alpha1.InferencePool
36-
InferenceModels *sync.Map
37-
pods *sync.Map
57+
poolMu sync.RWMutex
58+
pool *v1alpha1.InferencePool
59+
models *sync.Map
60+
// key: types.NamespacedName, value: *PodMetrics
61+
pods *sync.Map
3862
}
3963

40-
type K8sDatastoreOption func(*K8sDatastore)
41-
42-
// WithPods can be used in tests to override the pods.
43-
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
44-
return func(store *K8sDatastore) {
45-
store.pods = &sync.Map{}
46-
for _, pod := range pods {
47-
store.pods.Store(pod.Pod, true)
48-
}
49-
}
64+
func (ds *datastore) Clear() {
65+
ds.poolMu.Lock()
66+
defer ds.poolMu.Unlock()
67+
ds.pool = nil
68+
ds.models.Clear()
69+
ds.pods.Clear()
5070
}
5171

52-
func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) {
72+
// /// InferencePool APIs ///
73+
func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) {
5374
ds.poolMu.Lock()
5475
defer ds.poolMu.Unlock()
55-
ds.inferencePool = pool
76+
ds.pool = pool
5677
}
5778

58-
func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) {
79+
func (ds *datastore) PoolGet() (*v1alpha1.InferencePool, error) {
5980
ds.poolMu.RLock()
6081
defer ds.poolMu.RUnlock()
61-
if !ds.HasSynced() {
82+
if !ds.PoolHasSynced() {
6283
return nil, errors.New("InferencePool is not initialized in data store")
6384
}
64-
return ds.inferencePool, nil
85+
return ds.pool, nil
6586
}
6687

67-
func (ds *K8sDatastore) GetPodIPs() []string {
68-
var ips []string
69-
ds.pods.Range(func(name, pod any) bool {
70-
ips = append(ips, pod.(*corev1.Pod).Status.PodIP)
71-
return true
72-
})
73-
return ips
88+
func (ds *datastore) PoolHasSynced() bool {
89+
ds.poolMu.RLock()
90+
defer ds.poolMu.RUnlock()
91+
return ds.pool != nil
92+
}
93+
94+
func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool {
95+
poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector)
96+
podSet := labels.Set(podLabels)
97+
return poolSelector.Matches(podSet)
7498
}
7599

76-
func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) {
77-
infModel, ok := s.InferenceModels.Load(modelName)
100+
// /// InferenceModel APIs ///
101+
func (ds *datastore) ModelSet(infModel *v1alpha1.InferenceModel) {
102+
ds.models.Store(infModel.Spec.ModelName, infModel)
103+
}
104+
105+
func (ds *datastore) ModelGet(modelName string) (*v1alpha1.InferenceModel, bool) {
106+
infModel, ok := ds.models.Load(modelName)
78107
if ok {
79-
returnModel = infModel.(*v1alpha1.InferenceModel)
108+
return infModel.(*v1alpha1.InferenceModel), true
80109
}
81-
return
110+
return nil, false
82111
}
83112

84-
// HasSynced returns true if InferencePool is set in the data store.
85-
func (ds *K8sDatastore) HasSynced() bool {
86-
ds.poolMu.RLock()
87-
defer ds.poolMu.RUnlock()
88-
return ds.inferencePool != nil
113+
func (ds *datastore) ModelDelete(modelName string) {
114+
ds.models.Delete(modelName)
89115
}
90116

91-
func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
92-
var weights int32
93-
94-
source := rand.NewSource(rand.Int63())
95-
if seed > 0 {
96-
source = rand.NewSource(seed)
97-
}
98-
r := rand.New(source)
99-
for _, model := range model.Spec.TargetModels {
100-
weights += *model.Weight
117+
// /// Pods/endpoints APIs ///
118+
func (ds *datastore) PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool {
119+
if val, ok := ds.pods.Load(namespacedName); ok {
120+
existing := val.(*PodMetrics)
121+
existing.Metrics = *m
122+
return true
101123
}
102-
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
103-
randomVal := r.Int31n(weights)
104-
for _, model := range model.Spec.TargetModels {
105-
if randomVal < *model.Weight {
106-
return model.Name
107-
}
108-
randomVal -= *model.Weight
124+
return false
125+
}
126+
127+
func (ds *datastore) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) {
128+
val, ok := ds.pods.Load(namespacedName)
129+
if ok {
130+
return val.(*PodMetrics), true
109131
}
110-
return ""
132+
return nil, false
111133
}
112134

113-
func IsCritical(model *v1alpha1.InferenceModel) bool {
114-
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
135+
func (ds *datastore) PodGetAll() []*PodMetrics {
136+
res := []*PodMetrics{}
137+
fn := func(k, v any) bool {
138+
res = append(res, v.(*PodMetrics))
115139
return true
116140
}
117-
return false
141+
ds.pods.Range(fn)
142+
return res
118143
}
119144

120-
func (ds *K8sDatastore) LabelsMatch(podLabels map[string]string) bool {
121-
poolSelector := selectorFromInferencePoolSelector(ds.inferencePool.Spec.Selector)
122-
podSet := labels.Set(podLabels)
123-
return poolSelector.Matches(podSet)
145+
func (ds *datastore) PodRange(f func(key, value any) bool) {
146+
ds.pods.Range(f)
147+
}
148+
149+
func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
150+
ds.pods.Delete(namespacedName)
151+
}
152+
153+
func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool {
154+
new := &PodMetrics{
155+
Pod: Pod{
156+
NamespacedName: types.NamespacedName{
157+
Name: pod.Name,
158+
Namespace: pod.Namespace,
159+
},
160+
Address: pod.Status.PodIP,
161+
},
162+
Metrics: Metrics{
163+
ActiveModels: make(map[string]int),
164+
},
165+
}
166+
existing, ok := ds.pods.Load(new.NamespacedName)
167+
if !ok {
168+
ds.pods.Store(new.NamespacedName, new)
169+
return true
170+
}
171+
172+
// Update pod properties if anything changed.
173+
existing.(*PodMetrics).Pod = new.Pod
174+
return false
124175
}
125176

126-
func (ds *K8sDatastore) flushPodsAndRefetch(ctx context.Context, ctrlClient client.Client, newServerPool *v1alpha1.InferencePool) {
177+
func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client) {
178+
// Pool must exist to invoke this function.
179+
pool, _ := ds.PoolGet()
127180
podList := &corev1.PodList{}
128181
if err := ctrlClient.List(ctx, podList, &client.ListOptions{
129-
LabelSelector: selectorFromInferencePoolSelector(newServerPool.Spec.Selector),
130-
Namespace: newServerPool.Namespace,
182+
LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector),
183+
Namespace: pool.Namespace,
131184
}); err != nil {
132185
log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients")
186+
return
133187
}
134-
ds.pods.Clear()
135188

136-
for _, k8sPod := range podList.Items {
137-
pod := Pod{
138-
Name: k8sPod.Name,
139-
Address: k8sPod.Status.PodIP + ":" + strconv.Itoa(int(newServerPool.Spec.TargetPortNumber)),
189+
activePods := make(map[string]bool)
190+
for _, pod := range podList.Items {
191+
if podIsReady(&pod) {
192+
activePods[pod.Name] = true
193+
ds.PodUpdateOrAddIfNotExist(&pod)
140194
}
141-
ds.pods.Store(pod, true)
142195
}
196+
197+
// Remove pods that don't exist or not ready any more.
198+
deleteFn := func(k, v any) bool {
199+
pm := v.(*PodMetrics)
200+
if exist := activePods[pm.NamespacedName.Name]; !exist {
201+
ds.pods.Delete(pm.NamespacedName)
202+
}
203+
return true
204+
}
205+
ds.pods.Range(deleteFn)
206+
}
207+
208+
func (ds *datastore) PodDeleteAll() {
209+
ds.pods.Clear()
143210
}
144211

145212
func selectorFromInferencePoolSelector(selector map[v1alpha1.LabelKey]v1alpha1.LabelValue) labels.Selector {
@@ -153,3 +220,32 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha1.LabelKey]v1alpha1.LabelV
153220
}
154221
return outMap
155222
}
223+
224+
func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
225+
var weights int32
226+
227+
source := rand.NewSource(rand.Int63())
228+
if seed > 0 {
229+
source = rand.NewSource(seed)
230+
}
231+
r := rand.New(source)
232+
for _, model := range model.Spec.TargetModels {
233+
weights += *model.Weight
234+
}
235+
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
236+
randomVal := r.Int31n(weights)
237+
for _, model := range model.Spec.TargetModels {
238+
if randomVal < *model.Weight {
239+
return model.Name
240+
}
241+
randomVal -= *model.Weight
242+
}
243+
return ""
244+
}
245+
246+
func IsCritical(model *v1alpha1.InferenceModel) bool {
247+
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
248+
return true
249+
}
250+
return false
251+
}

pkg/ext-proc/backend/datastore_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ func TestHasSynced(t *testing.T) {
3232
}
3333
for _, tt := range tests {
3434
t.Run(tt.name, func(t *testing.T) {
35-
datastore := NewK8sDataStore()
35+
datastore := NewDatastore()
3636
// Set the inference pool
3737
if tt.inferencePool != nil {
38-
datastore.setInferencePool(tt.inferencePool)
38+
datastore.PoolSet(tt.inferencePool)
3939
}
4040
// Check if the data store has been initialized
41-
hasSynced := datastore.HasSynced()
41+
hasSynced := datastore.PoolHasSynced()
4242
if hasSynced != tt.hasSynced {
4343
t.Errorf("IsInitialized() = %v, want %v", hasSynced, tt.hasSynced)
4444
}

pkg/ext-proc/backend/fake.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@ package backend
33
import (
44
"context"
55

6+
"k8s.io/apimachinery/pkg/types"
67
"sigs.k8s.io/controller-runtime/pkg/log"
78
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
89
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
910
)
1011

1112
type FakePodMetricsClient struct {
12-
Err map[Pod]error
13-
Res map[Pod]*PodMetrics
13+
Err map[types.NamespacedName]error
14+
Res map[types.NamespacedName]*PodMetrics
1415
}
1516

16-
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) {
17-
if err, ok := f.Err[pod]; ok {
17+
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, existing *PodMetrics) (*PodMetrics, error) {
18+
if err, ok := f.Err[existing.NamespacedName]; ok {
1819
return nil, err
1920
}
20-
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "pod", pod, "existing", existing, "new", f.Res[pod])
21-
return f.Res[pod], nil
21+
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", f.Res[existing.NamespacedName])
22+
return f.Res[existing.NamespacedName], nil
2223
}
2324

2425
type FakeDataStore struct {

0 commit comments

Comments
 (0)