Skip to content

Commit b1f625a

Browse files
authored
feat: support baidu api key (#1687)
1 parent fd1eb54 commit b1f625a

File tree

5 files changed

+5
-270
lines changed

5 files changed

+5
-270
lines changed

plugins/wasm-go/extensions/ai-proxy/README.md

+1-9
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,7 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
157157

158158
#### 文心一言(Baidu)
159159

160-
文心一言所对应的 `type``baidu`。它特有的配置字段如下:
161-
162-
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
163-
|--------------------|-----------------|------|-----|-----------------------------------------------------------|
164-
| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 |
165-
| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 |
166-
| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 |
167-
| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 |
168-
160+
文心一言所对应的 `type``baidu`。它并无特有的配置字段。
169161

170162
#### 360智脑
171163

plugins/wasm-go/extensions/ai-proxy/config/config.go

-5
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,6 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
8686
providerConfig := c.GetProviderConfig()
8787
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
8888

89-
if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok {
90-
tickPeriod, tickFunc := handler.GetTickFunc(log)
91-
wrapper.RegisteTickFunc(tickPeriod, tickFunc)
92-
}
93-
9489
return err
9590
}
9691

Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
package provider
22

33
import (
4-
"crypto/hmac"
5-
"crypto/sha256"
6-
"encoding/hex"
7-
"encoding/json"
84
"errors"
9-
"fmt"
105
"net/http"
11-
"sort"
126
"strings"
13-
"time"
147

158
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
169
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
@@ -21,28 +14,14 @@ import (
2114
const (
2215
baiduDomain = "qianfan.baidubce.com"
2316
baiduChatCompletionPath = "/v2/chat/completions"
24-
baiduApiTokenDomain = "iam.bj.baidubce.com"
25-
baiduApiTokenPort = 443
26-
baiduApiTokenPath = "/v1/BCE-BEARER/token"
27-
// refresh apiToken every 1 hour
28-
baiduApiTokenRefreshInterval = 3600
29-
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
30-
// the default expiration time of apiToken is 24 hours
31-
baiduAuthorizationStringExpirationSeconds = 1800
32-
bce_prefix = "x-bce-"
3317
)
3418

3519
type baiduProviderInitializer struct{}
3620

3721
func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error {
38-
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
39-
return errors.New("no baiduAccessKeyAndSecret found in provider config")
22+
if config.apiTokens == nil || len(config.apiTokens) == 0 {
23+
return errors.New("no apiToken found in provider config")
4024
}
41-
if config.baiduApiTokenServiceName == "" {
42-
return errors.New("no baiduApiTokenServiceName found in provider config")
43-
}
44-
// baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs)
45-
config.useGlobalApiToken = true
4625
return nil
4726
}
4827

@@ -90,203 +69,3 @@ func (g *baiduProvider) GetApiName(path string) ApiName {
9069
}
9170
return ""
9271
}
93-
94-
func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
95-
c := strings.Split(accessKeyAndSecret, ":")
96-
credentials := BceCredentials{
97-
AccessKeyId: c[0],
98-
SecretAccessKey: c[1],
99-
}
100-
httpMethod := "GET"
101-
path := baiduApiTokenPath
102-
headers := map[string]string{"host": baiduApiTokenDomain}
103-
timestamp := time.Now().Unix()
104-
105-
headersToSign := make([]string, 0, len(headers))
106-
for k := range headers {
107-
headersToSign = append(headersToSign, k)
108-
}
109-
110-
return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
111-
}
112-
113-
// BceCredentials holds the access key and secret key
114-
type BceCredentials struct {
115-
AccessKeyId string
116-
SecretAccessKey string
117-
}
118-
119-
// normalizeString performs URI encoding according to RFC 3986
120-
func normalizeString(inStr string, encodingSlash bool) string {
121-
if inStr == "" {
122-
return ""
123-
}
124-
125-
var result strings.Builder
126-
for _, ch := range []byte(inStr) {
127-
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
128-
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
129-
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
130-
result.WriteByte(ch)
131-
} else {
132-
result.WriteString(fmt.Sprintf("%%%02X", ch))
133-
}
134-
}
135-
return result.String()
136-
}
137-
138-
// getCanonicalTime generates a timestamp in UTC format
139-
func getCanonicalTime(timestamp int64) string {
140-
if timestamp == 0 {
141-
timestamp = time.Now().Unix()
142-
}
143-
t := time.Unix(timestamp, 0).UTC()
144-
return t.Format("2006-01-02T15:04:05Z")
145-
}
146-
147-
// getCanonicalUri generates a canonical URI
148-
func getCanonicalUri(path string) string {
149-
return normalizeString(path, false)
150-
}
151-
152-
// getCanonicalHeaders generates canonical headers
153-
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
154-
if len(headers) == 0 {
155-
return ""
156-
}
157-
158-
// If headersToSign is not specified, use default headers
159-
if len(headersToSign) == 0 {
160-
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
161-
}
162-
163-
// Convert headersToSign to a map for easier lookup
164-
headerMap := make(map[string]bool)
165-
for _, header := range headersToSign {
166-
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
167-
}
168-
169-
// Create a slice to hold the canonical headers
170-
var canonicalHeaders []string
171-
for k, v := range headers {
172-
k = strings.ToLower(strings.TrimSpace(k))
173-
v = strings.TrimSpace(v)
174-
175-
// Add headers that start with x-bce- or are in headersToSign
176-
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
177-
canonicalHeaders = append(canonicalHeaders,
178-
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
179-
}
180-
}
181-
182-
// Sort the canonical headers
183-
sort.Strings(canonicalHeaders)
184-
185-
return strings.Join(canonicalHeaders, "\n")
186-
}
187-
188-
// sign generates the authorization string
189-
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
190-
timestamp int64, expirationInSeconds int,
191-
headersToSign []string) string {
192-
193-
// Generate sign key
194-
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
195-
credentials.AccessKeyId,
196-
getCanonicalTime(timestamp),
197-
expirationInSeconds)
198-
199-
// Generate sign key using HMAC-SHA256
200-
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
201-
h.Write([]byte(signKeyInfo))
202-
signKey := hex.EncodeToString(h.Sum(nil))
203-
204-
// Generate canonical URI
205-
canonicalUri := getCanonicalUri(path)
206-
207-
// Generate canonical headers
208-
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)
209-
210-
// Generate string to sign
211-
stringToSign := strings.Join([]string{
212-
httpMethod,
213-
canonicalUri,
214-
"",
215-
canonicalHeaders,
216-
}, "\n")
217-
218-
// Calculate final signature
219-
h = hmac.New(sha256.New, []byte(signKey))
220-
h.Write([]byte(stringToSign))
221-
signature := hex.EncodeToString(h.Sum(nil))
222-
223-
// Generate final authorization string
224-
if len(headersToSign) > 0 {
225-
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
226-
}
227-
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
228-
}
229-
230-
// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
231-
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
232-
vmID := generateVMID()
233-
234-
return baiduApiTokenRefreshInterval * 1000, func() {
235-
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
236-
if g.config.tryAcquireOrRenewLease(vmID, log) {
237-
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
238-
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
239-
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
240-
if err != nil {
241-
log.Errorf("Get old apiToken failed: %v", err)
242-
return
243-
}
244-
log.Debugf("Old apiTokens: %v", oldApiTokens)
245-
246-
for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
247-
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
248-
log.Debugf("Generate authorizationString: %v", authorizationString)
249-
g.generateNewApiToken(authorizationString, log)
250-
}
251-
252-
// remove old old apiToken
253-
for _, token := range oldApiTokens {
254-
log.Debugf("Remove old apiToken: %v", token)
255-
removeApiToken(g.config.failover.ctxApiTokens, token, log)
256-
}
257-
}
258-
}
259-
}
260-
261-
func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
262-
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
263-
FQDN: g.config.baiduApiTokenServiceName,
264-
Host: g.config.baiduApiTokenServiceHost,
265-
Port: g.config.baiduApiTokenServicePort,
266-
})
267-
268-
headers := [][2]string{
269-
{"content-type", "application/json"},
270-
{"Authorization", authorizationString},
271-
}
272-
273-
var apiToken string
274-
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
275-
if statusCode == 201 {
276-
var response map[string]interface{}
277-
err := json.Unmarshal(responseBody, &response)
278-
if err != nil {
279-
log.Errorf("Unmarshal response failed: %v", err)
280-
} else {
281-
apiToken = response["token"].(string)
282-
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
283-
}
284-
} else {
285-
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
286-
}
287-
}, 30000)
288-
289-
if err != nil {
290-
log.Errorf("Get apiToken failed: %v", err)
291-
}
292-
}

