-
Notifications
You must be signed in to change notification settings - Fork 295
enhance: avoid context limit #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,8 +2,10 @@ package openai | |||||
|
||||||
import ( | ||||||
"context" | ||||||
"errors" | ||||||
"io" | ||||||
"log/slog" | ||||||
"math" | ||||||
"os" | ||||||
"slices" | ||||||
"sort" | ||||||
|
@@ -24,6 +26,7 @@ import ( | |||||
const ( | ||||||
DefaultModel = openai.GPT4o | ||||||
BuiltinCredName = "sys.openai" | ||||||
TooLongMessage = "Error: tool call output is too long" | ||||||
) | ||||||
|
||||||
var ( | ||||||
|
@@ -317,6 +320,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques | |||||
} | ||||||
|
||||||
if messageRequest.Chat { | ||||||
// Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it. | ||||||
lastMessage := msgs[len(msgs)-1] | ||||||
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(math.Round(float64(getBudget(messageRequest.MaxTokens))*0.8)) { | ||||||
// We need to update it in the msgs slice for right now and in the messageRequest for future calls. | ||||||
msgs[len(msgs)-1].Content = TooLongMessage | ||||||
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) | ||||||
} | ||||||
|
||||||
msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) | ||||||
} | ||||||
|
||||||
|
@@ -383,6 +394,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques | |||||
return nil, err | ||||||
} else if !ok { | ||||||
response, err = c.call(ctx, request, id, status) | ||||||
|
||||||
// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. | ||||||
var apiError *openai.APIError | ||||||
if err != nil && errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||
// Decrease maxTokens by 10% to make garbage collection more aggressive. | ||||||
// The retry loop will further decrease maxTokens if needed. | ||||||
maxTokens := decreaseTenPercent(messageRequest.MaxTokens) | ||||||
response, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status) | ||||||
} | ||||||
|
||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
|
@@ -421,6 +442,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques | |||||
return &result, nil | ||||||
} | ||||||
|
||||||
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) ([]openai.ChatCompletionStreamResponse, error) { | ||||||
var ( | ||||||
response []openai.ChatCompletionStreamResponse | ||||||
err error | ||||||
) | ||||||
|
||||||
for range 10 { // maximum 10 tries | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the first use in our code base? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so! |
||||||
// Try to drop older messages again, with a decreased max tokens. | ||||||
request.Messages = dropMessagesOverCount(maxTokens, request.Messages) | ||||||
response, err = c.call(ctx, request, id, status) | ||||||
if err == nil { | ||||||
return response, nil | ||||||
} | ||||||
|
||||||
var apiError *openai.APIError | ||||||
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" { | ||||||
// Decrease maxTokens and try again | ||||||
maxTokens = decreaseTenPercent(maxTokens) | ||||||
continue | ||||||
} | ||||||
return nil, err | ||||||
} | ||||||
|
||||||
return nil, err | ||||||
} | ||||||
|
||||||
func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionStreamResponse) types.CompletionMessage { | ||||||
msg.Usage.CompletionTokens = types.FirstSet(msg.Usage.CompletionTokens, response.Usage.CompletionTokens) | ||||||
msg.Usage.PromptTokens = types.FirstSet(msg.Usage.PromptTokens, response.Usage.PromptTokens) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,32 @@ | ||
package openai | ||
|
||
import openai "github.com/gptscript-ai/chat-completion-client" | ||
import ( | ||
"math" | ||
|
||
openai "github.com/gptscript-ai/chat-completion-client" | ||
) | ||
|
||
const DefaultMaxTokens = 128_000 | ||
|
||
func decreaseTenPercent(maxTokens int) int { | ||
maxTokens = getBudget(maxTokens) | ||
return int(math.Round(float64(maxTokens) * 0.9)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
} | ||
|
||
func getBudget(maxTokens int) int { | ||
if maxTokens == 0 { | ||
return DefaultMaxTokens | ||
} | ||
return maxTokens | ||
} | ||
|
||
func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) { | ||
var ( | ||
lastSystem int | ||
withinBudget int | ||
budget = maxTokens | ||
budget = getBudget(maxTokens) | ||
) | ||
|
||
if maxTokens == 0 { | ||
budget = 300_000 | ||
} else { | ||
budget *= 3 | ||
} | ||
|
||
for i, msg := range msgs { | ||
if msg.Role == openai.ChatMessageRoleSystem { | ||
budget -= countMessage(msg) | ||
|
@@ -33,6 +45,14 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( | |
} | ||
} | ||
|
||
// OpenAI gets upset if there is a tool message without a tool call preceding it. | ||
// Check the oldest message within budget, and if it is a tool message, just drop it. | ||
// We do this in a loop because it is possible for multiple tool messages to be in a row, | ||
// due to parallel tool calls. | ||
for withinBudget < len(msgs) && msgs[withinBudget].Role == openai.ChatMessageRoleTool { | ||
withinBudget++ | ||
} | ||
|
||
if withinBudget == len(msgs)-1 { | ||
// We are going to drop all non system messages, which seems useless, so just return them | ||
// all and let it fail | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: probably don't need to
math.Round
here. Not much of a difference between 102,399 and 102,400.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed