Skip to content

Commit d6127fe

Browse files
authored
Rework timeout middleware to use http.TimeoutHandler implementation (fix #1761) (#1801)
1 parent 5622ecc commit d6127fe

File tree

2 files changed

+172
-97
lines changed

2 files changed

+172
-97
lines changed

middleware/timeout.go

+56-38
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ package middleware
44

55
import (
66
"context"
7-
"fmt"
87
"github.com/labstack/echo/v4"
8+
"net/http"
99
"time"
1010
)
1111

@@ -14,24 +14,31 @@ type (
1414
TimeoutConfig struct {
1515
// Skipper defines a function to skip middleware.
1616
Skipper Skipper
17-
// ErrorHandler defines a function which is executed for a timeout
18-
// It can be used to define a custom timeout error
19-
ErrorHandler TimeoutErrorHandlerWithContext
17+
18+
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
19+
// It can be used to define a custom timeout error message
20+
ErrorMessage string
21+
22+
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
23+
// request timeouted and we already had sent the error code (503) and message response to the client.
24+
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
25+
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
26+
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
27+
2028
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
29+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
30+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
31+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
2132
Timeout time.Duration
2233
}
23-
24-
// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
25-
// handle the error as we see fit
26-
TimeoutErrorHandlerWithContext func(error, echo.Context) error
2734
)
2835

2936
var (
3037
// DefaultTimeoutConfig is the default Timeout middleware config.
3138
DefaultTimeoutConfig = TimeoutConfig{
3239
Skipper: DefaultSkipper,
3340
Timeout: 0,
34-
ErrorHandler: nil,
41+
ErrorMessage: "",
3542
}
3643
)
3744

@@ -55,39 +62,50 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
5562
return next(c)
5663
}
5764

58-
ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
59-
defer cancel()
60-
61-
// this does a deep clone of the context, wondering if there is a better way to do this?
62-
c.SetRequest(c.Request().Clone(ctx))
63-
64-
done := make(chan error, 1)
65-
go func() {
66-
defer func() {
67-
if r := recover(); r != nil {
68-
err, ok := r.(error)
69-
if !ok {
70-
err = fmt.Errorf("panic recovered in timeout middleware: %v", r)
71-
}
72-
c.Logger().Error(err)
73-
done <- err
74-
}
75-
}()
76-
77-
// This goroutine will keep running even if this middleware times out and
78-
// will be stopped when ctx.Done() is called down the next(c) call chain
79-
done <- next(c)
80-
}()
65+
handlerWrapper := echoHandlerFuncWrapper{
66+
ctx: c,
67+
handler: next,
68+
errChan: make(chan error, 1),
69+
errHandler: config.OnTimeoutRouteErrorHandler,
70+
}
71+
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
72+
handler.ServeHTTP(c.Response().Writer, c.Request())
8173

8274
select {
83-
case <-ctx.Done():
84-
if config.ErrorHandler != nil {
85-
return config.ErrorHandler(ctx.Err(), c)
86-
}
87-
return ctx.Err()
88-
case err := <-done:
75+
case err := <-handlerWrapper.errChan:
8976
return err
77+
default:
78+
return nil
9079
}
9180
}
9281
}
9382
}
83+
84+
type echoHandlerFuncWrapper struct {
85+
ctx echo.Context
86+
handler echo.HandlerFunc
87+
errHandler func(err error, c echo.Context)
88+
errChan chan error
89+
}
90+
91+
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
92+
// replace writer with TimeoutHandler custom one. This will guarantee that
93+
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
94+
originalWriter := t.ctx.Response().Writer
95+
t.ctx.Response().Writer = rw
96+
97+
err := t.handler(t.ctx)
98+
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
99+
if err != nil && t.errHandler != nil {
100+
t.errHandler(err, t.ctx)
101+
}
102+
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
103+
}
104+
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
105+
// and should not anymore send additional headers/data
106+
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
107+
t.ctx.Response().Writer = originalWriter
108+
if err != nil {
109+
t.errChan <- err
110+
}
111+
}

middleware/timeout_test.go

+116-59
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
package middleware
44

55
import (
6-
"context"
76
"errors"
87
"github.com/labstack/echo/v4"
98
"github.com/stretchr/testify/assert"
@@ -22,6 +21,7 @@ func TestTimeoutSkipper(t *testing.T) {
2221
Skipper: func(context echo.Context) bool {
2322
return true
2423
},
24+
Timeout: 1 * time.Nanosecond,
2525
})
2626

