Skip to content

Commit 0ae7464

Browse files
authored
Support retries of failed proxy requests (#2414)
Support retries of failed proxy requests
1 parent deb17d2 commit 0ae7464

File tree

2 files changed

+437
-41
lines changed

2 files changed

+437
-41
lines changed

middleware/proxy.go

+121-41
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,33 @@ type (
2929
// Required.
3030
Balancer ProxyBalancer
3131

32+
// RetryCount defines the number of times a failed proxied request should be retried
33+
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
34+
RetryCount int
35+
36+
// RetryFilter defines a function used to determine if a failed request to a
37+
// ProxyTarget should be retried. The RetryFilter will only be called when the number
38+
// of previous retries is less than RetryCount. If the function returns true, the
39+
// request will be retried. The provided error indicates the reason for the request
40+
// failure. When the ProxyTarget is unavailable, the error will be an instance of
41+
// echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
42+
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
43+
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
44+
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
45+
// only called when the request to the target fails, or an internal error in the Proxy
46+
// middleware has occurred. Successful requests that return a non-200 response code cannot
47+
// be retried.
48+
RetryFilter func(c echo.Context, e error) bool
49+
50+
// ErrorHandler defines a function which can be used to return custom errors from
51+
// the Proxy middleware. ErrorHandler is only invoked when there has been
52+
// either an internal error in the Proxy middleware or the ProxyTarget is
53+
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
54+
// when a ProxyTarget returns a non-200 response. In these cases, the response
55+
// is already written so errors cannot be modified. ErrorHandler is only
56+
// invoked after all retry attempts have been exhausted.
57+
ErrorHandler func(c echo.Context, err error) error
58+
3259
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
3360
// retrieved by index e.g. $1, $2 and so on.
3461
// Examples:
@@ -71,7 +98,8 @@ type (
7198
Next(echo.Context) *ProxyTarget
7299
}
73100

74-
// TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
101+
// TargetProvider defines an interface that gives the opportunity for balancer
102+
// to return custom errors when selecting target.
75103
TargetProvider interface {
76104
NextTarget(echo.Context) (*ProxyTarget, error)
77105
}
@@ -107,22 +135,22 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
107135
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
108136
in, _, err := c.Response().Hijack()
109137
if err != nil {
110-
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
138+
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
111139
return
112140
}
113141
defer in.Close()
114142

115143
out, err := net.Dial("tcp", t.URL.Host)
116144
if err != nil {
117-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
145+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
118146
return
119147
}
120148
defer out.Close()
121149

122150
// Write header
123151
err = r.Write(out)
124152
if err != nil {
125-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
153+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
126154
return
127155
}
128156

@@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
136164
go cp(in, out)
137165
err = <-errCh
138166
if err != nil && err != io.EOF {
139-
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
167+
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL))
140168
}
141169
})
142170
}
@@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
200228
return b.targets[b.random.Intn(len(b.targets))]
201229
}
202230

203-
// Next returns an upstream target using round-robin technique.
231+
// Next returns an upstream target using round-robin technique. In the case
232+
// where a previously failed request is being retried, the round-robin
233+
// balancer will attempt to use the next target relative to the original
234+
// request. If the list of targets held by the balancer is modified while a
235+
// failed request is being retried, it is possible that the balancer will
236+
// return the original failed target.
204237
//
205238
// Note: `nil` is returned in case upstream target list is empty.
206239
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
@@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
211244
} else if len(b.targets) == 1 {
212245
return b.targets[0]
213246
}
214-
// reset the index if out of bounds
215-
if b.i >= len(b.targets) {
216-
b.i = 0
247+
248+
var i int
249+
const lastIdxKey = "_round_robin_last_index"
250+
// This request is a retry, start from the index of the previous
251+
// target to ensure we don't attempt to retry the request with
252+
// the same failed target
253+
if c.Get(lastIdxKey) != nil {
254+
i = c.Get(lastIdxKey).(int)
255+
i++
256+
if i >= len(b.targets) {
257+
i = 0
258+
}
259+
} else {
260+
// This is a first time request, use the global index
261+
if b.i >= len(b.targets) {
262+
b.i = 0
263+
}
264+
i = b.i
265+
b.i++
217266
}
218-
t := b.targets[b.i]
219-
b.i++
220-
return t
267+
268+
c.Set(lastIdxKey, i)
269+
return b.targets[i]
221270
}
222271

