Skip to content

Commit e948150

Browse files
authored
fix: chat stream returns an error response with a 'data: ' prefix (#396)
* fix: chat stream resp has 'data: ' prefix * fix: lint error * fix: lint error * fix: lint error
1 parent 7203770 commit e948150

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

chat_stream_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,45 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
178178
t.Logf("%+v\n", apiErr)
179179
}
180180

181+
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
182+
client, server, teardown := setupOpenAITestServer()
183+
defer teardown()
184+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
185+
w.Header().Set("Content-Type", "text/event-stream")
186+
187+
// Send test responses
188+
//nolint:lll
189+
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
190+
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
191+
192+
_, err := w.Write(dataBytes)
193+
checks.NoError(t, err, "Write error")
194+
})
195+
196+
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
197+
MaxTokens: 5,
198+
Model: GPT3Dot5Turbo,
199+
Messages: []ChatCompletionMessage{
200+
{
201+
Role: ChatMessageRoleUser,
202+
Content: "Hello!",
203+
},
204+
},
205+
Stream: true,
206+
})
207+
checks.NoError(t, err, "CreateCompletionStream returned error")
208+
defer stream.Close()
209+
210+
_, streamErr := stream.Recv()
211+
checks.HasError(t, streamErr, "stream.Recv() did not return error")
212+
213+
var apiErr *APIError
214+
if !errors.As(streamErr, &apiErr) {
215+
t.Errorf("stream.Recv() did not return APIError")
216+
}
217+
t.Logf("%+v\n", apiErr)
218+
}
219+
181220
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
182221
client, server, teardown := setupOpenAITestServer()
183222
defer teardown()

stream_reader.go

+18-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ import (
1010
utils "github.com/sashabaranov/go-openai/internal"
1111
)
1212

13+
var (
14+
headerData = []byte("data: ")
15+
errorPrefix = []byte(`data: {"error":`)
16+
)
17+
1318
type streamable interface {
1419
ChatCompletionStreamResponse | CompletionResponse
1520
}
@@ -34,22 +39,31 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
3439
return
3540
}
3641

42+
//nolint:gocognit
3743
func (stream *streamReader[T]) processLines() (T, error) {
38-
var emptyMessagesCount uint
44+
var (
45+
emptyMessagesCount uint
46+
hasErrorPrefix bool
47+
)
3948

4049
for {
4150
rawLine, readErr := stream.reader.ReadBytes('\n')
42-
if readErr != nil {
51+
if readErr != nil || hasErrorPrefix {
4352
respErr := stream.unmarshalError()
4453
if respErr != nil {
4554
return *new(T), fmt.Errorf("error, %w", respErr.Error)
4655
}
4756
return *new(T), readErr
4857
}
4958

50-
var headerData = []byte("data: ")
5159
noSpaceLine := bytes.TrimSpace(rawLine)
52-
if !bytes.HasPrefix(noSpaceLine, headerData) {
60+
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
61+
hasErrorPrefix = true
62+
}
63+
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
64+
if hasErrorPrefix {
65+
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
66+
}
5367
writeErr := stream.errAccumulator.Write(noSpaceLine)
5468
if writeErr != nil {
5569
return *new(T), writeErr

0 commit comments

Comments
 (0)