Skip to content

Commit a84a382

Browse files
authored
feature: allow ai-proxy to forward standard AI capabilities that are … (#1704)
1 parent 477e44b commit a84a382

32 files changed

+517
-158
lines changed

plugins/wasm-go/extensions/ai-proxy/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ description: AI 代理插件配置参考
4242
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
4343
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
4444
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
45-
45+
| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
46+
| `passthrough` | bool | 非必填 | - | 只要是不支持的API能力都直接转发, 此配置是capabilities配置的放大版本,允许任意api透传,就像没有ai-proxy插件一样 |
4647
`context`的配置字段说明如下:
4748

4849
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |

plugins/wasm-go/extensions/ai-proxy/main.go

+31-11
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
7878

7979
rawPath := ctx.Path()
8080
path, _ := url.Parse(rawPath)
81-
apiName := getOpenAiApiName(path.Path)
81+
apiName := getApiName(path.Path)
8282
providerConfig := pluginConfig.GetProviderConfig()
8383
if providerConfig.IsOriginal() {
8484
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
@@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
103103
// Set the apiToken for the current request.
104104
providerConfig.SetApiTokenInUse(ctx, log)
105105

106-
hasRequestBody := wrapper.HasRequestBody()
107106
err := handler.OnRequestHeaders(ctx, apiName, log)
108-
if err == nil {
109-
if hasRequestBody {
110-
proxywasm.RemoveHttpRequestHeader("Content-Length")
111-
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
112-
// Delay the header processing to allow changing in OnRequestBody
113-
return types.HeaderStopIteration
107+
if err != nil {
108+
if providerConfig.PassthroughUnsupportedAPI() {
109+
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
110+
ctx.DontReadRequestBody()
111+
return types.ActionContinue
114112
}
115-
ctx.DontReadRequestBody()
113+
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
116114
return types.ActionContinue
117115
}
118116

119-
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
117+
hasRequestBody := wrapper.HasRequestBody()
118+
if hasRequestBody {
119+
proxywasm.RemoveHttpRequestHeader("Content-Length")
120+
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
121+
// Delay the header processing to allow changing in OnRequestBody
122+
return types.HeaderStopIteration
123+
}
124+
ctx.DontReadRequestBody()
120125
return types.ActionContinue
121126
}
122127

@@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
151156
if err == nil {
152157
return action
153158
}
159+
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
160+
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
161+
return types.ActionContinue
162+
}
154163
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
155164
}
156165
return types.ActionContinue
@@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
267276
}
268277
}
269278

270-
func getOpenAiApiName(path string) provider.ApiName {
279+
func getApiName(path string) provider.ApiName {
280+
// openai style
271281
if strings.HasSuffix(path, "/v1/chat/completions") {
272282
return provider.ApiNameChatCompletion
273283
}
274284
if strings.HasSuffix(path, "/v1/embeddings") {
275285
return provider.ApiNameEmbeddings
276286
}
287+
if strings.HasSuffix(path, "/v1/audio/speech") {
288+
return provider.ApiNameAudioSpeech
289+
}
290+
if strings.HasSuffix(path, "/v1/images/generations") {
291+
return provider.ApiNameImageGeneration
292+
}
293+
// cohere style
294+
if strings.HasSuffix(path, "/v1/rerank") {
295+
return provider.ApiNameCohereV1Rerank
296+
}
277297
return ""
278298
}

plugins/wasm-go/extensions/ai-proxy/provider/ai360.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ type ai360Provider struct {
2222
contextCache *contextCache
2323
}
2424

25+
func (m *ai360ProviderInitializer) DefaultCapabilities() map[string]string {
26+
return map[string]string{
27+
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
28+
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
29+
}
30+
}
31+
2532
func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
2633
if config.apiTokens == nil || len(config.apiTokens) == 0 {
2734
return errors.New("no apiToken found in provider config")
@@ -30,6 +37,7 @@ func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error
3037
}
3138

3239
func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
40+
config.setDefaultCapabilities(m.DefaultCapabilities())
3341
return &ai360Provider{
3442
config: config,
3543
contextCache: createContextCache(&config),
@@ -41,7 +49,7 @@ func (m *ai360Provider) GetProviderType() string {
4149
}
4250

4351
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
44-
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
52+
if !m.config.isSupportedAPI(apiName) {
4553
return errUnsupportedApiName
4654
}
4755
m.config.handleRequestHeaders(m, ctx, apiName, log)
@@ -50,13 +58,14 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
5058
}
5159

5260
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
53-
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
61+
if !m.config.isSupportedAPI(apiName) {
5462
return types.ActionContinue, errUnsupportedApiName
5563
}
5664
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
5765
}
5866