223272
// Proxy returns a Proxy middleware.
@@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
232281
// ProxyWithConfig returns a Proxy middleware with config.
233282
// See: `Proxy()`
234283
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
284+
if config.Balancer == nil {
285+
panic("echo: proxy middleware requires balancer")
286+
}
235287
// Defaults
236288
if config.Skipper == nil {
237289
config.Skipper = DefaultProxyConfig.Skipper
238290
}
239-
if config.Balancer == nil {
240-
panic("echo: proxy middleware requires balancer")
291+
if config.RetryFilter == nil {
292+
config.RetryFilter = func(c echo.Context, e error) bool {
293+
if httpErr, ok := e.(*echo.HTTPError); ok {
294+
return httpErr.Code == http.StatusBadGateway
295+
}
296+
return false
297+
}
298+
}
299+
if config.ErrorHandler == nil {
300+
config.ErrorHandler = func(c echo.Context, err error) error {
301+
return err
302+
}
241303
}
242-
243304
if config.Rewrite != nil {
244305
if config.RegexRewrite == nil {
245306
config.RegexRewrite = make(map[*regexp.Regexp]string)
@@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
250311
}
251312

252313
provider, isTargetProvider := config.Balancer.(TargetProvider)
314+
253315
return func(next echo.HandlerFunc) echo.HandlerFunc {
254-
return func(c echo.Context) (err error) {
316+
return func(c echo.Context) error {
255317
if config.Skipper(c) {
256318
return next(c)
257319
}
258320

259321
req := c.Request()
260322
res := c.Response()
261-
262-
var tgt *ProxyTarget
263-
if isTargetProvider {
264-
tgt, err = provider.NextTarget(c)
265-
if err != nil {
266-
return err
267-
}
268-
} else {
269-
tgt = config.Balancer.Next(c)
270-
}
271-
c.Set(config.ContextKey, tgt)
272-
273323
if err := rewriteURL(config.RegexRewrite, req); err != nil {
274-
return err
324+
return config.ErrorHandler(c, err)
275325
}
276326

277327
// Fix header
@@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
287337
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
288338
}
289339

290-
// Proxy
291-
switch {
292-
case c.IsWebSocket():
293-
proxyRaw(tgt, c).ServeHTTP(res, req)
294-
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
295-
default:
296-
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
297-
}
298-
if e, ok := c.Get("_error").(error); ok {
299-
err = e
300-
}
340+
retries := config.RetryCount
341+
for {
342+
var tgt *ProxyTarget
343+
var err error
344+
if isTargetProvider {
345+
tgt, err = provider.NextTarget(c)
346+
if err != nil {
347+
return config.ErrorHandler(c, err)
348+
}
349+
} else {
350+
tgt = config.Balancer.Next(c)
351+
}
301352

302-
return
353+
c.Set(config.ContextKey, tgt)
354+
355+
//If retrying a failed request, clear any previous errors from
356+
//context here so that balancers have the option to check for
357+
//errors that occurred using previous target
358+
if retries < config.RetryCount {
359+
c.Set("_error", nil)
360+
}
361+
362+
// Proxy
363+
switch {
364+
case c.IsWebSocket():
365+
proxyRaw(tgt, c).ServeHTTP(res, req)
366+
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
367+
default:
368+
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
369+
}
370+
371+
err, hasError := c.Get("_error").(error)
372+
if !hasError {
373+
return nil
374+
}
375+
376+
retry := retries > 0 && config.RetryFilter(c, err)
377+
if !retry {
378+
return config.ErrorHandler(c, err)
379+
}
380+
381+
retries--
382+
}
303383
}
304384
}
305385
}

0 commit comments

Comments
 (0)