Skip to content

Commit 2b2b4a6

Browse files
authored
support extracting prompt from chat completions API (#798)
* support extracting prompt from chat completions API Signed-off-by: Hang Yin <[email protected]> * typo fixes Signed-off-by: Hang Yin <[email protected]> * fix tests * supply more tests and heading boilerplate Signed-off-by: Hang Yin <[email protected]> --------- Signed-off-by: Hang Yin <[email protected]>
1 parent 3d99aa1 commit 2b2b4a6

File tree

4 files changed

+359
-5
lines changed

4 files changed

+359
-5
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
6262
if !ok {
6363
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
6464
}
65-
prompt, ok := requestBodyMap["prompt"].(string)
66-
if !ok {
67-
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
65+
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
66+
if err != nil {
67+
return reqCtx, err
6868
}
6969

7070
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.

pkg/epp/requestcontrol/director_test.go

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func TestHandleRequest(t *testing.T) {
8585
wantRespBody map[string]interface{}
8686
}{
8787
{
88-
name: "successful request",
88+
name: "successful completions request",
8989
reqBodyMap: map[string]interface{}{
9090
"model": tsModel,
9191
"prompt": "test prompt",
@@ -102,7 +102,69 @@ func TestHandleRequest(t *testing.T) {
102102
},
103103
},
104104
{
105-
name: "successful request with target model",
105+
name: "successful chat completions request",
106+
reqBodyMap: map[string]interface{}{
107+
"model": tsModel,
108+
"messages": []interface{}{
109+
map[string]interface{}{
110+
"role": "user",
111+
"content": "test prompt",
112+
},
113+
},
114+
},
115+
wantReqCtx: &handlers.RequestContext{
116+
Model: tsModel,
117+
ResolvedTargetModel: tsModel,
118+
TargetPod: "/pod1",
119+
TargetEndpoint: "address-1:8000",
120+
},
121+
wantRespBody: map[string]interface{}{
122+
"model": tsModel,
123+
"messages": []interface{}{
124+
map[string]interface{}{
125+
"role": "user",
126+
"content": "test prompt",
127+
},
128+
},
129+
},
130+
},
131+
{
132+
name: "successful chat completions request with multiple messages",
133+
reqBodyMap: map[string]interface{}{
134+
"model": tsModel,
135+
"messages": []interface{}{
136+
map[string]interface{}{
137+
"role": "developer",
138+
"content": "You are a helpful assistant.",
139+
},
140+
map[string]interface{}{
141+
"role": "user",
142+
"content": "Hello!",
143+
},
144+
},
145+
},
146+
wantReqCtx: &handlers.RequestContext{
147+
Model: tsModel,
148+
ResolvedTargetModel: tsModel,
149+
TargetPod: "/pod1",
150+
TargetEndpoint: "address-1:8000",
151+
},
152+
wantRespBody: map[string]interface{}{
153+
"model": tsModel,
154+
"messages": []interface{}{
155+
map[string]interface{}{
156+
"role": "developer",
157+
"content": "You are a helpful assistant.",
158+
},
159+
map[string]interface{}{
160+
"role": "user",
161+
"content": "Hello!",
162+
},
163+
},
164+
},
165+
},
166+
{
167+
name: "successful completions request with target model",
106168
reqBodyMap: map[string]interface{}{
107169
"model": modelWithTarget,
108170
"prompt": "test prompt",
@@ -122,6 +184,21 @@ func TestHandleRequest(t *testing.T) {
122184
name: "no model defined, expect err",
123185
wantErrCode: errutil.BadRequest,
124186
},
187+
{
188+
name: "prompt or messages not found, expect err",
189+
reqBodyMap: map[string]interface{}{
190+
"model": tsModel,
191+
},
192+
wantErrCode: errutil.BadRequest,
193+
},
194+
{
195+
name: "empty messages, expect err",
196+
reqBodyMap: map[string]interface{}{
197+
"model": tsModel,
198+
"messages": []interface{}{},
199+
},
200+
wantErrCode: errutil.BadRequest,
201+
},
125202
{
126203
name: "invalid model defined, expect err",
127204
reqBodyMap: map[string]interface{}{

pkg/epp/util/request/body.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package request
18+
19+
import (
20+
"fmt"
21+
22+
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
23+
)
24+
25+
func ExtractPromptFromRequestBody(body map[string]interface{}) (string, error) {
26+
if _, ok := body["messages"]; ok {
27+
return extractPromptFromMessagesField(body)
28+
}
29+
return extractPromptField(body)
30+
}
31+
32+
func extractPromptField(body map[string]interface{}) (string, error) {
33+
prompt, ok := body["prompt"]
34+
if !ok {
35+
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
36+
}
37+
promptStr, ok := prompt.(string)
38+
if !ok {
39+
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt is not a string"}
40+
}
41+
return promptStr, nil
42+
}
43+
44+
func extractPromptFromMessagesField(body map[string]interface{}) (string, error) {
45+
messages, ok := body["messages"]
46+
if !ok {
47+
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"}
48+
}
49+
messageList, ok := messages.([]interface{})
50+
if !ok {
51+
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"}
52+
}
53+
if len(messageList) == 0 {
54+
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is empty"}
55+
}
56+
57+
prompt := ""
58+
for _, msg := range messageList {
59+
msgMap, ok := msg.(map[string]interface{})
60+
if !ok {
61+
continue
62+
}
63+
content, ok := msgMap["content"]
64+
if !ok {
65+
continue
66+
}
67+
contentStr, ok := content.(string)
68+
if !ok {
69+
continue
70+
}
71+
role, ok := msgMap["role"]
72+
if !ok {
73+
continue
74+
}
75+
roleStr, ok := role.(string)
76+
if !ok {
77+
continue
78+
}
79+
prompt += constructChatMessage(roleStr, contentStr)
80+
}
81+
return prompt, nil
82+
}
83+
84+
func constructChatMessage(role string, content string) string {
85+
return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>\n", role, content)
86+
}

pkg/epp/util/request/body_test.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package request
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestExtractPromptFromRequestBody(t *testing.T) {
8+
tests := []struct {
9+
name string
10+
body map[string]interface{}
11+
want string
12+
wantErr bool
13+
errType error
14+
}{
15+
{
16+
name: "chat completions request body",
17+
body: map[string]interface{}{
18+
"model": "test",
19+
"messages": []interface{}{
20+
map[string]interface{}{
21+
"role": "system", "content": "this is a system message",
22+
},
23+
map[string]interface{}{
24+
"role": "user", "content": "hello",
25+
},
26+
map[string]interface{}{
27+
"role": "assistant", "content": "hi, what can I do for you?",
28+
},
29+
},
30+
},
31+
want: "<|im_start|>system\nthis is a system message<|im_end|>\n" +
32+
"<|im_start|>user\nhello<|im_end|>\n" +
33+
"<|im_start|>assistant\nhi, what can I do for you?<|im_end|>\n",
34+
},
35+
{
36+
name: "completions request body",
37+
body: map[string]interface{}{
38+
"model": "test",
39+
"prompt": "test prompt",
40+
},
41+
want: "test prompt",
42+
},
43+
{
44+
name: "invalid prompt format",
45+
body: map[string]interface{}{
46+
"model": "test",
47+
"prompt": []interface{}{
48+
map[string]interface{}{
49+
"role": "system", "content": "this is a system message",
50+
},
51+
map[string]interface{}{
52+
"role": "user", "content": "hello",
53+
},
54+
map[string]interface{}{
55+
"role": "assistant", "content": "hi, what can I",
56+
},
57+
},
58+
},
59+
wantErr: true,
60+
},
61+
{
62+
name: "invalid messaged format",
63+
body: map[string]interface{}{
64+
"model": "test",
65+
"messages": map[string]interface{}{
66+
"role": "system", "content": "this is a system message",
67+
},
68+
},
69+
wantErr: true,
70+
},
71+
{
72+
name: "prompt does not exist",
73+
body: map[string]interface{}{
74+
"model": "test",
75+
},
76+
wantErr: true,
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
got, err := ExtractPromptFromRequestBody(tt.body)
83+
if (err != nil) != tt.wantErr {
84+
t.Errorf("ExtractPromptFromRequestBody() error = %v, wantErr %v", err, tt.wantErr)
85+
return
86+
}
87+
if got != tt.want {
88+
t.Errorf("ExtractPromptFromRequestBody() got = %v, want %v", got, tt.want)
89+
}
90+
})
91+
}
92+
}
93+
94+
func TestExtractPromptField(t *testing.T) {
95+
tests := []struct {
96+
name string
97+
body map[string]interface{}
98+
want string
99+
wantErr bool
100+
}{
101+
{
102+
name: "valid prompt",
103+
body: map[string]interface{}{
104+
"prompt": "test prompt",
105+
},
106+
want: "test prompt",
107+
},
108+
{
109+
name: "prompt not found",
110+
body: map[string]interface{}{},
111+
wantErr: true,
112+
},
113+
{
114+
name: "non-string prompt",
115+
body: map[string]interface{}{
116+
"prompt": 123,
117+
},
118+
wantErr: true,
119+
},
120+
}
121+
122+
for _, tt := range tests {
123+
t.Run(tt.name, func(t *testing.T) {
124+
got, err := extractPromptField(tt.body)
125+
if (err != nil) != tt.wantErr {
126+
t.Errorf("extractPromptField() error = %v, wantErr %v", err, tt.wantErr)
127+
return
128+
}
129+
if got != tt.want {
130+
t.Errorf("extractPromptField() got = %v, want %v", got, tt.want)
131+
}
132+
})
133+
}
134+
}
135+
136+
func TestExtractPromptFromMessagesField(t *testing.T) {
137+
tests := []struct {
138+
name string
139+
body map[string]interface{}
140+
want string
141+
wantErr bool
142+
}{
143+
{
144+
name: "valid messages",
145+
body: map[string]interface{}{
146+
"messages": []interface{}{
147+
map[string]interface{}{"role": "user", "content": "test1"},
148+
map[string]interface{}{"role": "assistant", "content": "test2"},
149+
},
150+
},
151+
want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n",
152+
},
153+
{
154+
name: "invalid messages format",
155+
body: map[string]interface{}{
156+
"messages": "invalid",
157+
},
158+
wantErr: true,
159+
},
160+
}
161+
162+
for _, tt := range tests {
163+
t.Run(tt.name, func(t *testing.T) {
164+
got, err := extractPromptFromMessagesField(tt.body)
165+
if (err != nil) != tt.wantErr {
166+
t.Errorf("extractPromptFromMessagesField() error = %v, wantErr %v", err, tt.wantErr)
167+
return
168+
}
169+
if got != tt.want {
170+
t.Errorf("extractPromptFromMessagesField() got = %v, want %v", got, tt.want)
171+
}
172+
})
173+
}
174+
}
175+
176+
func TestConstructChatMessage(t *testing.T) {
177+
tests := []struct {
178+
role string
179+
content string
180+
want string
181+
}{
182+
{"user", "hello", "<|im_start|>user\nhello<|im_end|>\n"},
183+
{"assistant", "hi", "<|im_start|>assistant\nhi<|im_end|>\n"},
184+
}
185+
186+
for _, tt := range tests {
187+
if got := constructChatMessage(tt.role, tt.content); got != tt.want {
188+
t.Errorf("constructChatMessage() = %v, want %v", got, tt.want)
189+
}
190+
}
191+
}

0 commit comments

Comments
 (0)