2727
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -31,18 +31,17 @@ func TestTimeoutSkipper(t *testing.T) {
3131
c := e.NewContext(req, rec)
3232

3333
err := m(func(c echo.Context) error {
34-
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
35-
return nil
34+
time.Sleep(25 * time.Microsecond)
35+
return errors.New("response from handler")
3636
})(c)
3737

38-
assert.NoError(t, err)
38+
// if not skipped we would have not returned error due context timeout logic
39+
assert.EqualError(t, err, "response from handler")
3940
}
4041

4142
func TestTimeoutWithTimeout0(t *testing.T) {
4243
t.Parallel()
43-
m := TimeoutWithConfig(TimeoutConfig{
44-
Timeout: 0,
45-
})
44+
m := Timeout()
4645

4746
req := httptest.NewRequest(http.MethodGet, "/", nil)
4847
rec := httptest.NewRecorder()
@@ -58,10 +57,11 @@ func TestTimeoutWithTimeout0(t *testing.T) {
5857
assert.NoError(t, err)
5958
}
6059

61-
func TestTimeoutIsCancelable(t *testing.T) {
60+
func TestTimeoutErrorOutInHandler(t *testing.T) {
6261
t.Parallel()
6362
m := TimeoutWithConfig(TimeoutConfig{
64-
Timeout: time.Minute,
63+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
64+
Timeout: 50 * time.Millisecond,
6565
})
6666

6767
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -70,59 +70,22 @@ func TestTimeoutIsCancelable(t *testing.T) {
7070
e := echo.New()
7171
c := e.NewContext(req, rec)
7272

73-
err := m(func(c echo.Context) error {
74-
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
75-
return nil
76-
})(c)
77-
78-
assert.NoError(t, err)
79-
}
80-
81-
func TestTimeoutErrorOutInHandler(t *testing.T) {
82-
t.Parallel()
83-
m := Timeout()
84-
85-
req := httptest.NewRequest(http.MethodGet, "/", nil)
86-
rec := httptest.NewRecorder()
87-
88-
e := echo.New()
89-
c := e.NewContext(req, rec)
90-
9173
err := m(func(c echo.Context) error {
9274
return errors.New("err")
9375
})(c)
9476

9577
assert.Error(t, err)
9678
}
9779

98-
func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
80+
func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
9981
t.Parallel()
100-
m := TimeoutWithConfig(TimeoutConfig{
101-
Timeout: time.Second,
102-
ErrorHandler: func(err error, e echo.Context) error {
103-
assert.EqualError(t, err, context.DeadlineExceeded.Error())
104-
return errors.New("err")
105-
},
106-
})
107-
108-
req := httptest.NewRequest(http.MethodGet, "/", nil)
109-
rec := httptest.NewRecorder()
110-
111-
e := echo.New()
112-
c := e.NewContext(req, rec)
11382

114-
err := m(func(c echo.Context) error {
115-
time.Sleep(time.Minute)
116-
return nil
117-
})(c)
118-
119-
assert.EqualError(t, err, errors.New("err").Error())
120-
}
121-
122-
func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
123-
t.Parallel()
83+
actualErrChan := make(chan error, 1)
12484
m := TimeoutWithConfig(TimeoutConfig{
125-
Timeout: time.Second,
85+
Timeout: 1 * time.Millisecond,
86+
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) {
87+
actualErrChan <- err
88+
},
12689
})
12790

12891
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -131,12 +94,16 @@ func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
13194
e := echo.New()
13295
c := e.NewContext(req, rec)
13396

97+
stopChan := make(chan struct{}, 0)
13498
err := m(func(c echo.Context) error {
135-
time.Sleep(time.Minute)
136-
return nil
99+
<-stopChan
100+
return errors.New("error in route after timeout")
137101
})(c)
102+
stopChan <- struct{}{}
103+
assert.NoError(t, err)
138104

139-
assert.EqualError(t, err, context.DeadlineExceeded.Error())
105+
actualErr := <-actualErrChan
106+
assert.EqualError(t, actualErr, "error in route after timeout")
140107
}
141108

