Skip to content

Commit 12bcc9a

Browse files
Allow bodyless requests to passthrough EPP (#555)
* Adding content length checker * Allow requests with no body to passthrough EPP --------- Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent 140d5eb commit 12bcc9a

File tree

9 files changed

+436
-324
lines changed

9 files changed

+436
-324
lines changed

pkg/epp/datastore/datastore.go

-31
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23-
"math/rand"
2423
"sync"
2524

26-
"github.com/go-logr/logr"
2725
corev1 "k8s.io/api/core/v1"
2826
"k8s.io/apimachinery/pkg/labels"
2927
"k8s.io/apimachinery/pkg/types"
@@ -304,35 +302,6 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha2.LabelKey]v1alpha2.LabelV
304302
return outMap
305303
}
306304

307-
func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string {
308-
source := rand.NewSource(rand.Int63())
309-
if seed > 0 {
310-
source = rand.NewSource(seed)
311-
}
312-
r := rand.New(source)
313-
314-
// all the weight values are nil, then we should return random model name
315-
if model.Spec.TargetModels[0].Weight == nil {
316-
index := r.Int31n(int32(len(model.Spec.TargetModels)))
317-
return model.Spec.TargetModels[index].Name
318-
}
319-
320-
var weights int32
321-
for _, model := range model.Spec.TargetModels {
322-
weights += *model.Weight
323-
}
324-
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
325-
randomVal := r.Int31n(weights)
326-
// TODO: optimize this without using loop
327-
for _, model := range model.Spec.TargetModels {
328-
if randomVal < *model.Weight {
329-
return model.Name
330-
}
331-
randomVal -= *model.Weight
332-
}
333-
return ""
334-
}
335-
336305
func IsCritical(model *v1alpha2.InferenceModel) bool {
337306
if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha2.Critical {
338307
return true

pkg/epp/datastore/datastore_test.go

-108
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"k8s.io/apimachinery/pkg/types"
3131
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3232
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
33-
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3433
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
3534
)
3635

@@ -223,113 +222,6 @@ func TestModel(t *testing.T) {
223222
}
224223
}
225224