plugins/wasm-go/extensions/ai-proxy/provider/failover.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,8 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
557557

558558
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
559559
var apiToken string
560-
if c.isFailoverEnabled() || c.useGlobalApiToken {
561-
// if enable apiToken failover, only use available apiToken from global apiTokens list
562-
// or the apiToken need to be accessed globally (via all Wasm VMs, e.g. baidu),
560+
// if enable apiToken failover, only use available apiToken from global apiTokens list
561+
if c.isFailoverEnabled() {
563562
apiToken = c.GetGlobalRandomToken(log)
564563
} else {
565564
apiToken = c.GetRandomToken()

plugins/wasm-go/extensions/ai-proxy/provider/provider.go

-30
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ type TransformResponseBodyHandler interface {
155155
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
156156
}
157157

158-
// TickFuncHandler allows the provider to execute a function periodically
159-
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
160-
type TickFuncHandler interface {
161-
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
162-
}
163-
164158
type ProviderConfig struct {
165159
// @Title zh-CN ID
166160
// @Description zh-CN AI服务提供商标识
@@ -246,17 +240,6 @@ type ProviderConfig struct {
246240
// @Title zh-CN 自定义大模型参数配置
247241
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
248242
customSettings []CustomSetting
249-
// @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken
250-
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
251-
// @Title zh-CN 请求刷新百度 apiToken 服务名称
252-
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
253-
// @Title zh-CN 请求刷新百度 apiToken 服务域名
254-
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
255-
// @Title zh-CN 请求刷新百度 apiToken 服务端口
256-
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
257-
// @Title zh-CN 是否使用全局的 apiToken
258-
// @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新
259-
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
260243
}
261244

262245
func (c *ProviderConfig) GetId() string {
@@ -364,19 +347,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
364347
if retryOnFailureJson.Exists() {
365348
c.retryOnFailure.FromJson(retryOnFailureJson)
366349
}
367-
368-
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
369-
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
370-
}
371-
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
372-
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
373-
if c.baiduApiTokenServiceHost == "" {
374-
c.baiduApiTokenServiceHost = baiduApiTokenDomain
375-
}
376-
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
377-
if c.baiduApiTokenServicePort == 0 {
378-
c.baiduApiTokenServicePort = baiduApiTokenPort
379-
}
380350
}
381351

382352
func (c *ProviderConfig) Validate() error {

0 commit comments

Comments
 (0)