Skip to content

Commit d122a6c

Browse files
authored
Add priority based scheduling (#25)
* Add priority based scheduling * Use the least kv cache for sheddable requests when there is capacity
1 parent 18bc3a2 commit d122a6c

File tree

10 files changed

+533
-47
lines changed

10 files changed

+533
-47
lines changed

pkg/ext-proc/backend/provider.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func (p *Provider) refreshPodsOnce() error {
113113
new := &PodMetrics{
114114
Pod: pod,
115115
Metrics: Metrics{
116-
CachedModels: make(map[string]int),
116+
ActiveModels: make(map[string]int),
117117
},
118118
}
119119
p.podMetrics.Store(pod, new)

pkg/ext-proc/backend/types.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ type Pod struct {
1212
}
1313

1414
func (p Pod) String() string {
15-
return p.Namespace + "." + p.Name
15+
return p.Namespace + "/" + p.Name
1616
}
1717

1818
type Metrics struct {
19-
// CachedModels is a set of models(including LoRA adapters) that are currently cached to GPU.
20-
CachedModels map[string]int
19+
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
20+
ActiveModels map[string]int
21+
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
22+
MaxActiveModels int
2123
RunningQueueSize int
2224
WaitingQueueSize int
2325
KVCacheUsagePercent float64
@@ -34,14 +36,14 @@ func (pm *PodMetrics) String() string {
3436
}
3537

3638
func (pm *PodMetrics) Clone() *PodMetrics {
37-
cm := make(map[string]int, len(pm.CachedModels))
38-
for k, v := range pm.CachedModels {
39+
cm := make(map[string]int, len(pm.ActiveModels))
40+
for k, v := range pm.ActiveModels {
3941
cm[k] = v
4042
}
4143
clone := &PodMetrics{
4244
Pod: pm.Pod,
4345
Metrics: Metrics{
44-
CachedModels: cm,
46+
ActiveModels: cm,
4547
RunningQueueSize: pm.RunningQueueSize,
4648
WaitingQueueSize: pm.WaitingQueueSize,
4749
KVCacheUsagePercent: pm.KVCacheUsagePercent,

pkg/ext-proc/backend/vllm/metrics.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,21 @@ func promToPodMetrics(metricFamilies map[string]*dto.MetricFamily, existing *bac
8585
}
8686
*/
8787

88+
// TODO(https://github.com/kubernetes-sigs/llm-instance-gateway/issues/22): Read from vLLM metrics once the is available.
89+
updated.MaxActiveModels = 4
90+
8891
// Update active loras
8992
mf, ok := metricFamilies[ActiveLoRAAdaptersMetricName]
9093
if ok {
9194
// IMPORTANT: replace the map entries instead of appending to it.
92-
updated.CachedModels = make(map[string]int)
95+
updated.ActiveModels = make(map[string]int)
9396
for _, metric := range mf.GetMetric() {
9497
for _, label := range metric.GetLabel() {
9598
if label.GetName() == "active_adapters" {
9699
if label.GetValue() != "" {
97100
adapterList := strings.Split(label.GetValue(), ",")
98101
for _, adapter := range adapterList {
99-
updated.CachedModels[adapter] = 0
102+
updated.ActiveModels[adapter] = 0
100103
}
101104
}
102105
}

pkg/ext-proc/go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.21
55
require (
66
github.com/bojand/ghz v0.120.0
77
github.com/envoyproxy/go-control-plane v0.13.0
8+
github.com/google/go-cmp v0.6.0
89
github.com/jhump/protoreflect v1.15.1
910
github.com/prometheus/client_model v0.6.1
1011
github.com/prometheus/common v0.55.0

pkg/ext-proc/handlers/request.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
3838
// TODO: Once the API is approved, read the "LLMUseCase" configuration and apply traffic split.
3939
TargetModels: map[string]int{model: 100},
4040
ResolvedTargetModel: model,
41+
// TODO: Read from LLMService CRD.
42+
Critical: true,
4143
}
4244

4345
// Update target models in the body.
@@ -51,7 +53,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
5153

5254
targetPod, err := s.scheduler.Schedule(llmReq)
5355
if err != nil {
54-
return nil, fmt.Errorf("failed to find target pod: %v", err)
56+
return nil, fmt.Errorf("failed to find target pod: %w", err)
5557
}
5658
klog.V(3).Infof("Selected target model %v in target pod: %v\n", llmReq.ResolvedTargetModel, targetPod)
5759

pkg/ext-proc/handlers/server.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"io"
55

66
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
7+
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
78
"google.golang.org/grpc/codes"
89
"google.golang.org/grpc/status"
910
klog "k8s.io/klog/v2"
@@ -83,13 +84,28 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8384

8485
if err != nil {
8586
klog.Errorf("failed to process request: %v", err)
86-
return status.Errorf(codes.Unknown, "failed to handle request: %v", err)
87+
switch status.Code(err) {
88+
// This code can be returned by scheduler when there is no capacity for sheddable
89+
// requests.
90+
case codes.ResourceExhausted:
91+
resp = &extProcPb.ProcessingResponse{
92+
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
93+
ImmediateResponse: &extProcPb.ImmediateResponse{
94+
Status: &envoyTypePb.HttpStatus{
95+
Code: envoyTypePb.StatusCode_TooManyRequests,
96+
},
97+
},
98+
},
99+
}
100+
default:
101+
return status.Errorf(status.Code(err), "failed to handle request: %w", err)
102+
}
87103
}
88104

89105
klog.V(3).Infof("response: %v", resp)
90106
if err := srv.Send(resp); err != nil {
91107
klog.Errorf("send error %v", err)
92-
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
108+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %w", err)
93109
}
94110
}
95111
}

pkg/ext-proc/scheduling/filter.go

+44-29
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
type Filter interface {
1313
Name() string
14-
Filter(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
14+
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
1515
}
1616

1717
// filter applies current filterFunc, and then recursively applies next filters depending success or
@@ -41,42 +41,46 @@ func (f *filter) Name() string {
4141
return f.name
4242
}
4343

44-
func (f *filter) Filter(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
45-
if f == nil {
46-
klog.V(3).Infof("Running nil filter, returning all input pods by default")
47-
return pods, nil
48-
}
49-
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, b, len(pods))
44+
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
45+
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, req, len(pods))
5046

51-
filtered, err := f.filter(b, pods)
47+
filtered, err := f.filter(req, pods)
5248

5349
next := f.nextOnSuccessOrFailure
54-
if err == nil {
55-
klog.V(3).Infof("onSuccess %v -> %v, filtered: %v", f.name, next.Name(), len(filtered))
50+
if err == nil && len(filtered) > 0 {
51+
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
52+
// No succeeding filters to run, return.
53+
return filtered, err
54+
}
5655
if f.nextOnSuccess != nil {
5756
next = f.nextOnSuccess
5857
}
58+
klog.V(3).Infof("onSuccess %q -> %q, filtered: %v", f.name, next.Name(), len(filtered))
5959
// On success, pass the filtered result to the next filter.
60-
return next.Filter(b, filtered)
61-
}
62-
63-
klog.V(3).Infof("onFailure %v -> %v", f.name, next.Name())
64-
if f.nextOnFailure != nil {
65-
next = f.nextOnFailure
60+
return next.Filter(req, filtered)
61+
} else {
62+
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
63+
// No succeeding filters to run, return.
64+
return filtered, err
65+
}
66+
if f.nextOnFailure != nil {
67+
next = f.nextOnFailure
68+
}
69+
klog.V(3).Infof("onFailure %q -> %q", f.name, next.Name())
70+
// On failure, pass the initial set of pods to the next filter.
71+
return next.Filter(req, pods)
6672
}
67-
// On failure, pass the initial set of pods to the next filter.
68-
return next.Filter(b, pods)
6973
}
7074

7175
// filterFunc filters a set of input pods to a subset.
72-
type filterFunc func(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
76+
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
7377

7478
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
7579
func toFilterFunc(pp podPredicate) filterFunc {
76-
return func(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
80+
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
7781
filtered := []*backend.PodMetrics{}
7882
for _, pod := range pods {
79-
pass := pp(b, pod)
83+
pass := pp(req, pod)
8084
if pass {
8185
filtered = append(filtered, pod)
8286
}
@@ -95,7 +99,7 @@ func toFilterFunc(pp podPredicate) filterFunc {
9599
// the least one as it gives more choices for the next filter, which on aggregate gave better
96100
// results.
97101
// TODO: Compare this strategy with other strategies such as top K.
98-
func leastQueuingFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
102+
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
99103
min := math.MaxInt
100104
max := 0
101105
filtered := []*backend.PodMetrics{}
@@ -123,9 +127,9 @@ func leastQueuingFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backe
123127
// should consider them all instead of the absolute minimum one. This worked better than picking the
124128
// least one as it gives more choices for the next filter, which on aggregate gave better results.
125129
// TODO: Compare this strategy with other strategies such as top K.
126-
func leastKVCacheFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
130+
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
127131
min := math.MaxFloat64
128-
max := math.SmallestNonzeroFloat64
132+
var max float64 = 0
129133
filtered := []*backend.PodMetrics{}
130134

131135
for _, pod := range pods {
@@ -146,10 +150,21 @@ func leastKVCacheFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backe
146150
}
147151

148152
// podPredicate is a filter function to check whether a pod is desired.
149-
type podPredicate func(b *LLMRequest, pod *backend.PodMetrics) bool
153+
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
154+
155+
// We consider serving an adapter low cost it the adapter is active in the model server, or the
156+
// model server has room to load the adapter
157+
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
158+
_, ok := pod.ActiveModels[req.ResolvedTargetModel]
159+
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
160+
}
150161

151-
// loraAffinityPredicate return true if the pod have the requested LoRA adapter loaded.
152-
func loraAffinityPredicate(b *LLMRequest, pod *backend.PodMetrics) bool {
153-
_, ok := pod.CachedModels[b.ResolvedTargetModel]
154-
return ok
162+
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
163+
return req.Critical
164+
}
165+
166+
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
167+
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
168+
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
169+
}
155170
}

0 commit comments

Comments
 (0)