@@ -303,83 +303,79 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
303
303
// fetches the tracing span value from the specified source.
304
304
func setAttributeBySource (ctx wrapper.HttpContext , config AIStatisticsConfig , source string , body []byte , log wrapper.Log ) {
305
305
for _ , attribute := range config .attributes {
306
- var key , value string
307
- var err error
306
+ var key string
307
+ var value interface {}
308
308
if source == attribute .ValueSource {
309
309
key = attribute .Key
310
310
switch source {
311
311
case FixedValue :
312
- log .Debugf ("[attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , attribute .Value )
313
312
value = attribute .Value
314
313
case RequestHeader :
315
- if value , err = proxywasm .GetHttpRequestHeader (attribute .Value ); err == nil {
316
- log .Debugf ("[attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , value )
317
- }
314
+ value , _ = proxywasm .GetHttpRequestHeader (attribute .Value )
318
315
case RequestBody :
319
- raw := gjson .GetBytes (body , attribute .Value ).Raw
320
- if len (raw ) > 2 {
321
- value = raw [1 : len (raw )- 1 ]
322
- }
323
- log .Debugf ("[attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , value )
316
+ value = gjson .GetBytes (body , attribute .Value ).Value ()
324
317
case ResponseHeader :
325
- if value , err = proxywasm .GetHttpResponseHeader (attribute .Value ); err == nil {
326
- log .Debugf ("[log attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , value )
327
- }
318
+ value , _ = proxywasm .GetHttpResponseHeader (attribute .Value )
328
319
case ResponseStreamingBody :
329
320
value = extractStreamingBodyByJsonPath (body , attribute .Value , attribute .Rule , log )
330
- log .Debugf ("[log attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , value )
331
321
case ResponseBody :
332
- value = gjson .GetBytes (body , attribute .Value ).String ()
333
- log .Debugf ("[log attribute] source type: %s, key: %s, value: %s" , source , attribute .Key , value )
322
+ value = gjson .GetBytes (body , attribute .Value ).Value ()
334
323
default :
335
324
}
325
+ log .Debugf ("[attribute] source type: %s, key: %s, value: %+v" , source , key , value )
336
326
if attribute .ApplyToLog {
337
327
ctx .SetUserAttribute (key , value )
338
328
}
329
+ // for metrics
330
+ if key == Model || key == InputToken || key == OutputToken {
331
+ ctx .SetContext (key , value )
332
+ }
339
333
if attribute .ApplyToSpan {
340
334
setSpanAttribute (key , value , log )
341
335
}
342
336
}
343
337
}
344
338
}
345
339
346
- func extractStreamingBodyByJsonPath (data []byte , jsonPath string , rule string , log wrapper.Log ) string {
340
+ func extractStreamingBodyByJsonPath (data []byte , jsonPath string , rule string , log wrapper.Log ) interface {} {
347
341
chunks := bytes .Split (bytes .TrimSpace (data ), []byte ("\n \n " ))
348
- var value string
342
+ var value interface {}
349
343
if rule == RuleFirst {
350
344
for _ , chunk := range chunks {
351
345
jsonObj := gjson .GetBytes (chunk , jsonPath )
352
346
if jsonObj .Exists () {
353
- value = jsonObj .String ()
347
+ value = jsonObj .Value ()
354
348
break
355
349
}
356
350
}
357
351
} else if rule == RuleReplace {
358
352
for _ , chunk := range chunks {
359
353
jsonObj := gjson .GetBytes (chunk , jsonPath )
360
354
if jsonObj .Exists () {
361
- value = jsonObj .String ()
355
+ value = jsonObj .Value ()
362
356
}
363
357
}
364
358
} else if rule == RuleAppend {
365
359
// extract llm response
360
+ var strValue string
366
361
for _ , chunk := range chunks {
367
362
jsonObj := gjson .GetBytes (chunk , jsonPath )
368
363
if jsonObj .Exists () {
369
- value += jsonObj .String ()
364
+ strValue += jsonObj .String ()
370
365
}
371
366
}
367
+ value = strValue
372
368
} else {
373
369
log .Errorf ("unsupported rule type: %s" , rule )
374
370
}
375
371
return value
376
372
}
377
373
378
374
// Set the tracing span with value.
379
- func setSpanAttribute (key , value string , log wrapper.Log ) {
375
+ func setSpanAttribute (key string , value interface {} , log wrapper.Log ) {
380
376
if value != "" {
381
377
traceSpanTag := wrapper .TraceSpanTagPrefix + key
382
- if e := proxywasm .SetProperty ([]string {traceSpanTag }, []byte (value )); e != nil {
378
+ if e := proxywasm .SetProperty ([]string {traceSpanTag }, []byte (fmt . Sprint ( value ) )); e != nil {
383
379
log .Warnf ("failed to set %s in filter state: %v" , traceSpanTag , e )
384
380
}
385
381
} else {
@@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) {
388
384
}
389
385
390
386
func writeMetric (ctx wrapper.HttpContext , config AIStatisticsConfig , log wrapper.Log ) {
391
- route := ctx .GetContext (RouteName ).(string )
392
- cluster := ctx .GetContext (ClusterName ).(string )
393
387
// Generate usage metrics
394
- var model string
395
- var inputToken , outputToken int64
388
+ var ok bool
389
+ var route , cluster , model string
390
+ var inputToken , outputToken uint64
391
+ route , ok = ctx .GetContext (RouteName ).(string )
392
+ if ! ok {
393
+ log .Warnf ("RouteName typd assert failed, skip metric record" )
394
+ return
395
+ }
396
+ cluster , ok = ctx .GetContext (ClusterName ).(string )
397
+ if ! ok {
398
+ log .Warnf ("ClusterName typd assert failed, skip metric record" )
399
+ return
400
+ }
396
401
if ctx .GetUserAttribute (Model ) == nil || ctx .GetUserAttribute (InputToken ) == nil || ctx .GetUserAttribute (OutputToken ) == nil {
397
402
log .Warnf ("get usage information failed, skip metric record" )
398
403
return
399
404
}
400
- model = ctx .GetUserAttribute (Model ).(string )
401
- inputToken = ctx .GetUserAttribute (InputToken ).(int64 )
402
- outputToken = ctx .GetUserAttribute (OutputToken ).(int64 )
405
+ model , ok = ctx .GetUserAttribute (Model ).(string )
406
+ if ! ok {
407
+ log .Warnf ("Model typd assert failed, skip metric record" )
408
+ return
409
+ }
410
+ inputToken , ok = convertToUInt (ctx .GetUserAttribute (InputToken ))
411
+ if ! ok {
412
+ log .Warnf ("InputToken typd assert failed, skip metric record" )
413
+ return
414
+ }
415
+ outputToken , ok = convertToUInt (ctx .GetUserAttribute (OutputToken ))
416
+ if ! ok {
417
+ log .Warnf ("OutputToken typd assert failed, skip metric record" )
418
+ return
419
+ }
403
420
if inputToken == 0 || outputToken == 0 {
404
421
log .Warnf ("inputToken and outputToken cannot equal to 0, skip metric record" )
405
422
return
406
423
}
407
- config .incrementCounter (generateMetricName (route , cluster , model , InputToken ), uint64 ( inputToken ) )
408
- config .incrementCounter (generateMetricName (route , cluster , model , OutputToken ), uint64 ( outputToken ) )
424
+ config .incrementCounter (generateMetricName (route , cluster , model , InputToken ), inputToken )
425
+ config .incrementCounter (generateMetricName (route , cluster , model , OutputToken ), outputToken )
409
426
410
427
// Generate duration metrics
411
- var llmFirstTokenDuration , llmServiceDuration int64
428
+ var llmFirstTokenDuration , llmServiceDuration uint64
412
429
// Is stream response
413
430
if ctx .GetUserAttribute (LLMFirstTokenDuration ) != nil {
414
- llmFirstTokenDuration = ctx .GetUserAttribute (LLMFirstTokenDuration ).(int64 )
415
- config .incrementCounter (generateMetricName (route , cluster , model , LLMFirstTokenDuration ), uint64 (llmFirstTokenDuration ))
431
+ llmFirstTokenDuration , ok = convertToUInt (ctx .GetUserAttribute (LLMFirstTokenDuration ))
432
+ if ! ok {
433
+ log .Warnf ("LLMFirstTokenDuration typd assert failed" )
434
+ return
435
+ }
436
+ config .incrementCounter (generateMetricName (route , cluster , model , LLMFirstTokenDuration ), llmFirstTokenDuration )
416
437
config .incrementCounter (generateMetricName (route , cluster , model , LLMStreamDurationCount ), 1 )
417
438
}
418
439
if ctx .GetUserAttribute (LLMServiceDuration ) != nil {
419
- llmServiceDuration = ctx .GetUserAttribute (LLMServiceDuration ).(int64 )
420
- config .incrementCounter (generateMetricName (route , cluster , model , LLMServiceDuration ), uint64 (llmServiceDuration ))
440
+ llmServiceDuration , ok = convertToUInt (ctx .GetUserAttribute (LLMServiceDuration ))
441
+ if ! ok {
442
+ log .Warnf ("LLMServiceDuration typd assert failed" )
443
+ return
444
+ }
445
+ config .incrementCounter (generateMetricName (route , cluster , model , LLMServiceDuration ), llmServiceDuration )
421
446
config .incrementCounter (generateMetricName (route , cluster , model , LLMDurationCount ), 1 )
422
447
}
423
448
}
449
+
450
+ func convertToUInt (val interface {}) (uint64 , bool ) {
451
+ switch v := val .(type ) {
452
+ case float32 :
453
+ return uint64 (v ), true
454
+ case float64 :
455
+ return uint64 (v ), true
456
+ case int32 :
457
+ return uint64 (v ), true
458
+ case int64 :
459
+ return uint64 (v ), true
460
+ case uint32 :
461
+ return uint64 (v ), true
462
+ case uint64 :
463
+ return v , true
464
+ default :
465
+ return 0 , false
466
+ }
467
+ }
0 commit comments