Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support baidu api key #1687

Merged
merged 2 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,7 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。

#### 文心一言(Baidu)

文心一言所对应的 `type` 为 `baidu`。它特有的配置字段如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------|-----|-----------------------------------------------------------|
| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 |
| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 |
| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 |
| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 |

文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。

#### 360智脑

Expand Down
5 changes: 0 additions & 5 deletions plugins/wasm-go/extensions/ai-proxy/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)

if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok {
tickPeriod, tickFunc := handler.GetTickFunc(log)
wrapper.RegisteTickFunc(tickPeriod, tickFunc)
}

return err
}

Expand Down
225 changes: 2 additions & 223 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
package provider

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"strings"
"time"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
Expand All @@ -21,28 +14,14 @@ import (
const (
baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions"
baiduApiTokenDomain = "iam.bj.baidubce.com"
baiduApiTokenPort = 443
baiduApiTokenPath = "/v1/BCE-BEARER/token"
// refresh apiToken every 1 hour
baiduApiTokenRefreshInterval = 3600
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
// the default expiration time of apiToken is 24 hours
baiduAuthorizationStringExpirationSeconds = 1800
bce_prefix = "x-bce-"
)

type baiduProviderInitializer struct{}

func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
return errors.New("no baiduAccessKeyAndSecret found in provider config")
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
if config.baiduApiTokenServiceName == "" {
return errors.New("no baiduApiTokenServiceName found in provider config")
}
// baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs)
config.useGlobalApiToken = true
return nil
}

Expand Down Expand Up @@ -90,203 +69,3 @@ func (g *baiduProvider) GetApiName(path string) ApiName {
}
return ""
}

func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
c := strings.Split(accessKeyAndSecret, ":")
credentials := BceCredentials{
AccessKeyId: c[0],
SecretAccessKey: c[1],
}
httpMethod := "GET"
path := baiduApiTokenPath
headers := map[string]string{"host": baiduApiTokenDomain}
timestamp := time.Now().Unix()

headersToSign := make([]string, 0, len(headers))
for k := range headers {
headersToSign = append(headersToSign, k)
}

return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
}

// BceCredentials holds the access key and secret key
type BceCredentials struct {
AccessKeyId string
SecretAccessKey string
}

// normalizeString performs URI encoding according to RFC 3986
func normalizeString(inStr string, encodingSlash bool) string {
if inStr == "" {
return ""
}

var result strings.Builder
for _, ch := range []byte(inStr) {
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
result.WriteByte(ch)
} else {
result.WriteString(fmt.Sprintf("%%%02X", ch))
}
}
return result.String()
}

// getCanonicalTime generates a timestamp in UTC format
func getCanonicalTime(timestamp int64) string {
if timestamp == 0 {
timestamp = time.Now().Unix()
}
t := time.Unix(timestamp, 0).UTC()
return t.Format("2006-01-02T15:04:05Z")
}

// getCanonicalUri generates a canonical URI
func getCanonicalUri(path string) string {
return normalizeString(path, false)
}

// getCanonicalHeaders generates canonical headers
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
if len(headers) == 0 {
return ""
}

// If headersToSign is not specified, use default headers
if len(headersToSign) == 0 {
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
}

// Convert headersToSign to a map for easier lookup
headerMap := make(map[string]bool)
for _, header := range headersToSign {
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
}

// Create a slice to hold the canonical headers
var canonicalHeaders []string
for k, v := range headers {
k = strings.ToLower(strings.TrimSpace(k))
v = strings.TrimSpace(v)

// Add headers that start with x-bce- or are in headersToSign
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
canonicalHeaders = append(canonicalHeaders,
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
}
}

// Sort the canonical headers
sort.Strings(canonicalHeaders)

return strings.Join(canonicalHeaders, "\n")
}

// sign generates the authorization string
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
timestamp int64, expirationInSeconds int,
headersToSign []string) string {

// Generate sign key
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
credentials.AccessKeyId,
getCanonicalTime(timestamp),
expirationInSeconds)

// Generate sign key using HMAC-SHA256
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
h.Write([]byte(signKeyInfo))
signKey := hex.EncodeToString(h.Sum(nil))

// Generate canonical URI
canonicalUri := getCanonicalUri(path)

// Generate canonical headers
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)

// Generate string to sign
stringToSign := strings.Join([]string{
httpMethod,
canonicalUri,
"",
canonicalHeaders,
}, "\n")

// Calculate final signature
h = hmac.New(sha256.New, []byte(signKey))
h.Write([]byte(stringToSign))
signature := hex.EncodeToString(h.Sum(nil))

// Generate final authorization string
if len(headersToSign) > 0 {
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
}
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
}

// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
vmID := generateVMID()

return baiduApiTokenRefreshInterval * 1000, func() {
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
if g.config.tryAcquireOrRenewLease(vmID, log) {
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
if err != nil {
log.Errorf("Get old apiToken failed: %v", err)
return
}
log.Debugf("Old apiTokens: %v", oldApiTokens)

for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
log.Debugf("Generate authorizationString: %v", authorizationString)
g.generateNewApiToken(authorizationString, log)
}

// remove old old apiToken
for _, token := range oldApiTokens {
log.Debugf("Remove old apiToken: %v", token)
removeApiToken(g.config.failover.ctxApiTokens, token, log)
}
}
}
}

func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: g.config.baiduApiTokenServiceName,
Host: g.config.baiduApiTokenServiceHost,
Port: g.config.baiduApiTokenServicePort,
})

headers := [][2]string{
{"content-type", "application/json"},
{"Authorization", authorizationString},
}

var apiToken string
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 201 {
var response map[string]interface{}
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("Unmarshal response failed: %v", err)
} else {
apiToken = response["token"].(string)
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
}
} else {
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
}
}, 30000)

if err != nil {
log.Errorf("Get apiToken failed: %v", err)
}
}
5 changes: 2 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,8 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {

func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
var apiToken string
if c.isFailoverEnabled() || c.useGlobalApiToken {
// if enable apiToken failover, only use available apiToken from global apiTokens list
// or the apiToken need to be accessed globally (via all Wasm VMs, e.g. baidu),
// if enable apiToken failover, only use available apiToken from global apiTokens list
if c.isFailoverEnabled() {
apiToken = c.GetGlobalRandomToken(log)
} else {
apiToken = c.GetRandomToken()
Expand Down
30 changes: 0 additions & 30 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,6 @@ type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}

// TickFuncHandler allows the provider to execute a function periodically
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
type TickFuncHandler interface {
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
}

type ProviderConfig struct {
// @Title zh-CN ID
// @Description zh-CN AI服务提供商标识
Expand Down Expand Up @@ -246,17 +240,6 @@ type ProviderConfig struct {
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
// @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
// @Title zh-CN 请求刷新百度 apiToken 服务名称
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
// @Title zh-CN 请求刷新百度 apiToken 服务域名
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
// @Title zh-CN 请求刷新百度 apiToken 服务端口
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
// @Title zh-CN 是否使用全局的 apiToken
// @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
}

func (c *ProviderConfig) GetId() string {
Expand Down Expand Up @@ -364,19 +347,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if retryOnFailureJson.Exists() {
c.retryOnFailure.FromJson(retryOnFailureJson)
}

for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
}
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
if c.baiduApiTokenServiceHost == "" {
c.baiduApiTokenServiceHost = baiduApiTokenDomain
}
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
if c.baiduApiTokenServicePort == 0 {
c.baiduApiTokenServicePort = baiduApiTokenPort
}
}

func (c *ProviderConfig) Validate() error {
Expand Down
Loading