5967
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
6068
util.OverwriteRequestHostHeader(headers, ai360Domain)
69+
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
6170
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
6271
}

plugins/wasm-go/extensions/ai-proxy/provider/azure.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ import (
1515
type azureProviderInitializer struct {
1616
}
1717

18+
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
19+
return map[string]string{
20+
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
21+
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
22+
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
23+
}
24+
}
25+
1826
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
1927
if config.azureServiceUrl == "" {
2028
return errors.New("missing azureServiceUrl in provider config")
@@ -35,6 +43,7 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
3543
} else {
3644
serviceUrl = u
3745
}
46+
config.setDefaultCapabilities(m.DefaultCapabilities())
3847
return &azureProvider{
3948
config: config,
4049
serviceUrl: serviceUrl,
@@ -54,15 +63,15 @@ func (m *azureProvider) GetProviderType() string {
5463
}
5564

5665
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
57-
if apiName != ApiNameChatCompletion {
66+
if !m.config.isSupportedAPI(apiName) {
5867
return errUnsupportedApiName
5968
}
6069
m.config.handleRequestHeaders(m, ctx, apiName, log)
6170
return nil
6271
}
6372

6473
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
65-
if apiName != ApiNameChatCompletion {
74+
if !m.config.isSupportedAPI(apiName) {
6675
return types.ActionContinue, errUnsupportedApiName
6776
}
6877
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)

plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import (
1212
// baichuanProvider is the provider for baichuan Ai service.
1313

1414
const (
15-
baichuanDomain = "api.baichuan-ai.com"
16-
baichuanChatCompletionPath = "/v1/chat/completions"
15+
baichuanDomain = "api.baichuan-ai.com"
1716
)
1817

1918
type baichuanProviderInitializer struct {
@@ -26,7 +25,15 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err
2625
return nil
2726
}
2827

28+
func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
29+
return map[string]string{
30+
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
31+
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
32+
}
33+
}
34+
2935
func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
36+
config.setDefaultCapabilities(m.DefaultCapabilities())
3037
return &baichuanProvider{
3138
config: config,
3239
contextCache: createContextCache(&config),
@@ -43,22 +50,22 @@ func (m *baichuanProvider) GetProviderType() string {
4350
}
4451

4552
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
46-
if apiName != ApiNameChatCompletion {
53+
if !m.config.isSupportedAPI(apiName) {
4754
return errUnsupportedApiName
4855
}
4956
m.config.handleRequestHeaders(m, ctx, apiName, log)
5057
return nil
5158
}
5259

5360
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
54-
if apiName != ApiNameChatCompletion {
61+
if !m.config.isSupportedAPI(apiName) {
5562
return types.ActionContinue, errUnsupportedApiName
5663
}
5764
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
5865
}
5966

6067
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
61-
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
68+
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
6269
util.OverwriteRequestHostHeader(headers, baichuanDomain)
6370
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
6471
headers.Del("Content-Length")

plugins/wasm-go/extensions/ai-proxy/provider/baidu.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
const (
1515
baiduDomain = "qianfan.baidubce.com"
1616
baiduChatCompletionPath = "/v2/chat/completions"
17+
baiduEmbeddings = "/v2/embeddings"
1718
)
1819

1920
type baiduProviderInitializer struct{}
@@ -25,7 +26,15 @@ func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error
2526
return nil
2627
}
2728

29+
func (g *baiduProviderInitializer) DefaultCapabilities() map[string]string {
30+
return map[string]string{
31+
string(ApiNameChatCompletion): baiduChatCompletionPath,
32+
string(ApiNameEmbeddings): baiduEmbeddings,
33+
}
34+
}
35+
2836
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
37+
config.setDefaultCapabilities(g.DefaultCapabilities())
2938
return &baiduProvider{
3039
config: config,
3140
contextCache: createContextCache(&config),
@@ -42,22 +51,22 @@ func (g *baiduProvider) GetProviderType() string {
4251
}
4352

4453
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
45-
if apiName != ApiNameChatCompletion {
54+
if !g.config.isSupportedAPI(apiName) {
4655
return errUnsupportedApiName
4756
}
4857
g.config.handleRequestHeaders(g, ctx, apiName, log)
4958
return nil
5059
}
5160

5261
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
53-
if apiName != ApiNameChatCompletion {
62+
if !g.config.isSupportedAPI(apiName) {
5463
return types.ActionContinue, errUnsupportedApiName
5564
}
5665
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
5766
}
5867

5968
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
60-
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
69+
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
6170
util.OverwriteRequestHostHeader(headers, baiduDomain)
6271
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
6372
headers.Del("Content-Length")

plugins/wasm-go/extensions/ai-proxy/provider/claude.go

+22-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,16 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
8585
return nil
8686
}
8787

88+
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
89+
return map[string]string{
90+
string(ApiNameChatCompletion): claudeChatCompletionPath,
91+
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
92+
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
93+
}
94+
}
95+
8896
func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
97+
config.setDefaultCapabilities(c.DefaultCapabilities())
8998
return &claudeProvider{
9099
config: config,
91100
contextCache: createContextCache(&config),
@@ -102,15 +111,15 @@ func (c *claudeProvider) GetProviderType() string {
102111
}
103112

104113
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
105-
if apiName != ApiNameChatCompletion {
114+
if !c.config.isSupportedAPI(apiName) {
106115
return errUnsupportedApiName
107116
}
108117
c.config.handleRequestHeaders(c, ctx, apiName, log)
109118
return nil
110119
}
111120

112121
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
113-
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
122+
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
114123
util.OverwriteRequestHostHeader(headers, claudeDomain)
115124

116125
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
@@ -123,13 +132,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
123132
}
124133

