@@ -10,136 +10,187 @@ import (
10
10
"github.com/go-logr/logr"
11
11
corev1 "k8s.io/api/core/v1"
12
12
"k8s.io/apimachinery/pkg/labels"
13
+ "k8s.io/apimachinery/pkg/types"
13
14
"sigs.k8s.io/controller-runtime/pkg/client"
14
15
"sigs.k8s.io/controller-runtime/pkg/log"
15
16
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
16
17
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
17
18
)
18
19
19
- func NewK8sDataStore (options ... K8sDatastoreOption ) * K8sDatastore {
20
- store := & K8sDatastore {
21
- poolMu : sync.RWMutex {},
22
- InferenceModels : & sync.Map {},
23
- pods : & sync.Map {},
24
- }
25
- for _ , opt := range options {
26
- opt (store )
20
+ // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
21
+ type Datastore interface {
22
+ // InferencePool operations
23
+ PoolSet (pool * v1alpha1.InferencePool )
24
+ PoolGet () (* v1alpha1.InferencePool , error )
25
+ PoolHasSynced () bool
26
+ PoolLabelsMatch (podLabels map [string ]string ) bool
27
+
28
+ // InferenceModel operations
29
+ ModelSet (infModel * v1alpha1.InferenceModel )
30
+ ModelGet (modelName string ) (returnModel * v1alpha1.InferenceModel )
31
+ ModelDelete (modelName string )
32
+
33
+ // PodMetrics operations
34
+ PodAddIfNotExist (pod * corev1.Pod ) bool
35
+ PodUpdateMetricsIfExist (pm * PodMetrics )
36
+ PodGet (namespacedName types.NamespacedName ) (* PodMetrics , bool )
37
+ PodDelete (namespacedName types.NamespacedName )
38
+ PodFlush (ctx context.Context , ctrlClient client.Client )
39
+ PodGetAll () []* PodMetrics
40
+ PodRange (f func (key , value any ) bool )
41
+ PodDeleteAll () // This is only for testing.
42
+ }
43
+
44
+ func NewDatastore () Datastore {
45
+ store := & datastore {
46
+ poolMu : sync.RWMutex {},
47
+ models : & sync.Map {},
48
+ pods : & sync.Map {},
27
49
}
28
50
return store
29
51
}
30
52
31
- // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
32
- type K8sDatastore struct {
53
+ type datastore struct {
33
54
// poolMu is used to synchronize access to the inferencePool.
34
- poolMu sync.RWMutex
35
- inferencePool * v1alpha1.InferencePool
36
- InferenceModels * sync.Map
37
- pods * sync.Map
38
- }
39
-
40
- type K8sDatastoreOption func (* K8sDatastore )
41
-
42
- // WithPods can be used in tests to override the pods.
43
- func WithPods (pods []* PodMetrics ) K8sDatastoreOption {
44
- return func (store * K8sDatastore ) {
45
- store .pods = & sync.Map {}
46
- for _ , pod := range pods {
47
- store .pods .Store (pod .Pod , true )
48
- }
49
- }
55
+ poolMu sync.RWMutex
56
+ pool * v1alpha1.InferencePool
57
+ models * sync.Map
58
+ // key: types.NamespacedName, value: *PodMetrics
59
+ pods * sync.Map
50
60
}
51
61
52
- func (ds * K8sDatastore ) setInferencePool (pool * v1alpha1.InferencePool ) {
62
+ // /// InferencePool APIs ///
63
+ func (ds * datastore ) PoolSet (pool * v1alpha1.InferencePool ) {
53
64
ds .poolMu .Lock ()
54
65
defer ds .poolMu .Unlock ()
55
- ds .inferencePool = pool
66
+ ds .pool = pool
56
67
}
57
68
58
- func (ds * K8sDatastore ) getInferencePool () (* v1alpha1.InferencePool , error ) {
69
+ func (ds * datastore ) PoolGet () (* v1alpha1.InferencePool , error ) {
59
70
ds .poolMu .RLock ()
60
71
defer ds .poolMu .RUnlock ()
61
- if ! ds .HasSynced () {
72
+ if ! ds .PoolHasSynced () {
62
73
return nil , errors .New ("InferencePool is not initialized in data store" )
63
74
}
64
- return ds .inferencePool , nil
75
+ return ds .pool , nil
65
76
}
66
77
67
- func (ds * K8sDatastore ) GetPodIPs () []string {
68
- var ips []string
69
- ds .pods .Range (func (name , pod any ) bool {
70
- ips = append (ips , pod .(* corev1.Pod ).Status .PodIP )
71
- return true
72
- })
73
- return ips
78
+ func (ds * datastore ) PoolHasSynced () bool {
79
+ ds .poolMu .RLock ()
80
+ defer ds .poolMu .RUnlock ()
81
+ return ds .pool != nil
82
+ }
83
+
84
+ func (ds * datastore ) PoolLabelsMatch (podLabels map [string ]string ) bool {
85
+ poolSelector := selectorFromInferencePoolSelector (ds .pool .Spec .Selector )
86
+ podSet := labels .Set (podLabels )
87
+ return poolSelector .Matches (podSet )
74
88
}
75
89
76
- func (s * K8sDatastore ) FetchModelData (modelName string ) (returnModel * v1alpha1.InferenceModel ) {
77
- infModel , ok := s .InferenceModels .Load (modelName )
90
+ // /// InferenceModel APIs ///
91
+ func (ds * datastore ) ModelSet (infModel * v1alpha1.InferenceModel ) {
92
+ ds .models .Store (infModel .Spec .ModelName , infModel )
93
+ }
94
+
95
+ func (ds * datastore ) ModelGet (modelName string ) (returnModel * v1alpha1.InferenceModel ) {
96
+ infModel , ok := ds .models .Load (modelName )
78
97
if ok {
79
98
returnModel = infModel .(* v1alpha1.InferenceModel )
80
99
}
81
100
return
82
101
}
83
102
84
- // HasSynced returns true if InferencePool is set in the data store.
85
- func (ds * K8sDatastore ) HasSynced () bool {
86
- ds .poolMu .RLock ()
87
- defer ds .poolMu .RUnlock ()
88
- return ds .inferencePool != nil
103
+ func (ds * datastore ) ModelDelete (modelName string ) {
104
+ ds .models .Delete (modelName )
89
105
}
90
106
91
- func RandomWeightedDraw (logger logr.Logger , model * v1alpha1.InferenceModel , seed int64 ) string {
92
- var weights int32
93
-
94
- source := rand .NewSource (rand .Int63 ())
95
- if seed > 0 {
96
- source = rand .NewSource (seed )
97
- }
98
- r := rand .New (source )
99
- for _ , model := range model .Spec .TargetModels {
100
- weights += * model .Weight
107
+ // /// Pods/endpoints APIs ///
108
+ func (ds * datastore ) PodUpdateMetricsIfExist (pm * PodMetrics ) {
109
+ if val , ok := ds .pods .Load (pm .NamespacedName ); ok {
110
+ existing := val .(* PodMetrics )
111
+ existing .Metrics = pm .Metrics
101
112
}
102
- logger .V (logutil .TRACE ).Info ("Weights for model computed" , "model" , model .Name , "weights" , weights )
103
- randomVal := r .Int31n (weights )
104
- for _ , model := range model .Spec .TargetModels {
105
- if randomVal < * model .Weight {
106
- return model .Name
107
- }
108
- randomVal -= * model .Weight
113
+ }
114
+
115
+ func (ds * datastore ) PodGet (namespacedName types.NamespacedName ) (* PodMetrics , bool ) {
116
+ val , ok := ds .pods .Load (namespacedName )
117
+ if ok {
118
+ return val .(* PodMetrics ), true
109
119
}
110
- return ""
120
+ return nil , false
111
121
}
112
122
113
- func IsCritical (model * v1alpha1.InferenceModel ) bool {
114
- if model .Spec .Criticality != nil && * model .Spec .Criticality == v1alpha1 .Critical {
123
+ func (ds * datastore ) PodGetAll () []* PodMetrics {
124
+ res := []* PodMetrics {}
125
+ fn := func (k , v any ) bool {
126
+ res = append (res , v .(* PodMetrics ))
115
127
return true
116
128
}
117
- return false
129
+ ds .pods .Range (fn )
130
+ return res
118
131
}
119
132
120
- func (ds * K8sDatastore ) LabelsMatch (podLabels map [string ]string ) bool {
121
- poolSelector := selectorFromInferencePoolSelector (ds .inferencePool .Spec .Selector )
122
- podSet := labels .Set (podLabels )
123
- return poolSelector .Matches (podSet )
133
+ func (ds * datastore ) PodRange (f func (key , value any ) bool ) {
134
+ ds .pods .Range (f )
135
+ }
136
+
137
+ func (ds * datastore ) PodDelete (namespacedName types.NamespacedName ) {
138
+ ds .pods .Delete (namespacedName )
139
+ }
140
+
141
+ func (ds * datastore ) PodAddIfNotExist (pod * corev1.Pod ) bool {
142
+ // new pod, add to the store for probing
143
+ pool , _ := ds .PoolGet ()
144
+ new := & PodMetrics {
145
+ NamespacedName : types.NamespacedName {
146
+ Name : pod .Name ,
147
+ Namespace : pod .Namespace ,
148
+ },
149
+ Address : pod .Status .PodIP + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber )),
150
+ Metrics : Metrics {
151
+ ActiveModels : make (map [string ]int ),
152
+ },
153
+ }
154
+ if _ , ok := ds .pods .Load (new .NamespacedName ); ! ok {
155
+ ds .pods .Store (new .NamespacedName , new )
156
+ return true
157
+ }
158
+ return false
124
159
}
125
160
126
- func (ds * K8sDatastore ) flushPodsAndRefetch (ctx context.Context , ctrlClient client.Client , newServerPool * v1alpha1.InferencePool ) {
161
+ func (ds * datastore ) PodFlush (ctx context.Context , ctrlClient client.Client ) {
162
+ // Pool must exist to invoke this function.
163
+ pool , _ := ds .PoolGet ()
127
164
podList := & corev1.PodList {}
128
165
if err := ctrlClient .List (ctx , podList , & client.ListOptions {
129
- LabelSelector : selectorFromInferencePoolSelector (newServerPool .Spec .Selector ),
130
- Namespace : newServerPool .Namespace ,
166
+ LabelSelector : selectorFromInferencePoolSelector (pool .Spec .Selector ),
167
+ Namespace : pool .Namespace ,
131
168
}); err != nil {
132
169
log .FromContext (ctx ).V (logutil .DEFAULT ).Error (err , "Failed to list clients" )
170
+ return
133
171
}
134
- ds .pods .Clear ()
135
172
136
- for _ , k8sPod := range podList .Items {
137
- pod := Pod {
138
- Name : k8sPod .Name ,
139
- Address : k8sPod .Status .PodIP + ":" + strconv .Itoa (int (newServerPool .Spec .TargetPortNumber )),
173
+ activePods := make (map [string ]bool )
174
+ for _ , pod := range podList .Items {
175
+ if podIsReady (& pod ) {
176
+ activePods [pod .Name ] = true
177
+ ds .PodAddIfNotExist (& pod )
140
178
}
141
- ds .pods .Store (pod , true )
142
179
}
180
+
181
+ // Remove pods that don't exist or not ready any more.
182
+ deleteFn := func (k , v any ) bool {
183
+ pm := v .(* PodMetrics )
184
+ if exist := activePods [pm .NamespacedName .Name ]; ! exist {
185
+ ds .pods .Delete (pm .NamespacedName )
186
+ }
187
+ return true
188
+ }
189
+ ds .pods .Range (deleteFn )
190
+ }
191
+
192
+ func (ds * datastore ) PodDeleteAll () {
193
+ ds .pods .Clear ()
143
194
}
144
195
145
196
func selectorFromInferencePoolSelector (selector map [v1alpha1.LabelKey ]v1alpha1.LabelValue ) labels.Selector {
@@ -153,3 +204,32 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha1.LabelKey]v1alpha1.LabelV
153
204
}
154
205
return outMap
155
206
}
207
+
208
+ func RandomWeightedDraw (logger logr.Logger , model * v1alpha1.InferenceModel , seed int64 ) string {
209
+ var weights int32
210
+
211
+ source := rand .NewSource (rand .Int63 ())
212
+ if seed > 0 {
213
+ source = rand .NewSource (seed )
214
+ }
215
+ r := rand .New (source )
216
+ for _ , model := range model .Spec .TargetModels {
217
+ weights += * model .Weight
218
+ }
219
+ logger .V (logutil .TRACE ).Info ("Weights for model computed" , "model" , model .Name , "weights" , weights )
220
+ randomVal := r .Int31n (weights )
221
+ for _ , model := range model .Spec .TargetModels {
222
+ if randomVal < * model .Weight {
223
+ return model .Name
224
+ }
225
+ randomVal -= * model .Weight
226
+ }
227
+ return ""
228
+ }
229
+
230
+ func IsCritical (model * v1alpha1.InferenceModel ) bool {
231
+ if model .Spec .Criticality != nil && * model .Spec .Criticality == v1alpha1 .Critical {
232
+ return true
233
+ }
234
+ return false
235
+ }
0 commit comments