From ff1f0eb61a714d937550820b6ded221586f37ea1 Mon Sep 17 00:00:00 2001 From: Jamie Tanna Date: Sun, 8 Feb 2026 18:51:34 +0000 Subject: [PATCH] feat: add ability to skip validation checks As noted in #63, an equivalent to Echo's `Skipper` would allow for middleware users to opt-out of validation in a more straightforward way. In a slightly different implementation to our `echo-middleware`, this does not allow the `Skipper` to consume the body of the original request, and instead duplicates it for the `Skipper`, and the other uses of it. Closes #63. --- oapi_validate.go | 35 +++++++++ oapi_validate_example_test.go | 142 ++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/oapi_validate.go b/oapi_validate.go index 78cf122..6654675 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -8,9 +8,11 @@ package nethttpmiddleware import ( + "bytes" "context" "errors" "fmt" + "io" "log" "net/http" "strings" @@ -74,6 +76,11 @@ type ErrorHandlerOptsMatchedRoute struct { // MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) type MultiErrorHandler func(openapi3.MultiError) (int, error) +// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware +// +// Return `true` if the request should be skipped +type Skipper func(r *http.Request) bool + // Options allows configuring the OapiRequestValidator. type Options struct { // Options contains any configuration for the underlying `openapi3filter` @@ -100,6 +107,9 @@ type Options struct { SilenceServersWarning bool // DoNotValidateServers ensures that there is no Host validation performed (see `SilenceServersWarning` and https://github.com/deepmap/oapi-codegen/issues/882 for more details) DoNotValidateServers bool + + // Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware + Skipper Skipper } // OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration. @@ -126,6 +136,15 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if options != nil && options.Skipper != nil { + r2, err := copyHTTPRequest(r) + if err == nil && options.Skipper(r2) { + // serve with the original request + next.ServeHTTP(w, r) + return + } + } + if options == nil { performRequestValidationForErrorHandler(next, w, r, router, options, http.Error) } else if options.ErrorHandlerWithOpts != nil { @@ -141,6 +160,22 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne } +func copyHTTPRequest(r *http.Request) (*http.Request, error) { + r2 := r.Clone(r.Context()) + + if r.Body != nil { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + // keep the original request body available + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // and have it available for the copy + r2.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + return r2, nil +} + func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) { // validate request statusCode, err := validateRequest(r, router, options) diff --git a/oapi_validate_example_test.go b/oapi_validate_example_test.go index be3dea1..cdfa305 100644 --- a/oapi_validate_example_test.go +++ b/oapi_validate_example_test.go @@ -839,3 +839,145 @@ paths: // Received an HTTP 400 response. Expected HTTP 400 // Response body: There was a bad request } + +func ExampleOapiRequestValidatorWithOptions_withSkipper() { + rawSpec := ` +openapi: "3.0.0" +info: + version: 1.0.0 + title: TestServer +servers: + - url: http://example.com/ +paths: + # we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests + /resource: + post: + operationId: createResource + responses: + '204': + description: No content + requestBody: + required: true + content: + text/plain: {} +` + + must := func(err error) { + if err != nil { + panic(err) + } + } + + use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler { + var s http.Handler + s = r + + for _, mw := range middlewares { + s = mw(s) + } + + return s + } + + logResponseBody := func(rr *httptest.ResponseRecorder) { + if rr.Result().Body != nil { + data, _ := io.ReadAll(rr.Result().Body) + if len(data) > 0 { + fmt.Printf("Response body: %s", data) + } + } + } + + spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec)) + must(err) + + // NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec + // See also: Options#SilenceServersWarning + spec.Servers = nil + + router := http.NewServeMux() + + router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("%s /resource was called\n", r.Method) + + if r.Method == http.MethodPost { + data, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + fmt.Printf("Request body: %s\n", data) + w.WriteHeader(http.StatusNoContent) + return + } + + w.WriteHeader(http.StatusMethodNotAllowed) + }) + + router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error { + fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName) + return fmt.Errorf("this check always fails - don't let anyone in!") + } + + skipperFunc := func(r *http.Request) bool { + // always consume the request body, because we're not following best practices + _, _ = io.ReadAll(r.Body) + + // skip the undocumented healthcheck endpoint + if r.URL.Path == "/healthz" { + return true + } + + return false + } + + // create middleware + mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{ + Options: openapi3filter.Options{ + AuthenticationFunc: authenticationFunc, + }, + Skipper: skipperFunc, + }) + + // then wire it in + server := use(router, mw) + + // ================================================================================ + fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated") + + req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody) + must(err) + + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code) + logResponseBody(rr) + + // ================================================================================ + fmt.Println("# The skipper cannot consume the request body") + + req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there"))) + must(err) + req.Header.Set("Content-Type", "text/plain") + + rr = httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code) + logResponseBody(rr) + + // Output: + // # A request that is made to the undocumented healthcheck endpoint does not get validated + // Received an HTTP 200 response. Expected HTTP 200 + // # The skipper cannot consume the request body + // POST /resource was called + // Request body: Hello there + // Received an HTTP 204 response. Expected HTTP 204 +}