15
15
package main
16
16
17
17
import (
18
+ "bytes"
18
19
_ "embed"
19
20
"errors"
20
21
"fmt"
21
22
"net/http"
22
23
"strings"
23
24
"time"
25
+ "unicode"
24
26
25
27
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
26
28
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -492,8 +494,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
492
494
if references == "" {
493
495
return types .ActionContinue
494
496
}
495
- content := gjson .GetBytes (body , "choices.0.message.content" )
496
- modifiedContent := fmt .Sprintf ("%s\n \n %s" , fmt .Sprintf (config .referenceFormat , references ), content )
497
+ content := gjson .GetBytes (body , "choices.0.message.content" ).String ()
498
+ var modifiedContent string
499
+ if strings .HasPrefix (strings .TrimLeftFunc (content , unicode .IsSpace ), "<think>" ) {
500
+ thinkEnd := strings .Index (content , "</think>" )
501
+ if thinkEnd != - 1 {
502
+ modifiedContent = content [:thinkEnd + 8 ] +
503
+ fmt .Sprintf ("\n %s\n \n %s" , fmt .Sprintf (config .referenceFormat , references ), content [thinkEnd + 8 :])
504
+ }
505
+ }
506
+ if modifiedContent == "" {
507
+ modifiedContent = fmt .Sprintf ("%s\n \n %s" , fmt .Sprintf (config .referenceFormat , references ), content )
508
+ }
497
509
body , err := sjson .SetBytes (body , "choices.0.message.content" , modifiedContent )
498
510
if err != nil {
499
511
log .Errorf ("modify response message content failed, err:%v, body:%s" , err , body )
@@ -503,6 +515,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
503
515
return types .ActionContinue
504
516
}
505
517
518
+ func unifySSEChunk (data []byte ) []byte {
519
+ data = bytes .ReplaceAll (data , []byte ("\r \n " ), []byte ("\n " ))
520
+ data = bytes .ReplaceAll (data , []byte ("\r " ), []byte ("\n " ))
521
+ return data
522
+ }
523
+
524
+ const (
525
+ PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
526
+ BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
527
+ BUFFER_SIZE = 30
528
+ )
529
+
506
530
func onStreamingResponseBody (ctx wrapper.HttpContext , config Config , chunk []byte , isLastChunk bool , log wrapper.Log ) []byte {
507
531
if ctx .GetBoolContext ("ReferenceAppended" , false ) {
508
532
return chunk
@@ -511,58 +535,110 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
511
535
if references == "" {
512
536
return chunk
513
537
}
514
- modifiedChunk , responseReady := setReferencesToFirstMessage (ctx , chunk , fmt .Sprintf (config .referenceFormat , references ), log )
515
- if responseReady {
516
- ctx .SetContext ("ReferenceAppended" , true )
517
- return modifiedChunk
518
- } else {
519
- return []byte ("" )
520
- }
521
- }
522
-
523
- const PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
524
-
525
- func setReferencesToFirstMessage (ctx wrapper.HttpContext , chunk []byte , references string , log wrapper.Log ) ([]byte , bool ) {
526
- if len (chunk ) == 0 {
527
- log .Debugf ("chunk is empty" )
528
- return nil , false
529
- }
530
-
538
+ chunk = unifySSEChunk (chunk )
531
539
var partialMessage []byte
532
540
partialMessageI := ctx .GetContext (PARTIAL_MESSAGE_CONTEXT_KEY )
541
+ log .Debugf ("[handleStreamChunk] buffer content: %v" , ctx .GetContext (BUFFER_CONTENT_CONTEXT_KEY ))
533
542
if partialMessageI != nil {
534
- if pMsg , ok := partialMessageI .([]byte ); ok {
535
- partialMessage = append (pMsg , chunk ... )
536
- } else {
537
- log .Warnf ("invalid partial message type: %T" , partialMessageI )
538
- partialMessage = chunk
539
- }
543
+ partialMessage = append (partialMessageI .([]byte ), chunk ... )
540
544
} else {
541
545
partialMessage = chunk
542
546
}
547
+ messages := strings .Split (string (partialMessage ), "\n \n " )
548
+ var newMessages []string
549
+ for i , msg := range messages {
550
+ if i < len (messages )- 1 {
551
+ newMsg := processSSEMessage (ctx , msg , fmt .Sprintf (config .referenceFormat , references ), log )
552
+ if newMsg != "" {
553
+ newMessages = append (newMessages , newMsg )
554
+ }
555
+ }
556
+ }
557
+ if ! strings .HasSuffix (string (partialMessage ), "\n \n " ) {
558
+ ctx .SetContext (PARTIAL_MESSAGE_CONTEXT_KEY , []byte (messages [len (messages )- 1 ]))
559
+ } else {
560
+ ctx .SetContext (PARTIAL_MESSAGE_CONTEXT_KEY , nil )
561
+ }
562
+ if len (newMessages ) == 1 {
563
+ return []byte (fmt .Sprintf ("%s\n \n " , newMessages [0 ]))
564
+ } else if len (newMessages ) > 1 {
565
+ return []byte (strings .Join (newMessages , "\n \n " ))
566
+ } else {
567
+ return []byte ("" )
568
+ }
569
+ }
543
570
544
- if len (partialMessage ) == 0 {
545
- log .Debugf ("partial message is empty" )
546
- return nil , false
571
+ func processSSEMessage (ctx wrapper.HttpContext , sseMessage string , references string , log wrapper.Log ) string {
572
+ log .Debugf ("single sse message: %s" , sseMessage )
573
+ subMessages := strings .Split (sseMessage , "\n " )
574
+ var message string
575
+ for _ , msg := range subMessages {
576
+ if strings .HasPrefix (msg , "data:" ) {
577
+ message = msg
578
+ break
579
+ }
547
580
}
548
- messages := strings .Split (string (partialMessage ), "\n \n " )
549
- if len (messages ) > 1 {
550
- firstMessage := messages [0 ]
551
- log .Debugf ("first message: %s" , firstMessage )
552
- firstMessage = strings .TrimPrefix (firstMessage , "data:" )
553
- firstMessage = strings .TrimPrefix (firstMessage , " " )
554
- firstMessage = strings .TrimSuffix (firstMessage , "\n " )
555
- deltaContent := gjson .Get (firstMessage , "choices.0.delta.content" )
556
- modifiedMessage , err := sjson .Set (firstMessage , "choices.0.delta.content" , fmt .Sprintf ("%s\n \n %s" , references , deltaContent ))
581
+ if len (message ) < 6 {
582
+ log .Errorf ("[processSSEMessage] invalid message: %s" , message )
583
+ return sseMessage
584
+ }
585
+ // Skip the prefix "data:"
586
+ bodyJson := message [5 :]
587
+ if strings .TrimSpace (bodyJson ) == "[DONE]" {
588
+ return sseMessage
589
+ }
590
+ bodyJson = strings .TrimPrefix (bodyJson , " " )
591
+ bodyJson = strings .TrimSuffix (bodyJson , "\n " )
592
+ deltaContent := gjson .Get (bodyJson , "choices.0.delta.content" ).String ()
593
+ // Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
594
+ if deltaContent == "" {
595
+ return sseMessage
596
+ }
597
+ bufferContent := ctx .GetStringContext (BUFFER_CONTENT_CONTEXT_KEY , "" ) + deltaContent
598
+ if len (bufferContent ) < BUFFER_SIZE {
599
+ ctx .SetContext (BUFFER_CONTENT_CONTEXT_KEY , bufferContent )
600
+ return ""
601
+ }
602
+ if ! ctx .GetBoolContext ("FirstMessageChecked" , false ) {
603
+ ctx .SetContext ("FirstMessageChecked" , true )
604
+ if ! strings .Contains (strings .TrimLeftFunc (bufferContent , unicode .IsSpace ), "<think>" ) {
605
+ modifiedMessage , err := sjson .Set (bodyJson , "choices.0.delta.content" , fmt .Sprintf ("%s\n \n %s" , references , bufferContent ))
606
+ if err != nil {
607
+ log .Errorf ("update messsage failed:%s" , err )
608
+ }
609
+ ctx .SetContext ("ReferenceAppended" , true )
610
+ return fmt .Sprintf ("data: %s" , modifiedMessage )
611
+ }
612
+ }
613
+ // Content has <think> prefix
614
+ // Check for complete </think> tag
615
+ thinkEnd := strings .Index (bufferContent , "</think>" )
616
+ if thinkEnd != - 1 {
617
+ modifiedContent := bufferContent [:thinkEnd + 8 ] +
618
+ fmt .Sprintf ("\n %s\n \n %s" , references , bufferContent [thinkEnd + 8 :])
619
+ modifiedMessage , err := sjson .Set (bodyJson , "choices.0.delta.content" , modifiedContent )
557
620
if err != nil {
558
- log .Errorf ("modify response delta content failed, err:%v" , err )
559
- return partialMessage , true
621
+ log .Errorf ("update messsage failed:%s" , err )
560
622
}
561
- modifiedMessage = fmt .Sprintf ("data: %s" , modifiedMessage )
562
- log .Debugf ("modified message: %s" , firstMessage )
563
- messages [0 ] = string (modifiedMessage )
564
- return []byte (strings .Join (messages , "\n \n " )), true
623
+ ctx .SetContext ("ReferenceAppended" , true )
624
+ return fmt .Sprintf ("data: %s" , modifiedMessage )
565
625
}
566
- ctx .SetContext (PARTIAL_MESSAGE_CONTEXT_KEY , partialMessage )
567
- return nil , false
626
+
627
+ // Check for partial </think> tag at end of buffer
628
+ // Look for any partial match that could be completed in next message
629
+ for i := 1 ; i < len ("</think>" ); i ++ {
630
+ if strings .HasSuffix (bufferContent , "</think>" [:i ]) {
631
+ // Store only the partial match for the next message
632
+ ctx .SetContext (BUFFER_CONTENT_CONTEXT_KEY , bufferContent [len (bufferContent )- i :])
633
+ // Return the content before the partial match
634
+ modifiedMessage , err := sjson .Set (bodyJson , "choices.0.delta.content" , bufferContent [:len (bufferContent )- i ])
635
+ if err != nil {
636
+ log .Errorf ("update messsage failed:%s" , err )
637
+ }
638
+ return fmt .Sprintf ("data: %s" , modifiedMessage )
639
+ }
640
+ }
641
+
642
+ ctx .SetContext (BUFFER_CONTENT_CONTEXT_KEY , "" )
643
+ return sseMessage
568
644
}
0 commit comments