125134
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
126-
if apiName != ApiNameChatCompletion {
135+
if !c.config.isSupportedAPI(apiName) {
127136
return types.ActionContinue, errUnsupportedApiName
128137
}
129138
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
130139
}
131140

132141
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
142+
if apiName != ApiNameChatCompletion {
143+
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
144+
}
133145
request := &chatCompletionRequest{}
134146
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
135147
return nil, err
@@ -139,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
139151
}
140152

141153
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
154+
if apiName != ApiNameChatCompletion {
155+
return body, nil
156+
}
142157
claudeResponse := &claudeTextGenResponse{}
143158
if err := json.Unmarshal(body, claudeResponse); err != nil {
144159
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
@@ -154,6 +169,10 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
154169
if isLastChunk || len(chunk) == 0 {
155170
return nil, nil
156171
}
172+
// only process the response from chat completion, skip other responses
173+
if name != ApiNameChatCompletion {
174+
return chunk, nil
175+
}
157176

158177
responseBuilder := &strings.Builder{}
159178
lines := strings.Split(string(chunk), "\n")

plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,14 @@ func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) e
2525
}
2626
return nil
2727
}
28+
func (c *cloudflareProviderInitializer) DefaultCapabilities() map[string]string {
29+
return map[string]string{
30+
string(ApiNameChatCompletion): cloudflareChatCompletionPath,
31+
}
32+
}
2833

2934
func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
35+
config.setDefaultCapabilities(c.DefaultCapabilities())
3036
return &cloudflareProvider{
3137
config: config,
3238
contextCache: createContextCache(&config),
@@ -43,15 +49,15 @@ func (c *cloudflareProvider) GetProviderType() string {
4349
}
4450

4551
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
46-
if apiName != ApiNameChatCompletion {
52+
if !c.config.isSupportedAPI(apiName) {
4753
return errUnsupportedApiName
4854
}
4955
c.config.handleRequestHeaders(c, ctx, apiName, log)
5056
return nil
5157
}
5258

5359
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
54-
if apiName != ApiNameChatCompletion {
60+
if !c.config.isSupportedAPI(apiName) {
5561
return types.ActionContinue, errUnsupportedApiName
5662
}
5763
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)

0 commit comments

Comments
 (0)