Skip to content

Commit 1907cbd

Browse files
jamietannapebomikeschinkelMattiasMartens
committed
rew! feat: add error handler with more configuration
When handling errros **??**, historically we have only been given the error message and a suggested HTTP status code to return, which provides **??** To improve this experience, and provide much more control over **??**, we will introduce the `ErrorHandlerWithOpts` function and its corresponding `ErrorHandlerOpts` struct to handle the **??** This will provide: - Direct access to the `error` that occurred **??** - As a means to **??**o This has been a long-standing issue **??** With thanks to Per, Mike and MattiasMartens who have **??**, as well as many others in the past who have **??** Closes #11, #27. Co-authored-by: Per Bockman <[email protected]> Co-authored-by: Mike Schinkel <[email protected]> Co-authored-by: MattiasMartens <[email protected]>
1 parent b73ed97 commit 1907cbd

File tree

3 files changed

+730
-14
lines changed

3 files changed

+730
-14
lines changed

oapi_validate.go

+177-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package nethttpmiddleware
99

1010
import (
11+
"context"
1112
"errors"
1213
"fmt"
1314
"log"
@@ -21,8 +22,58 @@ import (
2122
)
2223

2324
// ErrorHandler is called when there is an error in validation
25+
//
26+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
27+
//
28+
// Deprecated: it's recommended you migrate to the ErrorHandlerWithOpts, as it provides more control over how to handle an error that occurs, including giving direct access to the `error` itself. There are no plans to remove this method.
2429
type ErrorHandler func(w http.ResponseWriter, message string, statusCode int)
2530

31+
// ErrorHandlerWithOpts is called when there is an error in validation, with more information about the `error` that occurred and which request is currently being processed.
32+
//
33+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
34+
//
35+
// NOTE that this should ideally be used instead of ErrorHandler
36+
type ErrorHandlerWithOpts func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts)
37+
38+
// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of an error being returned by the middleware
39+
type ErrorHandlerOpts struct {
40+
// Error is the underlying error that triggered this error handler to be executed.
41+
//
42+
// Known error types:
43+
//
44+
// - `*openapi3filter.SecurityRequirementsError` - if the `AuthenticationFunc` has failed to authenticate the request
45+
// - `*openapi3filter.RequestError` - if a bad request has been made
46+
//
47+
// Additionally, if you have set `openapi3filter.Options#MultiError`:
48+
//
49+
// - `openapi3.MultiError` (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
50+
Error error
51+
52+
// StatusCode indicates the HTTP Status Code that the OpenAPI validation middleware _suggests_ is returned to the user.
53+
//
54+
// NOTE that this is very much a suggestion, and can be overridden if you believe you have a better approach.
55+
StatusCode int
56+
57+
// MatchedRoute is the underlying path that this request is being matched against.
58+
//
59+
// This is the route according to the OpenAPI validation middleware, and can be used in addition to/instead of the `http.Request`
60+
//
61+
// NOTE that this will be nil if there is no matched route (i.e. a request has been sent to an endpoint not in the OpenAPI spec)
62+
MatchedRoute *ErrorHandlerOptsMatchedRoute
63+
}
64+
65+
type ErrorHandlerOptsMatchedRoute struct {
66+
// Route indicates the Route that this error is received by.
67+
//
68+
// This can be used in addition to/instead of the `http.Request`.
69+
Route *routers.Route
70+
71+
// PathParams are any path parameters that are determined from the request.
72+
//
73+
// This can be used in addition to/instead of the `http.Request`.
74+
PathParams map[string]string
75+
}
76+
2677
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
2778
type MultiErrorHandler func(openapi3.MultiError) (int, error)
2879

@@ -32,11 +83,21 @@ type Options struct {
3283
Options openapi3filter.Options
3384
// ErrorHandler is called when a validation error occurs.
3485
//
86+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
87+
//
3588
// If not provided, `http.Error` will be called
3689
ErrorHandler ErrorHandler
90+
91+
// ErrorHandlerWithOpts is called when there is an error in validation.
92+
//
93+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
94+
ErrorHandlerWithOpts ErrorHandlerWithOpts
95+
3796
// MultiErrorHandler is called when there is an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) returned by the `openapi3filter`.
3897
//
3998
// If not provided `defaultMultiErrorHandler` will be used.
99+
//
100+
// Does not get called when using `ErrorHandlerWithOpts`
40101
MultiErrorHandler MultiErrorHandler
41102
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
42103
SilenceServersWarning bool
@@ -62,27 +123,97 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
62123

63124
return func(next http.Handler) http.Handler {
64125
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65-
// validate request
66-
statusCode, err := validateRequest(r, router, options)
67-
if err == nil {
68-
// serve
69-
next.ServeHTTP(w, r)
70-
return
71-
}
72-
73126
if options == nil {
74-
http.Error(w, err.Error(), statusCode)
75-
return
76-
}
77-
78-
if options.ErrorHandler != nil {
79-
options.ErrorHandler(w, err.Error(), statusCode)
127+
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
128+
} else if options.ErrorHandlerWithOpts != nil {
129+
performRequestValidationForErrorHandlerWithOpts(next, w, r, router, options)
130+
} else if options.ErrorHandler != nil {
131+
performRequestValidationForErrorHandler(next, w, r, router, options, options.ErrorHandler)
132+
} else {
133+
// NOTE that this shouldn't happen, but let's be sure that we always end up calling the default error handler if no other handler is defined
134+
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
80135
}
81136
})
82137
}
83138

