Skip to content

Consolidating all storage behind datastore #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 176 additions & 80 deletions pkg/ext-proc/backend/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,142 +4,209 @@ import (
"context"
"errors"
"math/rand"
"strconv"
"sync"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
)

func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore {
store := &K8sDatastore{
poolMu: sync.RWMutex{},
InferenceModels: &sync.Map{},
pods: &sync.Map{},
}
for _, opt := range options {
opt(store)
// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
type Datastore interface {
// InferencePool operations
PoolSet(pool *v1alpha1.InferencePool)
PoolGet() (*v1alpha1.InferencePool, error)
PoolHasSynced() bool
PoolLabelsMatch(podLabels map[string]string) bool

// InferenceModel operations
ModelSet(infModel *v1alpha1.InferenceModel)
ModelGet(modelName string) (*v1alpha1.InferenceModel, bool)
ModelDelete(modelName string)

// PodMetrics operations
PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool
PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool
PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool)
PodDelete(namespacedName types.NamespacedName)
PodResyncAll(ctx context.Context, ctrlClient client.Client)
PodGetAll() []*PodMetrics
PodDeleteAll() // This is only for testing.
PodRange(f func(key, value any) bool)

// Clears the store state, happens when the pool gets deleted.
Clear()
}

func NewDatastore() Datastore {
store := &datastore{
poolMu: sync.RWMutex{},
models: &sync.Map{},
pods: &sync.Map{},
}
return store
}

// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api)
type K8sDatastore struct {
type datastore struct {
// poolMu is used to synchronize access to the inferencePool.
poolMu sync.RWMutex
inferencePool *v1alpha1.InferencePool
InferenceModels *sync.Map
pods *sync.Map
poolMu sync.RWMutex
pool *v1alpha1.InferencePool
models *sync.Map
// key: types.NamespacedName, value: *PodMetrics
pods *sync.Map
}

type K8sDatastoreOption func(*K8sDatastore)

// WithPods can be used in tests to override the pods.
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
return func(store *K8sDatastore) {
store.pods = &sync.Map{}
for _, pod := range pods {
store.pods.Store(pod.Pod, true)
}
}
func (ds *datastore) Clear() {
ds.poolMu.Lock()
defer ds.poolMu.Unlock()
ds.pool = nil
ds.models.Clear()
ds.pods.Clear()
}

func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) {
// /// InferencePool APIs ///
func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) {
ds.poolMu.Lock()
defer ds.poolMu.Unlock()
ds.inferencePool = pool
ds.pool = pool
}

func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) {
func (ds *datastore) PoolGet() (*v1alpha1.InferencePool, error) {
ds.poolMu.RLock()
defer ds.poolMu.RUnlock()
if !ds.HasSynced() {
if !ds.PoolHasSynced() {
return nil, errors.New("InferencePool is not initialized in data store")
}
return ds.inferencePool, nil
return ds.pool, nil
}

func (ds *K8sDatastore) GetPodIPs() []string {
var ips []string
ds.pods.Range(func(name, pod any) bool {
ips = append(ips, pod.(*corev1.Pod).Status.PodIP)
return true
})
return ips
func (ds *datastore) PoolHasSynced() bool {
ds.poolMu.RLock()
defer ds.poolMu.RUnlock()
return ds.pool != nil
}

func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool {
poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector)
podSet := labels.Set(podLabels)
return poolSelector.Matches(podSet)
}

func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) {
infModel, ok := s.InferenceModels.Load(modelName)
// /// InferenceModel APIs ///
func (ds *datastore) ModelSet(infModel *v1alpha1.InferenceModel) {
ds.models.Store(infModel.Spec.ModelName, infModel)
}

func (ds *datastore) ModelGet(modelName string) (*v1alpha1.InferenceModel, bool) {
infModel, ok := ds.models.Load(modelName)
if ok {
returnModel = infModel.(*v1alpha1.InferenceModel)
return infModel.(*v1alpha1.InferenceModel), true
}
return
return nil, false
}

// HasSynced returns true if InferencePool is set in the data store.
func (ds *K8sDatastore) HasSynced() bool {
ds.poolMu.RLock()
defer ds.poolMu.RUnlock()
return ds.inferencePool != nil
func (ds *datastore) ModelDelete(modelName string) {
ds.models.Delete(modelName)
}

func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
var weights int32

source := rand.NewSource(rand.Int63())
if seed > 0 {
source = rand.NewSource(seed)
}
r := rand.New(source)
for _, model := range model.Spec.TargetModels {
weights += *model.Weight
// /// Pods/endpoints APIs ///
func (ds *datastore) PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool {
if val, ok := ds.pods.Load(namespacedName); ok {
existing := val.(*PodMetrics)
existing.Metrics = *m
return true
}
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
randomVal := r.Int31n(weights)
for _, model := range model.Spec.TargetModels {
if randomVal < *model.Weight {
return model.Name
}
randomVal -= *model.Weight
return false
}

func (ds *datastore) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) {
val, ok := ds.pods.Load(namespacedName)
if ok {
return val.(*PodMetrics), true
}
return ""
return nil, false
}

func IsCritical(model *v1alpha1.InferenceModel) bool {
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
func (ds *datastore) PodGetAll() []*PodMetrics {
res := []*PodMetrics{}
fn := func(k, v any) bool {
res = append(res, v.(*PodMetrics))
return true
}
return false
ds.pods.Range(fn)
return res
}

func (ds *K8sDatastore) LabelsMatch(podLabels map[string]string) bool {
poolSelector := selectorFromInferencePoolSelector(ds.inferencePool.Spec.Selector)
podSet := labels.Set(podLabels)
return poolSelector.Matches(podSet)
func (ds *datastore) PodRange(f func(key, value any) bool) {
ds.pods.Range(f)
}

func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
ds.pods.Delete(namespacedName)
}

func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool {
new := &PodMetrics{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This can be a followup).
Consider making a NewPodMetrics helper function here.

We should hide the internal fields of the PodMetric object, and make helper functions. This will make the PR #223 much easier. Perhaps we need to make PodMetrics an interface instead. I imagine with PR #223 the PodMetrics will need to manage the lifecycle of the refresher goroutines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets do a follow up on that, I think we also need to rename the object as well. Perhaps we call it Endpoint? wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah renaming sounds good too.

Pod: Pod{
NamespacedName: types.NamespacedName{
Name: pod.Name,
Namespace: pod.Namespace,
},
Address: pod.Status.PodIP,
},
Metrics: Metrics{
ActiveModels: make(map[string]int),
},
}
existing, ok := ds.pods.Load(new.NamespacedName)
if !ok {
ds.pods.Store(new.NamespacedName, new)
return true
}

// Update pod properties if anything changed.
existing.(*PodMetrics).Pod = new.Pod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to update pod properties, currently the address only.

return false
}

func (ds *K8sDatastore) flushPodsAndRefetch(ctx context.Context, ctrlClient client.Client, newServerPool *v1alpha1.InferencePool) {
func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client) {
// Pool must exist to invoke this function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be true now, but I could see this eventually being extracted to a lib to be used in custom EPPs. We may want to think about capturing the error and returning it to protect future callers.

Not a blocking comment for this PR however

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

We can consider making this method taking a pod selector, and have the PoolReconciler send in the pod selector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pool will always be required since it has the selector, otherwise we wouldn't know what pods to cache. The other option is to pass in the selector, which I am not sure is better since I view the datastore as a per pool cache.

pool, _ := ds.PoolGet()
podList := &corev1.PodList{}
if err := ctrlClient.List(ctx, podList, &client.ListOptions{
LabelSelector: selectorFromInferencePoolSelector(newServerPool.Spec.Selector),
Namespace: newServerPool.Namespace,
LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector),
Namespace: pool.Namespace,
}); err != nil {
log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients")
return
}
ds.pods.Clear()

for _, k8sPod := range podList.Items {
pod := Pod{
Name: k8sPod.Name,
Address: k8sPod.Status.PodIP + ":" + strconv.Itoa(int(newServerPool.Spec.TargetPortNumber)),
activePods := make(map[string]bool)
for _, pod := range podList.Items {
if podIsReady(&pod) {
activePods[pod.Name] = true
ds.PodUpdateOrAddIfNotExist(&pod)
}
ds.pods.Store(pod, true)
}

// Remove pods that don't exist or not ready any more.
deleteFn := func(k, v any) bool {
pm := v.(*PodMetrics)
if exist := activePods[pm.NamespacedName.Name]; !exist {
ds.pods.Delete(pm.NamespacedName)
}
return true
}
ds.pods.Range(deleteFn)
}

func (ds *datastore) PodDeleteAll() {
ds.pods.Clear()
}

func selectorFromInferencePoolSelector(selector map[v1alpha1.LabelKey]v1alpha1.LabelValue) labels.Selector {
Expand All @@ -153,3 +220,32 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha1.LabelKey]v1alpha1.LabelV
}
return outMap
}

func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string {
var weights int32

source := rand.NewSource(rand.Int63())
if seed > 0 {
source = rand.NewSource(seed)
}
r := rand.New(source)
for _, model := range model.Spec.TargetModels {
weights += *model.Weight
}
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
randomVal := r.Int31n(weights)
for _, model := range model.Spec.TargetModels {
if randomVal < *model.Weight {
return model.Name
}
randomVal -= *model.Weight
}
return ""
}

func IsCritical(model *v1alpha1.InferenceModel) bool {
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical {
return true
}
return false
}
6 changes: 3 additions & 3 deletions pkg/ext-proc/backend/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ func TestHasSynced(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
datastore := NewK8sDataStore()
datastore := NewDatastore()
// Set the inference pool
if tt.inferencePool != nil {
datastore.setInferencePool(tt.inferencePool)
datastore.PoolSet(tt.inferencePool)
}
// Check if the data store has been initialized
hasSynced := datastore.HasSynced()
hasSynced := datastore.PoolHasSynced()
if hasSynced != tt.hasSynced {
t.Errorf("IsInitialized() = %v, want %v", hasSynced, tt.hasSynced)
}
Expand Down
13 changes: 7 additions & 6 deletions pkg/ext-proc/backend/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@ package backend
import (
"context"

"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging"
)

type FakePodMetricsClient struct {
Err map[Pod]error
Res map[Pod]*PodMetrics
Err map[types.NamespacedName]error
Res map[types.NamespacedName]*PodMetrics
}

func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) {
if err, ok := f.Err[pod]; ok {
func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, existing *PodMetrics) (*PodMetrics, error) {
if err, ok := f.Err[existing.NamespacedName]; ok {
return nil, err
}
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "pod", pod, "existing", existing, "new", f.Res[pod])
return f.Res[pod], nil
log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", f.Res[existing.NamespacedName])
return f.Res[existing.NamespacedName], nil
}

type FakeDataStore struct {
Expand Down
Loading