Skip to content

Commit bbc4468

Browse files
committed
Make functions more compatible with OpenAI specs
1 parent 4de7f55 commit bbc4468

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

api/openai.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ type Message struct {
7777
// The message role
7878
Role string `json:"role,omitempty" yaml:"role"`
7979
// The message content
80-
Content string `json:"content,omitempty" yaml:"content"`
80+
Content *string `json:"content" yaml:"content"`
8181
// A result of a function call
8282
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
8383
}
@@ -392,7 +392,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
392392
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
393393
resp := OpenAIResponse{
394394
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
395-
Choices: []Choice{{Delta: &Message{Content: s}, Index: 0}},
395+
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
396396
Object: "chat.completion.chunk",
397397
}
398398
log.Debug().Msgf("Sending goroutine: %s", s)
@@ -460,24 +460,29 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
460460
}
461461
}
462462
r := config.Roles[role]
463+
contentExists := i.Content != nil && *i.Content != ""
463464
if r != "" {
464-
content = fmt.Sprint(r, " ", i.Content)
465+
if contentExists {
466+
content = fmt.Sprint(r, " ", *i.Content)
467+
}
465468
if i.FunctionCall != nil {
466469
j, err := json.Marshal(i.FunctionCall)
467470
if err == nil {
468-
if i.Content != "" {
471+
if contentExists {
469472
content += "\n" + fmt.Sprint(r, " ", string(j))
470473
} else {
471474
content = fmt.Sprint(r, " ", string(j))
472475
}
473476
}
474477
}
475478
} else {
476-
content = i.Content
479+
if contentExists {
480+
content = fmt.Sprint(*i.Content)
481+
}
477482
if i.FunctionCall != nil {
478483
j, err := json.Marshal(i.FunctionCall)
479484
if err == nil {
480-
if i.Content != "" {
485+
if contentExists {
481486
content += "\n" + string(j)
482487
} else {
483488
content = string(j)
@@ -600,7 +605,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
600605
message = Finetune(*config, predInput, message)
601606
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
602607

603-
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: message}})
608+
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
604609
return
605610
}
606611
}
@@ -623,18 +628,18 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
623628
}
624629

625630
prediction = Finetune(*config, predInput, prediction)
626-
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: prediction}})
631+
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
627632
} else {
628633
// otherwise reply with the function call
629634
*c = append(*c, Choice{
630635
FinishReason: "function_call",
631-
Message: &Message{Role: "function", FunctionCall: ss},
636+
Message: &Message{Role: "assistant", FunctionCall: ss},
632637
})
633638
}
634639

635640
return
636641
}
637-
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
642+
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}})
638643
}, nil)
639644
if err != nil {
640645
return err

0 commit comments

Comments
 (0)