1
1
package provider
2
2
3
3
import (
4
- "crypto/hmac"
5
- "crypto/sha256"
6
- "encoding/hex"
7
- "encoding/json"
8
4
"errors"
9
- "fmt"
10
5
"net/http"
11
- "sort"
12
6
"strings"
13
- "time"
14
7
15
8
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
16
9
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
@@ -21,28 +14,14 @@ import (
21
14
const (
22
15
baiduDomain = "qianfan.baidubce.com"
23
16
baiduChatCompletionPath = "/v2/chat/completions"
24
- baiduApiTokenDomain = "iam.bj.baidubce.com"
25
- baiduApiTokenPort = 443
26
- baiduApiTokenPath = "/v1/BCE-BEARER/token"
27
- // refresh apiToken every 1 hour
28
- baiduApiTokenRefreshInterval = 3600
29
- // authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
30
- // the default expiration time of apiToken is 24 hours
31
- baiduAuthorizationStringExpirationSeconds = 1800
32
- bce_prefix = "x-bce-"
33
17
)
34
18
35
19
type baiduProviderInitializer struct {}
36
20
37
21
func (g * baiduProviderInitializer ) ValidateConfig (config * ProviderConfig ) error {
38
- if config .baiduAccessKeyAndSecret == nil || len (config .baiduAccessKeyAndSecret ) == 0 {
39
- return errors .New ("no baiduAccessKeyAndSecret found in provider config" )
22
+ if config .apiTokens == nil || len (config .apiTokens ) == 0 {
23
+ return errors .New ("no apiToken found in provider config" )
40
24
}
41
- if config .baiduApiTokenServiceName == "" {
42
- return errors .New ("no baiduApiTokenServiceName found in provider config" )
43
- }
44
- // baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs)
45
- config .useGlobalApiToken = true
46
25
return nil
47
26
}
48
27
@@ -90,203 +69,3 @@ func (g *baiduProvider) GetApiName(path string) ApiName {
90
69
}
91
70
return ""
92
71
}
93
-
94
- func generateAuthorizationString (accessKeyAndSecret string , expirationInSeconds int ) string {
95
- c := strings .Split (accessKeyAndSecret , ":" )
96
- credentials := BceCredentials {
97
- AccessKeyId : c [0 ],
98
- SecretAccessKey : c [1 ],
99
- }
100
- httpMethod := "GET"
101
- path := baiduApiTokenPath
102
- headers := map [string ]string {"host" : baiduApiTokenDomain }
103
- timestamp := time .Now ().Unix ()
104
-
105
- headersToSign := make ([]string , 0 , len (headers ))
106
- for k := range headers {
107
- headersToSign = append (headersToSign , k )
108
- }
109
-
110
- return sign (credentials , httpMethod , path , headers , timestamp , expirationInSeconds , headersToSign )
111
- }
112
-
113
- // BceCredentials holds the access key and secret key
114
- type BceCredentials struct {
115
- AccessKeyId string
116
- SecretAccessKey string
117
- }
118
-
119
- // normalizeString performs URI encoding according to RFC 3986
120
- func normalizeString (inStr string , encodingSlash bool ) string {
121
- if inStr == "" {
122
- return ""
123
- }
124
-
125
- var result strings.Builder
126
- for _ , ch := range []byte (inStr ) {
127
- if (ch >= 'A' && ch <= 'Z' ) || (ch >= 'a' && ch <= 'z' ) ||
128
- (ch >= '0' && ch <= '9' ) || ch == '.' || ch == '-' ||
129
- ch == '_' || ch == '~' || (! encodingSlash && ch == '/' ) {
130
- result .WriteByte (ch )
131
- } else {
132
- result .WriteString (fmt .Sprintf ("%%%02X" , ch ))
133
- }
134
- }
135
- return result .String ()
136
- }
137
-
138
- // getCanonicalTime generates a timestamp in UTC format
139
- func getCanonicalTime (timestamp int64 ) string {
140
- if timestamp == 0 {
141
- timestamp = time .Now ().Unix ()
142
- }
143
- t := time .Unix (timestamp , 0 ).UTC ()
144
- return t .Format ("2006-01-02T15:04:05Z" )
145
- }
146
-
147
- // getCanonicalUri generates a canonical URI
148
- func getCanonicalUri (path string ) string {
149
- return normalizeString (path , false )
150
- }
151
-
152
- // getCanonicalHeaders generates canonical headers
153
- func getCanonicalHeaders (headers map [string ]string , headersToSign []string ) string {
154
- if len (headers ) == 0 {
155
- return ""
156
- }
157
-
158
- // If headersToSign is not specified, use default headers
159
- if len (headersToSign ) == 0 {
160
- headersToSign = []string {"host" , "content-md5" , "content-length" , "content-type" }
161
- }
162
-
163
- // Convert headersToSign to a map for easier lookup
164
- headerMap := make (map [string ]bool )
165
- for _ , header := range headersToSign {
166
- headerMap [strings .ToLower (strings .TrimSpace (header ))] = true
167
- }
168
-
169
- // Create a slice to hold the canonical headers
170
- var canonicalHeaders []string
171
- for k , v := range headers {
172
- k = strings .ToLower (strings .TrimSpace (k ))
173
- v = strings .TrimSpace (v )
174
-
175
- // Add headers that start with x-bce- or are in headersToSign
176
- if strings .HasPrefix (k , bce_prefix ) || headerMap [k ] {
177
- canonicalHeaders = append (canonicalHeaders ,
178
- fmt .Sprintf ("%s:%s" , normalizeString (k , true ), normalizeString (v , true )))
179
- }
180
- }
181
-
182
- // Sort the canonical headers
183
- sort .Strings (canonicalHeaders )
184
-
185
- return strings .Join (canonicalHeaders , "\n " )
186
- }
187
-
188
- // sign generates the authorization string
189
- func sign (credentials BceCredentials , httpMethod , path string , headers map [string ]string ,
190
- timestamp int64 , expirationInSeconds int ,
191
- headersToSign []string ) string {
192
-
193
- // Generate sign key
194
- signKeyInfo := fmt .Sprintf ("bce-auth-v1/%s/%s/%d" ,
195
- credentials .AccessKeyId ,
196
- getCanonicalTime (timestamp ),
197
- expirationInSeconds )
198
-
199
- // Generate sign key using HMAC-SHA256
200
- h := hmac .New (sha256 .New , []byte (credentials .SecretAccessKey ))
201
- h .Write ([]byte (signKeyInfo ))
202
- signKey := hex .EncodeToString (h .Sum (nil ))
203
-
204
- // Generate canonical URI
205
- canonicalUri := getCanonicalUri (path )
206
-
207
- // Generate canonical headers
208
- canonicalHeaders := getCanonicalHeaders (headers , headersToSign )
209
-
210
- // Generate string to sign
211
- stringToSign := strings .Join ([]string {
212
- httpMethod ,
213
- canonicalUri ,
214
- "" ,
215
- canonicalHeaders ,
216
- }, "\n " )
217
-
218
- // Calculate final signature
219
- h = hmac .New (sha256 .New , []byte (signKey ))
220
- h .Write ([]byte (stringToSign ))
221
- signature := hex .EncodeToString (h .Sum (nil ))
222
-
223
- // Generate final authorization string
224
- if len (headersToSign ) > 0 {
225
- return fmt .Sprintf ("%s/%s/%s" , signKeyInfo , strings .Join (headersToSign , ";" ), signature )
226
- }
227
- return fmt .Sprintf ("%s//%s" , signKeyInfo , signature )
228
- }
229
-
230
- // GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
231
- func (g * baiduProvider ) GetTickFunc (log wrapper.Log ) (tickPeriod int64 , tickFunc func ()) {
232
- vmID := generateVMID ()
233
-
234
- return baiduApiTokenRefreshInterval * 1000 , func () {
235
- // Only the Wasm VM that successfully acquires the lease will refresh the apiToken
236
- if g .config .tryAcquireOrRenewLease (vmID , log ) {
237
- log .Debugf ("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v" , vmID )
238
- // Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
239
- oldApiTokens , _ , err := getApiTokens (g .config .failover .ctxApiTokens )
240
- if err != nil {
241
- log .Errorf ("Get old apiToken failed: %v" , err )
242
- return
243
- }
244
- log .Debugf ("Old apiTokens: %v" , oldApiTokens )
245
-
246
- for _ , accessKeyAndSecret := range g .config .baiduAccessKeyAndSecret {
247
- authorizationString := generateAuthorizationString (accessKeyAndSecret , baiduAuthorizationStringExpirationSeconds )
248
- log .Debugf ("Generate authorizationString: %v" , authorizationString )
249
- g .generateNewApiToken (authorizationString , log )
250
- }
251
-
252
- // remove old old apiToken
253
- for _ , token := range oldApiTokens {
254
- log .Debugf ("Remove old apiToken: %v" , token )
255
- removeApiToken (g .config .failover .ctxApiTokens , token , log )
256
- }
257
- }
258
- }
259
- }
260
-
261
- func (g * baiduProvider ) generateNewApiToken (authorizationString string , log wrapper.Log ) {
262
- client := wrapper .NewClusterClient (wrapper.FQDNCluster {
263
- FQDN : g .config .baiduApiTokenServiceName ,
264
- Host : g .config .baiduApiTokenServiceHost ,
265
- Port : g .config .baiduApiTokenServicePort ,
266
- })
267
-
268
- headers := [][2 ]string {
269
- {"content-type" , "application/json" },
270
- {"Authorization" , authorizationString },
271
- }
272
-
273
- var apiToken string
274
- err := client .Get (baiduApiTokenPath , headers , func (statusCode int , responseHeaders http.Header , responseBody []byte ) {
275
- if statusCode == 201 {
276
- var response map [string ]interface {}
277
- err := json .Unmarshal (responseBody , & response )
278
- if err != nil {
279
- log .Errorf ("Unmarshal response failed: %v" , err )
280
- } else {
281
- apiToken = response ["token" ].(string )
282
- addApiToken (g .config .failover .ctxApiTokens , apiToken , log )
283
- }
284
- } else {
285
- log .Errorf ("Get apiToken failed, status code: %d, response body: %s" , statusCode , string (responseBody ))
286
- }
287
- }, 30000 )
288
-
289
- if err != nil {
290
- log .Errorf ("Get apiToken failed: %v" , err )
291
- }
292
- }
0 commit comments