84139
}
85140

141+
func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) {
142+
// validate request
143+
statusCode, err := validateRequest(r, router, options)
144+
if err == nil {
145+
// serve
146+
next.ServeHTTP(w, r)
147+
return
148+
}
149+
150+
errorHandler(w, err.Error(), statusCode)
151+
}
152+
153+
// **??**
154+
// Note that this is an inline-and-modified version of `validateRequest` that **??**.
155+
func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) {
156+
// Find route
157+
route, pathParams, err := router.FindRoute(r)
158+
if err != nil {
159+
errOpts := ErrorHandlerOpts{
160+
// MatchedRoute will be nil, as we've not matched a route we know about
161+
Error: err,
162+
StatusCode: http.StatusNotFound,
163+
}
164+
165+
options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
166+
return
167+
}
168+
169+
errOpts := ErrorHandlerOpts{
170+
MatchedRoute: &ErrorHandlerOptsMatchedRoute{
171+
Route: route,
172+
PathParams: pathParams,
173+
},
174+
// other options will be added before executing
175+
}
176+
177+
// Validate request
178+
requestValidationInput := &openapi3filter.RequestValidationInput{
179+
Request: r,
180+
PathParams: pathParams,
181+
Route: route,
182+
}
183+
184+
if options != nil {
185+
requestValidationInput.Options = &options.Options
186+
}
187+
188+
err = openapi3filter.ValidateRequest(r.Context(), requestValidationInput)
189+
if err == nil {
190+
// it's a valid request, so serve it
191+
next.ServeHTTP(w, r)
192+
return
193+
}
194+
195+
switch e := err.(type) {
196+
case openapi3.MultiError:
197+
errOpts.Error = e
198+
errOpts.StatusCode = determineStatusCodeForMultiError(e)
199+
case *openapi3filter.RequestError:
200+
// We've got a bad request
201+
errOpts.Error = e
202+
errOpts.StatusCode = http.StatusBadRequest
203+
case *openapi3filter.SecurityRequirementsError:
204+
errOpts.Error = e
205+
errOpts.StatusCode = http.StatusUnauthorized
206+
default:
207+
// This should never happen today, but if our upstream code changes,
208+
// we don't want to crash the server, so handle the unexpected error.
209+
// return http.StatusInternalServerError,
210+
errOpts.Error = fmt.Errorf("error validating route: %w", e)
211+
errOpts.StatusCode = http.StatusUnauthorized
212+
}
213+
214+
options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
215+
}
216+
86217
// validateRequest is called from the middleware above and actually does the work
87218
// of validating a request.
88219
func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
@@ -150,3 +281,35 @@ func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
150281
func defaultMultiErrorHandler(me openapi3.MultiError) (int, error) {
151282
return http.StatusBadRequest, me
152283
}
284+
285+
func determineStatusCodeForMultiError(errs openapi3.MultiError) int {
286+
numRequestErrors := 0
287+
numSecurityRequirementsErrors := 0
288+
289+
for _, err := range errs {
290+
switch err.(type) {
291+
case *openapi3filter.RequestError:
292+
numRequestErrors++
293+
case *openapi3filter.SecurityRequirementsError:
294+
numSecurityRequirementsErrors++
295+
default:
296+
// if we have /any/ unknown error types, we should suggest returning an HTTP 500 Internal Server Error
297+
return http.StatusInternalServerError
298+
}
299+
}
300+
301+
if numRequestErrors > 0 && numSecurityRequirementsErrors > 0 {
302+
return http.StatusInternalServerError
303+
}
304+
305+
if numRequestErrors > 0 {
306+
return http.StatusBadRequest
307+
}
308+
309+
if numSecurityRequirementsErrors > 0 {
310+
return http.StatusUnauthorized
311+
}
312+
313+
// we shouldn't hit this, but to be safe, return an HTTP 500 Internal Server Error if we don't have any cases above
314+
return http.StatusInternalServerError
315+
}

0 commit comments

Comments
 (0)