Skip to content

Commit dfe8d9c

Browse files
authored
scheduling changes for lora affinity load balancing (#423)
* scheduling changes for lora affinity load balancing * refactor unit tests, address comments * restore vllm deployment manifest * update README for model server protocol to add waiting lora adapters * remove unused variables * removed unused func * fix model protocol readme * fix hermetic test for select active lora, low queue * update comment in metrics.go in vllm backend * add filter test TestLoRASoftAffinityDistribution * restore vllm manifest * update unit test
1 parent 6118534 commit dfe8d9c

File tree

6 files changed

+196
-24
lines changed

6 files changed

+196
-24
lines changed

docs/proposals/003-model-server-protocol/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ The model server MUST expose the following LoRA adapter metrics via the same Pro
4747
requested adapter. Example: `"max_lora": "8"`.
4848
* `running_lora_adapters`: A comma separated list of adapters that are currently loaded in GPU
4949
memory and ready to serve requests. Example: `"running_lora_adapters": "adapter1, adapter2"`
50+
* `waiting_lora_adapters`: A comma separated list of adapters that are waiting to be served. Example: `"waiting_lora_adapters": "adapter1, adapter2"`

pkg/epp/backend/vllm/metrics.go

+42-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@ import (
3434
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3535
)
3636

37+
// Metric names used in the vLLM metrics implementation.
38+
// Refer to the protocol doc for more details:
39+
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/003-model-server-protocol
3740
const (
3841
LoraRequestInfoMetricName = "vllm:lora_requests_info"
3942
LoraRequestInfoRunningAdaptersMetricName = "running_lora_adapters"
43+
LoraRequestInfoWaitingAdaptersMetricName = "waiting_lora_adapters"
4044
LoraRequestInfoMaxAdaptersMetricName = "max_lora"
4145
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
4246
RunningQueueSizeMetricName = "vllm:num_requests_running"
@@ -45,8 +49,7 @@ const (
4549
RunningQueueSizeMetricName = "vllm:num_tokens_running"
4650
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
4751
*/
48-
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
49-
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
52+
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
5053
)
5154

5255
type PodMetricsClientImpl struct{}
@@ -138,6 +141,14 @@ func promToPodMetrics(
138141
}
139142
}
140143
}
144+
if label.GetName() == LoraRequestInfoWaitingAdaptersMetricName {
145+
if label.GetValue() != "" {
146+
adapterList := strings.Split(label.GetValue(), ",")
147+
for _, adapter := range adapterList {
148+
updated.ActiveModels[adapter] = 0
149+
}
150+
}
151+
}
141152
if label.GetName() == LoraRequestInfoMaxAdaptersMetricName {
142153
if label.GetValue() != "" {
143154
updated.MaxActiveModels, err = strconv.Atoi(label.GetValue())
@@ -163,14 +174,42 @@ func getLatestLoraMetric(logger logr.Logger, metricFamilies map[string]*dto.Metr
163174
logger.V(logutil.DEFAULT).Error(nil, "Metric family not found", "name", LoraRequestInfoMetricName)
164175
return nil, time.Time{}, fmt.Errorf("metric family %q not found", LoraRequestInfoMetricName)
165176
}
166-
var latestTs float64
177+
167178
var latest *dto.Metric
179+
var latestTs float64
180+
181+
// Iterate over all metrics in the family.
168182
for _, m := range loraRequests.GetMetric() {
183+
var running, waiting string
184+
// Read the label values for running and waiting adapters.
185+
for _, lp := range m.GetLabel() {
186+
switch lp.GetName() {
187+
case LoraRequestInfoRunningAdaptersMetricName:
188+
running = lp.GetValue()
189+
case LoraRequestInfoWaitingAdaptersMetricName:
190+
waiting = lp.GetValue()
191+
}
192+
}
193+
194+
// Ignore metrics with both labels empty. This happens when there are no running or waiting requests on
195+
// the server, in this case it is best to use the last set of active adapters.
196+
if running == "" && waiting == "" {
197+
continue
198+
}
199+
200+
// Select the metric with the latest creation timestamp.
169201
if m.GetGauge().GetValue() > latestTs {
170202
latestTs = m.GetGauge().GetValue()
171203
latest = m
172204
}
173205
}
206+
207+
if latest == nil {
208+
logger.V(logutil.TRACE).Info("Metric value Empty", "value", latest, "metric", LoraRequestInfoMetricName)
209+
return nil, time.Time{}, nil
210+
}
211+
212+
// Convert the gauge value (creation timestamp) to time.Time.
174213
return latest, time.Unix(0, int64(latestTs*1000)), nil
175214
}
176215

pkg/epp/scheduling/filter.go

+52-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package scheduling
1919
import (
2020
"errors"
2121
"math"
22+
"math/rand"
23+
"time"
2224

2325
"github.com/go-logr/logr"
2426
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
@@ -183,18 +185,59 @@ func lowLoRACostPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool {
183185
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
184186
}
185187

186-
// loRAAffinityPredicate is a filter function to check whether a pod has affinity to the lora requested.
187-
func loRAAffinityPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool {
188-
_, ok := pod.ActiveModels[req.ResolvedTargetModel]
189-
return ok
190-
}
188+
// loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods
189+
// with existing LoRA model affinity while allowing for load balancing through randomization.
190+
//
191+
// The function works by:
192+
// 1. Separating pods into two groups: those with target model affinity and those with available capacity
193+
// 2. Using a probability threshold to sometimes select from non-affinity pods to enable load balancing
194+
// 3. Falling back to whatever group has pods if one group is empty
195+
//
196+
// Parameters:
197+
// - logger: Logger interface for diagnostic output
198+
// - req: LLM request containing the resolved target model
199+
// - pods: Slice of pod metrics to filter
200+
//
201+
// Returns:
202+
// - Filtered slice of pod metrics based on affinity and availability
203+
// - Error if any issues occur during filtering
204+
func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []*datastore.PodMetrics) ([]*datastore.PodMetrics, error) {
205+
206+
// Pre-allocate slices with estimated capacity
207+
filtered_affinity := make([]*datastore.PodMetrics, 0, len(pods))
208+
filtered_available := make([]*datastore.PodMetrics, 0, len(pods))
209+
210+
// Categorize pods based on affinity and availability
211+
for _, pod := range pods {
212+
213+
if _, exists := pod.ActiveModels[req.ResolvedTargetModel]; exists {
214+
filtered_affinity = append(filtered_affinity, pod)
215+
} else if len(pod.ActiveModels) < pod.MaxActiveModels {
216+
filtered_available = append(filtered_available, pod)
217+
}
218+
}
219+
220+
// Use crypto/rand for better randomization in production environments
221+
randSource := rand.NewSource(time.Now().UnixNano())
222+
randGen := rand.New(randSource)
223+
224+
// If both groups have pods, use probability to select which group to return
225+
if len(filtered_affinity) > 0 && len(filtered_available) > 0 {
226+
if randGen.Float64() < loraAffinityThreshold {
227+
return filtered_affinity, nil
228+
}
229+
return filtered_available, nil
230+
}
231+
232+
// Return whichever group has pods
233+
if len(filtered_affinity) > 0 {
234+
return filtered_affinity, nil
235+
}
191236

192-
// canAcceptNewLoraPredicate is a filter function to check whether a pod has room to load the adapter.
193-
func canAcceptNewLoraPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool {
194-
return len(pod.ActiveModels) < pod.MaxActiveModels
237+
return filtered_available, nil
195238
}
196239

197-
func criticalRequestPredicate(req *LLMRequest, pod *datastore.PodMetrics) bool {
240+
func criticalRequestPredicate(req *LLMRequest, _ *datastore.PodMetrics) bool {
198241
return req.Critical
199242
}
200243

pkg/epp/scheduling/filter_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -429,3 +429,93 @@ func TestFilterFunc(t *testing.T) {
429429
})
430430
}
431431
}
432+
433+
// TestLoRASoftAffinityDistribution tests that the loRASoftAffinityFilter function
434+
// properly distributes requests according to the loraAffinityThreshold
435+
func TestLoRASoftAffinityDistribution(t *testing.T) {
436+
logger := logutil.NewTestLogger()
437+
438+
const (
439+
testModelName = "test-model"
440+
testAffinityModel = "test-affinity-model"
441+
numIterations = 10000
442+
tolerancePercent = 5.0 // Allow 5% tolerance from expected distribution
443+
)
444+
445+
// Create a test request and pods
446+
req := &LLMRequest{
447+
Model: testAffinityModel,
448+
ResolvedTargetModel: testAffinityModel,
449+
}
450+
451+
// Test setup: One affinity pod and one available pod
452+
pods := []*datastore.PodMetrics{
453+
{
454+
Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "affinity-pod"}},
455+
Metrics: datastore.Metrics{
456+
MaxActiveModels: 2,
457+
ActiveModels: map[string]int{
458+
testAffinityModel: 1,
459+
},
460+
},
461+
},
462+
{
463+
Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "available-pod"}},
464+
Metrics: datastore.Metrics{
465+
MaxActiveModels: 2,
466+
ActiveModels: map[string]int{},
467+
},
468+
},
469+
}
470+
471+
// Run the filter function multiple times and count the results
472+
affinityCount := 0
473+
availableCount := 0
474+
475+
// Use the actual loraAffinityThreshold as defined in the original code
476+
// This test should work with whatever value is set there
477+
expectedAffinityPercent := loraAffinityThreshold * 100
478+
for i := 0; i < numIterations; i++ {
479+
result, err := loRASoftAffinityFilter(logger, req, pods)
480+
if err != nil {
481+
t.Fatalf("Unexpected error: %v", err)
482+
}
483+
484+
// Check which type of pod was returned
485+
if len(result) != 1 {
486+
t.Fatalf("Expected exactly one pod in result, got %d", len(result))
487+
}
488+
489+
// Identify if the returned pod is the affinity pod or available pod
490+
if _, exists := result[0].ActiveModels[testAffinityModel]; exists {
491+
affinityCount++
492+
} else {
493+
availableCount++
494+
}
495+
}
496+
497+
// Calculate the actual percentages
498+
actualAffinityPercent := float64(affinityCount) / float64(numIterations) * 100
499+
actualAvailablePercent := float64(availableCount) / float64(numIterations) * 100
500+
501+
// Check if the distribution matches expected threshold within tolerance
502+
affinityLowerBound := expectedAffinityPercent - tolerancePercent
503+
affinityUpperBound := expectedAffinityPercent + tolerancePercent
504+
505+
availableLowerBound := actualAvailablePercent - tolerancePercent
506+
availableUpperBound := actualAvailablePercent + tolerancePercent
507+
508+
t.Logf("Distribution results over %d iterations:", numIterations)
509+
t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, loraAffinityThreshold)
510+
t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations)
511+
t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations)
512+
513+
if actualAffinityPercent < affinityLowerBound || actualAffinityPercent > affinityUpperBound {
514+
t.Errorf("Affinity selection percent %.2f%% outside expected range %.2f%% to %.2f%%",
515+
actualAffinityPercent, affinityLowerBound, affinityUpperBound)
516+
}
517+
if actualAvailablePercent < availableLowerBound || actualAvailablePercent > availableUpperBound {
518+
t.Errorf("Availability selection percent %.2f%% outside expected range %.2f%% to %.2f%%",
519+
actualAvailablePercent, availableLowerBound, availableUpperBound)
520+
}
521+
}