142109
func TestTimeoutTestRequestClone(t *testing.T) {
@@ -148,7 +115,7 @@ func TestTimeoutTestRequestClone(t *testing.T) {
148115

149116
m := TimeoutWithConfig(TimeoutConfig{
150117
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
151-
Timeout: time.Second,
118+
Timeout: 1 * time.Second,
152119
})
153120

154121
e := echo.New()
@@ -178,8 +145,63 @@ func TestTimeoutTestRequestClone(t *testing.T) {
178145

179146
func TestTimeoutRecoversPanic(t *testing.T) {
180147
t.Parallel()
148+
e := echo.New()
149+
e.Use(Recover()) // recover middleware will handler our panic
150+
e.Use(TimeoutWithConfig(TimeoutConfig{
151+
Timeout: 50 * time.Millisecond,
152+
}))
153+
154+
e.GET("/", func(c echo.Context) error {
155+
panic("panic!!!")
156+
})
157+
158+
req := httptest.NewRequest(http.MethodGet, "/", nil)
159+
rec := httptest.NewRecorder()
160+
161+
assert.NotPanics(t, func() {
162+
e.ServeHTTP(rec, req)
163+
})
164+
}
165+
166+
func TestTimeoutDataRace(t *testing.T) {
167+
t.Parallel()
168+
169+
timeout := 1 * time.Millisecond
170+
m := TimeoutWithConfig(TimeoutConfig{
171+
Timeout: timeout,
172+
ErrorMessage: "Timeout! change me",
173+
})
174+
175+
req := httptest.NewRequest(http.MethodGet, "/", nil)
176+
rec := httptest.NewRecorder()
177+
178+
e := echo.New()
179+
c := e.NewContext(req, rec)
180+
181+
err := m(func(c echo.Context) error {
182+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
183+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
184+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
185+
time.Sleep(timeout) // timeout and handler execution time difference is close to zero
186+
return c.String(http.StatusOK, "Hello, World!")
187+
})(c)
188+
189+
assert.NoError(t, err)
190+
191+
if rec.Code == http.StatusServiceUnavailable {
192+
assert.Equal(t, "Timeout! change me", rec.Body.String())
193+
} else {
194+
assert.Equal(t, "Hello, World!", rec.Body.String())
195+
}
196+
}
197+
198+
func TestTimeoutWithErrorMessage(t *testing.T) {
199+
t.Parallel()
200+
201+
timeout := 1 * time.Millisecond
181202
m := TimeoutWithConfig(TimeoutConfig{
182-
Timeout: 25 * time.Millisecond,
203+
Timeout: timeout,
204+
ErrorMessage: "Timeout! change me",
183205
})
184206

185207
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -188,9 +210,44 @@ func TestTimeoutRecoversPanic(t *testing.T) {
188210
e := echo.New()
189211
c := e.NewContext(req, rec)
190212

213+
stopChan := make(chan struct{}, 0)
191214
err := m(func(c echo.Context) error {
192-
panic("panic in handler")
215+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
216+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
217+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
218+
<-stopChan
219+
return c.String(http.StatusOK, "Hello, World!")
193220
})(c)
221+
stopChan <- struct{}{}
222+
223+
assert.NoError(t, err)
224+
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
225+
assert.Equal(t, "Timeout! change me", rec.Body.String())
226+
}
227+
228+
func TestTimeoutWithDefaultErrorMessage(t *testing.T) {
229+
t.Parallel()
194230

195-
assert.Error(t, err, "panic recovered in timeout middleware: panic in handler")
231+
timeout := 1 * time.Millisecond
232+
m := TimeoutWithConfig(TimeoutConfig{
233+
Timeout: timeout,
234+
ErrorMessage: "",
235+
})
236+
237+
req := httptest.NewRequest(http.MethodGet, "/", nil)
238+
rec := httptest.NewRecorder()
239+
240+
e := echo.New()
241+
c := e.NewContext(req, rec)
242+
243+
stopChan := make(chan struct{}, 0)
244+
err := m(func(c echo.Context) error {
245+
<-stopChan
246+
return c.String(http.StatusOK, "Hello, World!")
247+
})(c)
248+
stopChan <- struct{}{}
249+
250+
assert.NoError(t, err)
251+
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
252+
assert.Equal(t, `<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>`, rec.Body.String())
196253
}

0 commit comments

Comments
 (0)