Skip to content

Commit 62f54a5

Browse files
liu-congkfswain
authored andcommitted
Add priority based scheduling (kubernetes-sigs#25)
* Add priority based scheduling * Use the least kv cache for sheddable requests when there is capacity
1 parent 2280463 commit 62f54a5

File tree

10 files changed

+661
-45
lines changed

10 files changed

+661
-45
lines changed

pkg/ext-proc/backend/provider.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func (p *Provider) refreshPodsOnce() error {
9595
new := &PodMetrics{
9696
Pod: pod,
9797
Metrics: Metrics{
98-
CachedModels: make(map[string]int),
98+
ActiveModels: make(map[string]int),
9999
},
100100
}
101101
p.podMetrics.Store(pod, new)

pkg/ext-proc/backend/types.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ type Pod struct {
1313
}
1414

1515
func (p Pod) String() string {
16-
return p.Namespace + "." + p.Name
16+
return p.Namespace + "/" + p.Name
1717
}
1818

1919
type Metrics struct {
20-
// CachedModels is a set of models(including LoRA adapters) that are currently cached to GPU.
21-
CachedModels map[string]int
20+
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
21+
ActiveModels map[string]int
22+
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
23+
MaxActiveModels int
2224
RunningQueueSize int
2325
WaitingQueueSize int
2426
KVCacheUsagePercent float64
@@ -35,14 +37,14 @@ func (pm *PodMetrics) String() string {
3537
}
3638

3739
func (pm *PodMetrics) Clone() *PodMetrics {
38-
cm := make(map[string]int, len(pm.CachedModels))
39-
for k, v := range pm.CachedModels {
40+
cm := make(map[string]int, len(pm.ActiveModels))
41+
for k, v := range pm.ActiveModels {
4042
cm[k] = v
4143
}
4244
clone := &PodMetrics{
4345
Pod: pm.Pod,
4446
Metrics: Metrics{
45-
CachedModels: cm,
47+
ActiveModels: cm,
4648
RunningQueueSize: pm.RunningQueueSize,
4749
WaitingQueueSize: pm.WaitingQueueSize,
4850
KVCacheUsagePercent: pm.KVCacheUsagePercent,

pkg/ext-proc/backend/vllm/metrics.go

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Package vllm provides vllm specific pod metrics implementation.
2+
package vllm
3+
4+
import (
5+
"ext-proc/backend"
6+
"fmt"
7+
"net/http"
8+
"strings"
9+
"time"
10+
11+
dto "github.com/prometheus/client_model/go"
12+
"github.com/prometheus/common/expfmt"
13+
"go.uber.org/multierr"
14+
klog "k8s.io/klog/v2"
15+
)
16+
17+
const (
18+
ActiveLoRAAdaptersMetricName = "vllm:info_active_adapters_info"
19+
LoRAAdapterPendingRequestMetricName = "vllm:active_lora_adapters"
20+
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
21+
RunningQueueSizeMetricName = "vllm:num_requests_running"
22+
WaitingQueueSizeMetricName = "vllm:num_requests_waiting"
23+
/* TODO: Uncomment this once the following are added to the fork.
24+
RunningQueueSizeMetricName = "vllm:num_tokens_running"
25+
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
26+
*/
27+
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
28+
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
29+
)
30+
31+
type PodMetricsClientImpl struct {
32+
}
33+
34+
// FetchMetrics fetches metrics from a given pod.
35+
func (p *PodMetricsClientImpl) FetchMetrics(pod backend.Pod, existing *backend.PodMetrics) (*backend.PodMetrics, error) {
36+
// Currently the metrics endpoint is hard-coded, which works with vLLM.
37+
// TODO(https://github.com/kubernetes-sigs/llm-instance-gateway/issues/16): Consume this from LLMServerPool config.
38+
url := fmt.Sprintf("http://%s/metrics", pod.Address)
39+
resp, err := http.Get(url)
40+
if err != nil {
41+
klog.Errorf("failed to fetch metrics from %s: %v", pod, err)
42+
return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod, err)
43+
}
44+
defer resp.Body.Close()
45+
46+
if resp.StatusCode != http.StatusOK {
47+
klog.Errorf("unexpected status code from %s: %v", pod, resp.StatusCode)
48+
return nil, fmt.Errorf("unexpected status code from %s: %v", pod, resp.StatusCode)
49+
}
50+
51+
parser := expfmt.TextParser{}
52+
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
53+
if err != nil {
54+
return nil, err
55+
}
56+
return promToPodMetrics(metricFamilies, existing)
57+
}
58+
59+
// promToPodMetrics updates internal pod metrics with scraped prometheus metrics.
60+
// A combined error is returned if errors occur in one or more metric processing.
61+
// it returns a new PodMetrics pointer which can be used to atomically update the pod metrics map.
62+
func promToPodMetrics(metricFamilies map[string]*dto.MetricFamily, existing *backend.PodMetrics) (*backend.PodMetrics, error) {
63+
var errs error
64+
updated := existing.Clone()
65+
runningQueueSize, _, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
66+
multierr.Append(errs, err)
67+
if err == nil {
68+
updated.RunningQueueSize = int(runningQueueSize.GetGauge().GetValue())
69+
}
70+
waitingQueueSize, _, err := getLatestMetric(metricFamilies, WaitingQueueSizeMetricName)
71+
multierr.Append(errs, err)
72+
if err == nil {
73+
updated.WaitingQueueSize = int(waitingQueueSize.GetGauge().GetValue())
74+
}
75+
cachePercent, _, err := getLatestMetric(metricFamilies, KVCacheUsagePercentMetricName)
76+
multierr.Append(errs, err)
77+
if err == nil {
78+
updated.KVCacheUsagePercent = cachePercent.GetGauge().GetValue()
79+
}
80+
/* TODO: uncomment once this is available in vllm.
81+
kvCap, _, err := getGaugeLatestValue(metricFamilies, KvCacheMaxTokenCapacityMetricName)
82+
multierr.Append(errs, err)
83+
if err != nil {
84+
updated.KvCacheMaxTokenCapacity = int(kvCap)
85+
}
86+
*/
87+
88+
// Update active loras
89+
mf, ok := metricFamilies[ActiveLoRAAdaptersMetricName]
90+
if ok {
91+
// IMPORTANT: replace the map entries instead of appending to it.
92+
updated.CachedModels = make(map[string]int)
93+
for _, metric := range mf.GetMetric() {
94+
for _, label := range metric.GetLabel() {
95+
if label.GetName() == "active_adapters" {
96+
if label.GetValue() != "" {
97+
adapterList := strings.Split(label.GetValue(), ",")
98+
for _, adapter := range adapterList {
99+
updated.CachedModels[adapter] = 0
100+
}
101+
}
102+
}
103+
}
104+
}
105+
} else {
106+
klog.Warningf("metric family %q not found", ActiveLoRAAdaptersMetricName)
107+
multierr.Append(errs, fmt.Errorf("metric family %q not found", ActiveLoRAAdaptersMetricName))
108+
}
109+
110+
return updated, errs
111+
}
112+
113+
// getLatestMetric gets the latest metric of a family. This should be used to get the latest Gauge metric.
114+
func getLatestMetric(metricFamilies map[string]*dto.MetricFamily, metricName string) (*dto.Metric, time.Time, error) {
115+
mf, ok := metricFamilies[metricName]
116+
if !ok {
117+
klog.Warningf("metric family %q not found", metricName)
118+
return nil, time.Time{}, fmt.Errorf("metric family %q not found", metricName)
119+
}
120+
if len(mf.GetMetric()) == 0 {
121+
return nil, time.Time{}, fmt.Errorf("no metrics available for %q", metricName)
122+
}
123+
var latestTs int64
124+
var latest *dto.Metric
125+
for _, m := range mf.GetMetric() {
126+
if m.GetTimestampMs() >= latestTs {
127+
latestTs = m.GetTimestampMs()
128+
latest = m
129+
}
130+
}
131+
klog.V(4).Infof("Got metric value %+v for metric %v", latest, metricName)
132+
return latest, time.Unix(0, latestTs*1000), nil
133+
}

pkg/ext-proc/go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.22.0
55
require (
66
github.com/bojand/ghz v0.120.0
77
github.com/envoyproxy/go-control-plane v0.13.0
8+
github.com/google/go-cmp v0.6.0
89
github.com/jhump/protoreflect v1.15.1
910
github.com/prometheus/client_model v0.6.1
1011
github.com/prometheus/common v0.55.0

pkg/ext-proc/handlers/request.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
3838
// TODO: Once the API is approved, read the "LLMUseCase" configuration and apply traffic split.
3939
TargetModels: map[string]int{model: 100},
4040
ResolvedTargetModel: model,
41+
// TODO: Read from LLMService CRD.
42+
Critical: true,
4143
}
4244

4345
// Update target models in the body.
@@ -51,7 +53,7 @@ func (s *Server) HandleRequestBody(reqCtx *RequestContext, req *extProcPb.Proces
5153

5254
targetPod, err := s.scheduler.Schedule(llmReq)
5355
if err != nil {
54-
return nil, fmt.Errorf("failed to find target pod: %v", err)
56+
return nil, fmt.Errorf("failed to find target pod: %w", err)
5557
}
5658
klog.V(3).Infof("Selected target model %v in target pod: %v\n", llmReq.ResolvedTargetModel, targetPod)
5759

pkg/ext-proc/handlers/server.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"io"
55

66
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
7+
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
78
"google.golang.org/grpc/codes"
89
"google.golang.org/grpc/status"
910
klog "k8s.io/klog/v2"
@@ -83,13 +84,28 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8384

8485
if err != nil {
8586
klog.Errorf("failed to process request: %v", err)
86-
return status.Errorf(codes.Unknown, "failed to handle request: %v", err)
87+
switch status.Code(err) {
88+
// This code can be returned by scheduler when there is no capacity for sheddable
89+
// requests.
90+
case codes.ResourceExhausted:
91+
resp = &extProcPb.ProcessingResponse{
92+
Response: &extProcPb.ProcessingResponse_ImmediateResponse{
93+
ImmediateResponse: &extProcPb.ImmediateResponse{
94+
Status: &envoyTypePb.HttpStatus{
95+
Code: envoyTypePb.StatusCode_TooManyRequests,
96+
},
97+
},
98+
},
99+
}
100+
default:
101+
return status.Errorf(status.Code(err), "failed to handle request: %w", err)
102+
}
87103
}
88104

89105
klog.V(3).Infof("response: %v", resp)
90106
if err := srv.Send(resp); err != nil {
91107
klog.Errorf("send error %v", err)
92-
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
108+
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %w", err)
93109
}
94110
}
95111
}

pkg/ext-proc/scheduling/filter.go

+44-29
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
type Filter interface {
1313
Name() string
14-
Filter(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
14+
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
1515
}
1616

1717
// filter applies current filterFunc, and then recursively applies next filters depending success or
@@ -41,42 +41,46 @@ func (f *filter) Name() string {
4141
return f.name
4242
}
4343

44-
func (f *filter) Filter(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
45-
if f == nil {
46-
klog.V(3).Infof("Running nil filter, returning all input pods by default")
47-
return pods, nil
48-
}
49-
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, b, len(pods))
44+
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
45+
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, req, len(pods))
5046

51-
filtered, err := f.filter(b, pods)
47+
filtered, err := f.filter(req, pods)
5248

5349
next := f.nextOnSuccessOrFailure
54-
if err == nil {
55-
klog.V(3).Infof("onSuccess %v -> %v, filtered: %v", f.name, next.Name(), len(filtered))
50+
if err == nil && len(filtered) > 0 {
51+
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
52+
// No succeeding filters to run, return.
53+
return filtered, err
54+
}
5655
if f.nextOnSuccess != nil {
5756
next = f.nextOnSuccess
5857
}
58+
klog.V(3).Infof("onSuccess %q -> %q, filtered: %v", f.name, next.Name(), len(filtered))
5959
// On success, pass the filtered result to the next filter.
60-
return next.Filter(b, filtered)
61-
}
62-
63-
klog.V(3).Infof("onFailure %v -> %v", f.name, next.Name())
64-
if f.nextOnFailure != nil {
65-
next = f.nextOnFailure
60+
return next.Filter(req, filtered)
61+
} else {
62+
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
63+
// No succeeding filters to run, return.
64+
return filtered, err
65+
}
66+
if f.nextOnFailure != nil {
67+
next = f.nextOnFailure
68+
}
69+
klog.V(3).Infof("onFailure %q -> %q", f.name, next.Name())
70+
// On failure, pass the initial set of pods to the next filter.
71+
return next.Filter(req, pods)
6672
}
67-
// On failure, pass the initial set of pods to the next filter.
68-
return next.Filter(b, pods)
6973
}
7074

7175
// filterFunc filters a set of input pods to a subset.
72-
type filterFunc func(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
76+
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
7377

7478
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
7579
func toFilterFunc(pp podPredicate) filterFunc {
76-
return func(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
80+
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
7781
filtered := []*backend.PodMetrics{}
7882
for _, pod := range pods {
79-
pass := pp(b, pod)
83+
pass := pp(req, pod)
8084
if pass {
8185
filtered = append(filtered, pod)
8286
}
@@ -95,7 +99,7 @@ func toFilterFunc(pp podPredicate) filterFunc {
9599
// the least one as it gives more choices for the next filter, which on aggregate gave better
96100
// results.
97101
// TODO: Compare this strategy with other strategies such as top K.
98-
func leastQueuingFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
102+
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
99103
min := math.MaxInt
100104
max := 0
101105
filtered := []*backend.PodMetrics{}
@@ -123,9 +127,9 @@ func leastQueuingFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backe
123127
// should consider them all instead of the absolute minimum one. This worked better than picking the
124128
// least one as it gives more choices for the next filter, which on aggregate gave better results.
125129
// TODO: Compare this strategy with other strategies such as top K.
126-
func leastKVCacheFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
130+
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
127131
min := math.MaxFloat64
128-
max := math.SmallestNonzeroFloat64
132+
var max float64 = 0
129133
filtered := []*backend.PodMetrics{}
130134

131135
for _, pod := range pods {
@@ -146,10 +150,21 @@ func leastKVCacheFilterFunc(b *LLMRequest, pods []*backend.PodMetrics) ([]*backe
146150
}
147151

148152
// podPredicate is a filter function to check whether a pod is desired.
149-
type podPredicate func(b *LLMRequest, pod *backend.PodMetrics) bool
153+
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
154+
155+
// We consider serving an adapter low cost it the adapter is active in the model server, or the
156+
// model server has room to load the adapter
157+
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
158+
_, ok := pod.ActiveModels[req.ResolvedTargetModel]
159+
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
160+
}
150161

151-
// loraAffinityPredicate return true if the pod have the requested LoRA adapter loaded.
152-
func loraAffinityPredicate(b *LLMRequest, pod *backend.PodMetrics) bool {
153-
_, ok := pod.CachedModels[b.ResolvedTargetModel]
154-
return ok
162+
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
163+
return req.Critical
164+
}
165+
166+
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
167+
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
168+
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
169+
}
155170
}

0 commit comments

Comments
 (0)