Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
package nethttpmiddleware

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
Expand Down Expand Up @@ -74,6 +76,11 @@ type ErrorHandlerOptsMatchedRoute struct {
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
type MultiErrorHandler func(openapi3.MultiError) (int, error)

// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware
//
// Return `true` if the request should be skipped
type Skipper func(r *http.Request) bool

// Options allows configuring the OapiRequestValidator.
type Options struct {
// Options contains any configuration for the underlying `openapi3filter`
Expand All @@ -100,6 +107,9 @@ type Options struct {
SilenceServersWarning bool
// DoNotValidateServers ensures that there is no Host validation performed (see `SilenceServersWarning` and https://github.com/deepmap/oapi-codegen/issues/882 for more details)
DoNotValidateServers bool

// Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware
Skipper Skipper
}

// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
Expand All @@ -126,6 +136,15 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if options != nil && options.Skipper != nil {
r2, err := copyHTTPRequest(r)
if err == nil && options.Skipper(r2) {
// serve with the original request
next.ServeHTTP(w, r)
return
}
}

if options == nil {
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
} else if options.ErrorHandlerWithOpts != nil {
Expand All @@ -141,6 +160,22 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne

}

func copyHTTPRequest(r *http.Request) (*http.Request, error) {
r2 := r.Clone(r.Context())

if r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
// keep the original request body available
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// and have it available for the copy
r2.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
return r2, nil
}

func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) {
// validate request
statusCode, err := validateRequest(r, router, options)
Expand Down
142 changes: 142 additions & 0 deletions oapi_validate_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,145 @@ paths:
// Received an HTTP 400 response. Expected HTTP 400
// Response body: There was a bad request
}

func ExampleOapiRequestValidatorWithOptions_withSkipper() {
rawSpec := `
openapi: "3.0.0"
info:
version: 1.0.0
title: TestServer
servers:
- url: http://example.com/
paths:
# we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests
/resource:
post:
operationId: createResource
responses:
'204':
description: No content
requestBody:
required: true
content:
text/plain: {}
`

must := func(err error) {
if err != nil {
panic(err)
}
}

use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler {
var s http.Handler
s = r

for _, mw := range middlewares {
s = mw(s)
}

return s
}

logResponseBody := func(rr *httptest.ResponseRecorder) {
if rr.Result().Body != nil {
data, _ := io.ReadAll(rr.Result().Body)
if len(data) > 0 {
fmt.Printf("Response body: %s", data)
}
}
}

spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec))
must(err)

// NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec
// See also: Options#SilenceServersWarning
spec.Servers = nil

router := http.NewServeMux()

router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("%s /resource was called\n", r.Method)

if r.Method == http.MethodPost {
data, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Printf("Request body: %s\n", data)
w.WriteHeader(http.StatusNoContent)
return
}

w.WriteHeader(http.StatusMethodNotAllowed)
})

router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error {
fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName)
return fmt.Errorf("this check always fails - don't let anyone in!")
}

skipperFunc := func(r *http.Request) bool {
// always consume the request body, because we're not following best practices
_, _ = io.ReadAll(r.Body)

// skip the undocumented healthcheck endpoint
if r.URL.Path == "/healthz" {
return true
}

return false
}

// create middleware
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Options: openapi3filter.Options{
AuthenticationFunc: authenticationFunc,
},
Skipper: skipperFunc,
})

// then wire it in
server := use(router, mw)

// ================================================================================
fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated")

req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody)
must(err)

rr := httptest.NewRecorder()

server.ServeHTTP(rr, req)

fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code)
logResponseBody(rr)

// ================================================================================
fmt.Println("# The skipper cannot consume the request body")

req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there")))
must(err)
req.Header.Set("Content-Type", "text/plain")

rr = httptest.NewRecorder()

server.ServeHTTP(rr, req)

fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code)
logResponseBody(rr)

// Output:
// # A request that is made to the undocumented healthcheck endpoint does not get validated
// Received an HTTP 200 response. Expected HTTP 200
// # The skipper cannot consume the request body
// POST /resource was called
// Request body: Hello there
// Received an HTTP 204 response. Expected HTTP 204
}