Skip to content

Commit 2aafc71

Browse files
rinfxyunmaoQu
authored andcommitted
qwen bailian compatible bug fix (alibaba#1597)
1 parent 50091fd commit 2aafc71

File tree

3 files changed

+90
-49
lines changed

3 files changed

+90
-49
lines changed

plugins/wasm-go/extensions/ai-proxy/provider/qwen.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const (
2727
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
2828
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
2929
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
30+
qwenBailianPath = "/api/v1/apps"
3031
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
3132

3233
qwenTopPMin = 0.000001
@@ -71,7 +72,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
7172
}
7273
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
7374

74-
if m.config.qwenEnableCompatible {
75+
if m.config.IsOriginal() {
76+
} else if m.config.qwenEnableCompatible {
7577
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
7678
} else if apiName == ApiNameChatCompletion {
7779
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
@@ -762,6 +764,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
762764
switch {
763765
case strings.Contains(path, qwenChatCompletionPath),
764766
strings.Contains(path, qwenMultimodalGenerationPath),
767+
strings.Contains(path, qwenBailianPath),
765768
strings.Contains(path, qwenCompatiblePath):
766769
return ApiNameChatCompletion
767770
case strings.Contains(path, qwenTextEmbeddingPath):

plugins/wasm-go/extensions/ai-security-guard/main.go

+4-10
Original file line numberDiff line numberDiff line change
@@ -384,26 +384,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
384384
ctx.DontReadResponseBody()
385385
return types.ActionContinue
386386
}
387-
headers, err := proxywasm.GetHttpResponseHeaders()
388-
if err != nil {
389-
log.Warnf("failed to get response headers: %v", err)
390-
return types.ActionContinue
391-
}
392-
hdsMap := convertHeaders(headers)
393-
if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
387+
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
388+
if statusCode != "200" {
394389
log.Debugf("response is not 200, skip response body check")
395390
ctx.DontReadResponseBody()
396391
return types.ActionContinue
397392
}
398-
ctx.SetContext("headers", hdsMap)
399393
return types.HeaderStopIteration
400394
}
401395

402396
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
403397
log.Debugf("checking response body...")
404398
startTime := time.Now().UnixMilli()
405-
hdsMap := ctx.GetContext("headers").(map[string][]string)
406-
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
399+
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
400+
isStreamingResponse := strings.Contains(contentType, "event-stream")
407401
model := ctx.GetStringContext("requestModel", "unknown")
408402
var content string
409403
if isStreamingResponse {

plugins/wasm-go/extensions/ai-statistics/main.go

+82-38
Original file line numberDiff line numberDiff line change
@@ -303,83 +303,79 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
303303
// fetches the tracing span value from the specified source.
304304
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
305305
for _, attribute := range config.attributes {
306-
var key, value string
307-
var err error
306+
var key string
307+
var value interface{}
308308
if source == attribute.ValueSource {
309309
key = attribute.Key
310310
switch source {
311311
case FixedValue:
312-
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
313312
value = attribute.Value
314313
case RequestHeader:
315-
if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil {
316-
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
317-
}
314+
value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
318315
case RequestBody:
319-
raw := gjson.GetBytes(body, attribute.Value).Raw
320-
if len(raw) > 2 {
321-
value = raw[1 : len(raw)-1]
322-
}
323-
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
316+
value = gjson.GetBytes(body, attribute.Value).Value()
324317
case ResponseHeader:
325-
if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil {
326-
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
327-
}
318+
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
328319
case ResponseStreamingBody:
329320
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
330-
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
331321
case ResponseBody:
332-
value = gjson.GetBytes(body, attribute.Value).String()
333-
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
322+
value = gjson.GetBytes(body, attribute.Value).Value()
334323
default:
335324
}
325+
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
336326
if attribute.ApplyToLog {
337327
ctx.SetUserAttribute(key, value)
338328
}
329+
// for metrics
330+
if key == Model || key == InputToken || key == OutputToken {
331+
ctx.SetContext(key, value)
332+
}
339333
if attribute.ApplyToSpan {
340334
setSpanAttribute(key, value, log)
341335
}
342336
}
343337
}
344338
}
345339

346-
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string {
340+
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
347341
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
348-
var value string
342+
var value interface{}
349343
if rule == RuleFirst {
350344
for _, chunk := range chunks {
351345
jsonObj := gjson.GetBytes(chunk, jsonPath)
352346
if jsonObj.Exists() {
353-
value = jsonObj.String()
347+
value = jsonObj.Value()
354348
break
355349
}
356350
}
357351
} else if rule == RuleReplace {
358352
for _, chunk := range chunks {
359353
jsonObj := gjson.GetBytes(chunk, jsonPath)
360354
if jsonObj.Exists() {
361-
value = jsonObj.String()
355+
value = jsonObj.Value()
362356
}
363357
}
364358
} else if rule == RuleAppend {
365359
// extract llm response
360+
var strValue string
366361
for _, chunk := range chunks {
367362
jsonObj := gjson.GetBytes(chunk, jsonPath)
368363
if jsonObj.Exists() {
369-
value += jsonObj.String()
364+
strValue += jsonObj.String()
370365
}
371366
}
367+
value = strValue
372368
} else {
373369
log.Errorf("unsupported rule type: %s", rule)
374370
}
375371
return value
376372
}
377373

378374
// Set the tracing span with value.
379-
func setSpanAttribute(key, value string, log wrapper.Log) {
375+
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
380376
if value != "" {
381377
traceSpanTag := wrapper.TraceSpanTagPrefix + key
382-
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil {
378+
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
383379
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
384380
}
385381
} else {
@@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) {
388384
}
389385

390386
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
391-
route := ctx.GetContext(RouteName).(string)
392-
cluster := ctx.GetContext(ClusterName).(string)
393387
// Generate usage metrics
394-
var model string
395-
var inputToken, outputToken int64
388+
var ok bool
389+
var route, cluster, model string
390+
var inputToken, outputToken uint64
391+
route, ok = ctx.GetContext(RouteName).(string)
392+
if !ok {
393+
log.Warnf("RouteName typd assert failed, skip metric record")
394+
return
395+
}
396+
cluster, ok = ctx.GetContext(ClusterName).(string)
397+
if !ok {
398+
log.Warnf("ClusterName typd assert failed, skip metric record")
399+
return
400+
}
396401
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
397402
log.Warnf("get usage information failed, skip metric record")
398403
return
399404
}
400-
model = ctx.GetUserAttribute(Model).(string)
401-
inputToken = ctx.GetUserAttribute(InputToken).(int64)
402-
outputToken = ctx.GetUserAttribute(OutputToken).(int64)
405+
model, ok = ctx.GetUserAttribute(Model).(string)
406+
if !ok {
407+
log.Warnf("Model typd assert failed, skip metric record")
408+
return
409+
}
410+
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
411+
if !ok {
412+
log.Warnf("InputToken typd assert failed, skip metric record")
413+
return
414+
}
415+
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
416+
if !ok {
417+
log.Warnf("OutputToken typd assert failed, skip metric record")
418+
return
419+
}
403420
if inputToken == 0 || outputToken == 0 {
404421
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
405422
return
406423
}
407-
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken))
408-
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken))
424+
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
425+
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
409426

410427
// Generate duration metrics
411-
var llmFirstTokenDuration, llmServiceDuration int64
428+
var llmFirstTokenDuration, llmServiceDuration uint64
412429
// Is stream response
413430
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
414-
llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64)
415-
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration))
431+
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
432+
if !ok {
433+
log.Warnf("LLMFirstTokenDuration typd assert failed")
434+
return
435+
}
436+
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
416437
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
417438
}
418439
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
419-
llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64)
420-
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration))
440+
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
441+
if !ok {
442+
log.Warnf("LLMServiceDuration typd assert failed")
443+
return
444+
}
445+
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
421446
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
422447
}
423448
}
449+
450+
func convertToUInt(val interface{}) (uint64, bool) {
451+
switch v := val.(type) {
452+
case float32:
453+
return uint64(v), true
454+
case float64:
455+
return uint64(v), true
456+
case int32:
457+
return uint64(v), true
458+
case int64:
459+
return uint64(v), true
460+
case uint32:
461+
return uint64(v), true
462+
case uint64:
463+
return v, true
464+
default:
465+
return 0, false
466+
}
467+
}

0 commit comments

Comments
 (0)