226-
func TestRandomWeightedDraw(t *testing.T) {
227-
logger := logutil.NewTestLogger()
228-
tests := []struct {
229-
name string
230-
model *v1alpha2.InferenceModel
231-
want string
232-
}{
233-
{
234-
name: "'random' distribution",
235-
model: &v1alpha2.InferenceModel{
236-
Spec: v1alpha2.InferenceModelSpec{
237-
TargetModels: []v1alpha2.TargetModel{
238-
{
239-
Name: "canary",
240-
Weight: pointer(50),
241-
},
242-
{
243-
Name: "v1",
244-
Weight: pointer(50),
245-
},
246-
},
247-
},
248-
},
249-
want: "canary",
250-
},
251-
{
252-
name: "'random' distribution",
253-
model: &v1alpha2.InferenceModel{
254-
Spec: v1alpha2.InferenceModelSpec{
255-
TargetModels: []v1alpha2.TargetModel{
256-
{
257-
Name: "canary",
258-
Weight: pointer(25),
259-
},
260-
{
261-
Name: "v1.1",
262-
Weight: pointer(55),
263-
},
264-
{
265-
Name: "v1",
266-
Weight: pointer(50),
267-
},
268-
},
269-
},
270-
},
271-
want: "v1",
272-
},
273-
{
274-
name: "'random' distribution",
275-
model: &v1alpha2.InferenceModel{
276-
Spec: v1alpha2.InferenceModelSpec{
277-
TargetModels: []v1alpha2.TargetModel{
278-
{
279-
Name: "canary",
280-
Weight: pointer(20),
281-
},
282-
{
283-
Name: "v1.1",
284-
Weight: pointer(20),
285-
},
286-
{
287-
Name: "v1",
288-
Weight: pointer(10),
289-
},
290-
},
291-
},
292-
},
293-
want: "v1.1",
294-
},
295-
{
296-
name: "weighted distribution with weight unset",
297-
model: &v1alpha2.InferenceModel{
298-
Spec: v1alpha2.InferenceModelSpec{
299-
TargetModels: []v1alpha2.TargetModel{
300-
{
301-
Name: "canary",
302-
},
303-
{
304-
Name: "v1.1",
305-
},
306-
{
307-
Name: "v1",
308-
},
309-
},
310-
},
311-
},
312-
want: "canary",
313-
},
314-
}
315-
var seedVal int64 = 420
316-
for _, test := range tests {
317-
t.Run(test.name, func(t *testing.T) {
318-
for range 10000 {
319-
model := RandomWeightedDraw(logger, test.model, seedVal)
320-
if model != test.want {
321-
t.Errorf("Model returned: %v != %v", model, test.want)
322-
break
323-
}
324-
}
325-
})
326-
}
327-
}
328-
329-
func pointer(v int32) *int32 {
330-
return &v
331-
}
332-
333225
var (
334226
pod1 = &corev1.Pod{
335227
ObjectMeta: metav1.ObjectMeta{

pkg/epp/handlers/request.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (s *Server) HandleRequestBody(
6969
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
7070
}
7171
if len(modelObj.Spec.TargetModels) > 0 {
72-
modelName = datastore.RandomWeightedDraw(logger, modelObj, 0)
72+
modelName = RandomWeightedDraw(logger, modelObj, 0)
7373
if modelName == "" {
7474
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
7575
}

pkg/epp/handlers/response.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ func (s *Server) HandleResponseHeaders(
8585
if header.Key == "content-type" {
8686
contentType := header.RawValue
8787
if strings.Contains(string(contentType), "text/event-stream") {
88-
reqCtx.Streaming = true
89-
} else {
90-
reqCtx.Streaming = false
88+
reqCtx.modelServerStreaming = true
9189
}
9290
typeFound = true
9391
}
@@ -155,7 +153,7 @@ func (s *Server) HandleResponseBody(
155153
loggerVerbose := logger.V(logutil.VERBOSE)
156154
body := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
157155

158-
if reqCtx.Streaming {
156+
if reqCtx.modelServerStreaming {
159157
logger.V(logutil.DEBUG).Info("Processing HandleResponseBody")
160158
if err := s.HandleStreaming(ctx, reqCtx, body, loggerVerbose); err != nil {
161159
return nil, err
@@ -189,7 +187,7 @@ func (s *Server) HandleNonStreaming(
189187
if err := json.Unmarshal(body.ResponseBody.Body, &res); err != nil {
190188
return errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("unmarshaling response body: %v", err)}
191189
}
192-
reqCtx.Response = res
190+
reqCtx.Usage = res.Usage
193191
reqCtx.ResponseSize = len(body.ResponseBody.Body)
194192
reqCtx.ResponseComplete = true
195193
loggerVerbose.Info("Response generated", "response", res)
@@ -205,7 +203,7 @@ func (s *Server) HandleStreaming(
205203
responseText := string(body.ResponseBody.Body)
206204
if strings.Contains(responseText, streamingEndMsg) {
207205
parsedResp := ParseRespForUsage(ctx, responseText, loggerVerbose)
208-
reqCtx.Response = parsedResp
206+
reqCtx.Usage = parsedResp.Usage
209207
}
210208

211209
if body.ResponseBody.EndOfStream {

pkg/epp/handlers/response_test.go

+12-16
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestHandleResponseBody(t *testing.T) {
6565
name string
6666
req *extProcPb.ProcessingRequest_ResponseBody
6767
reqCtx *RequestContext
68-
want Response
68+
want Usage
6969
wantErr bool
7070
}{
7171
{
@@ -75,12 +75,10 @@ func TestHandleResponseBody(t *testing.T) {
7575
Body: []byte(body),
7676
},
7777
},
78-
want: Response{
79-
Usage: Usage{
80-
PromptTokens: 11,
81-
TotalTokens: 111,
82-
CompletionTokens: 100,
83-
},
78+
want: Usage{
79+
PromptTokens: 11,
80+
TotalTokens: 111,
81+
CompletionTokens: 100,
8482
},
8583
},
8684
{
@@ -100,7 +98,7 @@ func TestHandleResponseBody(t *testing.T) {
10098
},
10199
},
102100
reqCtx: &RequestContext{
103-
Streaming: true,
101+
modelServerStreaming: true,
104102
},
105103
wantErr: false,
106104
// In the middle of streaming response, so request context response is not set yet.
@@ -113,15 +111,13 @@ func TestHandleResponseBody(t *testing.T) {
113111
},
114112
},
115113
reqCtx: &RequestContext{
116-
Streaming: true,
114+
modelServerStreaming: true,
117115
},
118116
wantErr: false,
119-
want: Response{
120-
Usage: Usage{
121-
PromptTokens: 7,
122-
TotalTokens: 17,
123-
CompletionTokens: 10,
124-
},
117+
want: Usage{
118+
PromptTokens: 7,
119+
TotalTokens: 17,
120+
CompletionTokens: 10,
125121
},
126122
},
127123
}
@@ -141,7 +137,7 @@ func TestHandleResponseBody(t *testing.T) {
141137
return
142138
}
143139

144-
if diff := cmp.Diff(test.want, reqCtx.Response); diff != "" {
140+
if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
145141
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
146142
}
147143
})

pkg/epp/handlers/server.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
128128
reqCtx.ResponseCompleteTimestamp = time.Now()
129129
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
130130
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
131-
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.PromptTokens)
132-
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.CompletionTokens)
131+
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
132+
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
133133
}
134-
if reqCtx.Streaming {
134+
if reqCtx.modelServerStreaming {
135135
logger.V(logutil.DEBUG).Info("Request context after HandleResponseBody", "context", reqCtx)
136136
} else {
137137
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
@@ -149,7 +149,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
149149
}
150150
}
151151

152-
if !reqCtx.Streaming {
152+
if !reqCtx.modelServerStreaming {
153153
loggerVerbose.Info("Response generated", "response", resp)
154154
} else {
155155
logger.V(logutil.DEBUG).Info("Response generated", "response", resp)
@@ -224,9 +224,32 @@ type RequestContext struct {
224224
RequestReceivedTimestamp time.Time
225225
ResponseCompleteTimestamp time.Time
226226
RequestSize int
227-
Response Response
227+
Usage Usage
228228
ResponseSize int
229229
ResponseComplete bool
230230
ResponseStatusCode string
231-
Streaming bool
231+
232+
RequestState StreamRequestState
233+
modelServerStreaming bool
234+
235+
reqHeaderResp *extProcPb.ProcessingResponse
236+
reqBodyResp *extProcPb.ProcessingResponse
237+
reqTrailerResp *extProcPb.ProcessingResponse
238+
239+
respHeaderResp *extProcPb.ProcessingResponse
240+
respBodyResp *extProcPb.ProcessingResponse
241+
respTrailerResp *extProcPb.ProcessingResponse
232242
}
243+
244+
type StreamRequestState int
245+
246+
const (
247+
RequestReceived StreamRequestState = 0
248+
HeaderRequestResponseComplete StreamRequestState = 1
249+
BodyRequestResponsesComplete StreamRequestState = 2
250+
TrailerRequestResponsesComplete StreamRequestState = 3
251+
ResponseRecieved StreamRequestState = 4
252+
HeaderResponseResponseComplete StreamRequestState = 5
253+
BodyResponseResponsesComplete StreamRequestState = 6
254+
TrailerResponseResponsesComplete StreamRequestState = 7
255+
)

0 commit comments

Comments
 (0)