Skip to content

Commit d571a1b

Browse files
fix: fix early cancel when RequestTimeout is provided for streaming requests (#9)
1 parent 14a21dd commit d571a1b

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

internal/requestconfig/requestconfig.go

+52-7
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,41 @@ func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
292292
return 0, false
293293
}
294294

295+
// isBeforeContextDeadline reports whether the non-zero Time t is
296+
// before ctx's deadline. If ctx does not have a deadline, it
297+
// always reports true (the deadline is considered infinite).
298+
func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
299+
d, ok := ctx.Deadline()
300+
if !ok {
301+
return true
302+
}
303+
return t.Before(d)
304+
}
305+
306+
// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
307+
// to handle timeouts etc. It wraps an existing io.ReadCloser.
308+
type bodyWithTimeout struct {
309+
stop func() // stops the time.Timer waiting to cancel the request
310+
rc io.ReadCloser
311+
}
312+
313+
func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
314+
n, err = b.rc.Read(p)
315+
if err == nil {
316+
return n, nil
317+
}
318+
if err == io.EOF {
319+
return n, err
320+
}
321+
return n, err
322+
}
323+
324+
func (b *bodyWithTimeout) Close() error {
325+
err := b.rc.Close()
326+
b.stop()
327+
return err
328+
}
329+
295330
func retryDelay(res *http.Response, retryCount int) time.Duration {
296331
// If the API asks us to wait a certain amount of time (and it's a reasonable amount),
297332
// just do what it says.
@@ -353,12 +388,17 @@ func (cfg *RequestConfig) Execute() (err error) {
353388
shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
354389

355390
var res *http.Response
391+
var cancel context.CancelFunc
356392
for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
357393
ctx := cfg.Request.Context()
358-
if cfg.RequestTimeout != time.Duration(0) {
359-
var cancel context.CancelFunc
394+
if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
360395
ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
361-
defer cancel()
396+
defer func() {
397+
// The cancel function is nil if it was handed off to be handled in a different scope.
398+
if cancel != nil {
399+
cancel()
400+
}
401+
}()
362402
}
363403

364404
req := cfg.Request.Clone(ctx)
@@ -426,10 +466,15 @@ func (cfg *RequestConfig) Execute() (err error) {
426466
return &aerr
427467
}
428468

429-
if cfg.ResponseBodyInto == nil {
430-
return nil
431-
}
432-
if _, ok := cfg.ResponseBodyInto.(**http.Response); ok {
469+
_, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
470+
if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
471+
// We aren't reading the response body in this scope, but whoever is will need the
472+
// cancel func from the context to observe request timeouts.
473+
// Put the cancel function in the response body so it can be handled elsewhere.
474+
if cancel != nil {
475+
res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
476+
cancel = nil
477+
}
433478
return nil
434479
}
435480

0 commit comments

Comments
 (0)