Skip to content

Commit b1fed6c

Browse files
authored
make ModelName immutable and fix model weight (#427)
* make ModelName immutable and fix model weight * Fix ut
1 parent 0d08a07 commit b1fed6c

File tree

4 files changed

+35
-5
lines changed

4 files changed

+35
-5
lines changed

api/v1alpha2/inferencemodel_types.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type InferenceModelSpec struct {
7171
//
7272
// +kubebuilder:validation:MaxLength=256
7373
// +kubebuilder:validation:Required
74+
// +kubebuilder:validation:XValidation:rule="self == oldSelf",message="modelName is immutable"
7475
ModelName string `json:"modelName"`
7576

7677
// Criticality defines how important it is to serve the model compared to other models referencing the same pool.
@@ -175,7 +176,7 @@ type TargetModel struct {
175176
// Conversely weights are optional, so long as ALL targetModels do not specify a weight.
176177
//
177178
// +optional
178-
// +kubebuilder:validation:Minimum=0
179+
// +kubebuilder:validation:Minimum=1
179180
// +kubebuilder:validation:Maximum=1000000
180181
Weight *int32 `json:"weight,omitempty"`
181182
}

config/crd/bases/inference.networking.x-k8s.io_inferencemodels.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ spec:
8282
an error will be returned specifying that no valid target model is found.
8383
maxLength: 256
8484
type: string
85+
x-kubernetes-validations:
86+
- message: modelName is immutable
87+
rule: self == oldSelf
8588
poolRef:
8689
description: PoolRef is a reference to the inference pool, the pool
8790
must exist in the same namespace.
@@ -143,7 +146,7 @@ spec:
143146
Conversely weights are optional, so long as ALL targetModels do not specify a weight.
144147
format: int32
145148
maximum: 1000000
146-
minimum: 0
149+
minimum: 1
147150
type: integer
148151
required:
149152
- name

pkg/epp/datastore/datastore.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -334,18 +334,25 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha2.LabelKey]v1alpha2.LabelV
334334
}
335335

336336
func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string {
337-
var weights int32
338-
339337
source := rand.NewSource(rand.Int63())
340338
if seed > 0 {
341339
source = rand.NewSource(seed)
342340
}
343341
r := rand.New(source)
342+
343+
// all the weight values are nil, then we should return random model name
344+
if model.Spec.TargetModels[0].Weight == nil {
345+
index := r.Int31n(int32(len(model.Spec.TargetModels)))
346+
return model.Spec.TargetModels[index].Name
347+
}
348+
349+
var weights int32
344350
for _, model := range model.Spec.TargetModels {
345351
weights += *model.Weight
346352
}
347353
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
348354
randomVal := r.Int31n(weights)
355+
// TODO: optimize this without using loop
349356
for _, model := range model.Spec.TargetModels {
350357
if randomVal < *model.Weight {
351358
return model.Name

pkg/epp/datastore/datastore_test.go

+20-1
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,33 @@ func TestRandomWeightedDraw(t *testing.T) {
280280
},
281281
want: "v1.1",
282282
},
283+
{
284+
name: "weighted distribution with weight unset",
285+
model: &v1alpha2.InferenceModel{
286+
Spec: v1alpha2.InferenceModelSpec{
287+
TargetModels: []v1alpha2.TargetModel{
288+
{
289+
Name: "canary",
290+
},
291+
{
292+
Name: "v1.1",
293+
},
294+
{
295+
Name: "v1",
296+
},
297+
},
298+
},
299+
},
300+
want: "canary",
301+
},
283302
}
284303
var seedVal int64 = 420
285304
for _, test := range tests {
286305
t.Run(test.name, func(t *testing.T) {
287306
for range 10000 {
288307
model := RandomWeightedDraw(logger, test.model, seedVal)
289308
if model != test.want {
290-
t.Errorf("Model returned!: %v", model)
309+
t.Errorf("Model returned: %v != %v", model, test.want)
291310
break
292311
}
293312
}

0 commit comments

Comments
 (0)