@@ -2,6 +2,7 @@ package openai
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"io"
6
7
"log/slog"
7
8
"os"
@@ -24,6 +25,7 @@ import (
24
25
const (
25
26
DefaultModel = openai .GPT4o
26
27
BuiltinCredName = "sys.openai"
28
+ TooLongMessage = "Error: tool call output is too long"
27
29
)
28
30
29
31
var (
@@ -317,6 +319,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317
319
}
318
320
319
321
if messageRequest .Chat {
322
+ // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
323
+ lastMessage := msgs [len (msgs )- 1 ]
324
+ if lastMessage .Role == string (types .CompletionMessageRoleTypeTool ) && countMessage (lastMessage ) > int (float64 (getBudget (messageRequest .MaxTokens ))* 0.8 ) {
325
+ // We need to update it in the msgs slice for right now and in the messageRequest for future calls.
326
+ msgs [len (msgs )- 1 ].Content = TooLongMessage
327
+ messageRequest .Messages [len (messageRequest .Messages )- 1 ].Content = types .Text (TooLongMessage )
328
+ }
329
+
320
330
msgs = dropMessagesOverCount (messageRequest .MaxTokens , msgs )
321
331
}
322
332
@@ -383,6 +393,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
383
393
return nil , err
384
394
} else if ! ok {
385
395
response , err = c .call (ctx , request , id , status )
396
+
397
+ // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
398
+ var apiError * openai.APIError
399
+ if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
400
+ // Decrease maxTokens by 10% to make garbage collection more aggressive.
401
+ // The retry loop will further decrease maxTokens if needed.
402
+ maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
403
+ response , err = c .contextLimitRetryLoop (ctx , request , id , maxTokens , status )
404
+ }
405
+
386
406
if err != nil {
387
407
return nil , err
388
408
}
@@ -421,6 +441,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
421
441
return & result , nil
422
442
}
423
443
444
+ func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , maxTokens int , status chan <- types.CompletionStatus ) ([]openai.ChatCompletionStreamResponse , error ) {
445
+ var (
446
+ response []openai.ChatCompletionStreamResponse
447
+ err error
448
+ )
449
+
450
+ for range 10 { // maximum 10 tries
451
+ // Try to drop older messages again, with a decreased max tokens.
452
+ request .Messages = dropMessagesOverCount (maxTokens , request .Messages )
453
+ response , err = c .call (ctx , request , id , status )
454
+ if err == nil {
455
+ return response , nil
456
+ }
457
+
458
+ var apiError * openai.APIError
459
+ if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" {
460
+ // Decrease maxTokens and try again
461
+ maxTokens = decreaseTenPercent (maxTokens )
462
+ continue
463
+ }
464
+ return nil , err
465
+ }
466
+
467
+ return nil , err
468
+ }
469
+
424
470
func appendMessage (msg types.CompletionMessage , response openai.ChatCompletionStreamResponse ) types.CompletionMessage {
425
471
msg .Usage .CompletionTokens = types .FirstSet (msg .Usage .CompletionTokens , response .Usage .CompletionTokens )
426
472
msg .Usage .PromptTokens = types .FirstSet (msg .Usage .PromptTokens , response .Usage .PromptTokens )
0 commit comments