pkg/epp/scheduling/scheduler.go

+9-11
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ const (
3636
queueThresholdCritical = 5
3737
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
3838
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
39-
// The value of 50 is arrived heuristicically based on experiments.
40-
queueingThresholdLoRA = 50
39+
// The value of 128 is arrived heuristicically based on experiments.
40+
queueingThresholdLoRA = 128
41+
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
42+
// loraAffinityThreshold indicates the probability with which we prefer a pod with LoRA affinity over a pod without but having room to fit more LoRA adapters.
43+
loraAffinityThreshold = 0.999
4144
)
4245

4346
var (
@@ -54,7 +57,7 @@ var (
5457
filter: leastQueuingFilterFunc,
5558
nextOnSuccessOrFailure: &filter{
5659
name: "low cost LoRA",
57-
filter: toFilterFunc(lowLoRACostPredicate),
60+
filter: loRASoftAffinityFilter,
5861
nextOnSuccessOrFailure: &filter{
5962
name: "least KV cache percent",
6063
filter: leastKVCacheFilterFunc,
@@ -76,14 +79,9 @@ var (
7679
name: "low queueing filter",
7780
filter: toFilterFunc((lowQueueingPodPredicate)),
7881
nextOnSuccess: &filter{
79-
name: "affinity LoRA",
80-
filter: toFilterFunc(loRAAffinityPredicate),
81-
nextOnSuccess: queueAndKVCacheFilter,
82-
nextOnFailure: &filter{
83-
name: "can accept LoRA Adapter",
84-
filter: toFilterFunc(canAcceptNewLoraPredicate),
85-
nextOnSuccessOrFailure: queueAndKVCacheFilter,
86-
},
82+
name: "affinity LoRA",
83+
filter: loRASoftAffinityFilter,
84+
nextOnSuccessOrFailure: queueAndKVCacheFilter,
8785
},
8886
nextOnFailure: queueLoRAAndKVCacheFilter,
8987
}

test/integration/hermetic_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ func TestKubeInferenceModelRequest(t *testing.T) {
158158
KVCacheUsagePercent: 0.2,
159159
ActiveModels: map[string]int{
160160
"foo": 1,
161+
"bar": 1,
161162
},
162163
}),
163164
},
@@ -200,7 +201,7 @@ func TestKubeInferenceModelRequest(t *testing.T) {
200201
},
201202
}),
202203
extprocutils.FakePodMetrics(1, datastore.Metrics{
203-
WaitingQueueSize: 50,
204+
WaitingQueueSize: 200,
204205
KVCacheUsagePercent: 0.1,
205206
ActiveModels: map[string]int{
206207
"foo": 1,

0 commit comments

Comments
 (0)