Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions apiendpoint/api_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net/http"
"time"

"github.com/go-playground/validator/v10"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"

Expand Down Expand Up @@ -94,6 +95,9 @@ type MountOpts struct {
// MiddlewareStack is a stack of middleware that will be mounted in front of
// the API endpoint handler. If not specified, no middleware will be used.
MiddlewareStack *apimiddleware.MiddlewareStack
// Validator is the validator to use for this endpoint. If not specified,
// the default validator will be used.
Validator *validator.Validate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing is that it seems like a bit of a smell to have this configured on a per-endpoint basis. Very likely you'd want one custom validator that you're going to be using application-wide, and having to inject that into every single endpoint seems like a large-ish chore and would add quite a lot of visual noise to the definitions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the issue is there’s no global struct to keep it on. I will say in practice it turned out to not be bad because I’m reusing the same mount opts struct for all endpoints in each group (authenticated and not).

}

// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log
Expand All @@ -108,14 +112,19 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI
logger = slog.Default()
}

validator := opts.Validator
if validator == nil {
validator = validate.Default
}

apiEndpoint.SetLogger(logger)

meta := apiEndpoint.Meta()
meta.validate() // panic on problem
apiEndpoint.SetMeta(meta)

innerHandler := func(w http.ResponseWriter, r *http.Request) {
executeAPIEndpoint(w, r, opts.Logger, meta, apiEndpoint.Execute)
executeAPIEndpoint(w, r, opts.Logger, meta, validator, apiEndpoint.Execute)
}

if opts.MiddlewareStack != nil {
Expand All @@ -127,13 +136,10 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI
return apiEndpoint
}

func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) {
func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, validator *validator.Validate, execute func(ctx context.Context, req *TReq) (*TResp, error)) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()

// Run as much code as we can in a sub-function that can return an error.
// This is more convenient to write, but is also safer because unlike when
// writing errors to ResponseWriter, there's no danger of a missing return.
err := func() error {
var req TReq
if r.Method != http.MethodGet {
Expand Down Expand Up @@ -161,8 +167,8 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ
}
}

if err := validate.StructCtx(ctx, &req); err != nil {
return apierror.NewBadRequest(validate.PublicFacingMessage(err))
if err := validator.StructCtx(ctx, &req); err != nil {
return apierror.NewBadRequest(validate.PublicFacingMessage(validator, err))
}

resp, err := execute(ctx, &req)
Expand Down
20 changes: 15 additions & 5 deletions apitest/apitest.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/riverqueue/apiframe/apiendpoint"
"github.com/riverqueue/apiframe/apierror"
"github.com/riverqueue/apiframe/internal/validate"
)
Expand All @@ -21,19 +22,28 @@ import (
// Sample invocation:
//
// endpoint := &testEndpoint{}
// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, &testRequest{ReqField: "string"})
// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, nil, &testRequest{ReqField: "string"})
// require.NoError(t, err)
func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), req *TReq) (*TResp, error) {
if err := validate.StructCtx(ctx, req); err != nil {
return nil, apierror.NewBadRequest(validate.PublicFacingMessage(err))
func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), opts *apiendpoint.MountOpts, req *TReq) (*TResp, error) {
if opts == nil {
opts = &apiendpoint.MountOpts{}
}

validator := opts.Validator
if validator == nil {
validator = validate.Default
}

if err := validator.StructCtx(ctx, req); err != nil {
return nil, apierror.NewBadRequest(validate.PublicFacingMessage(validator, err))
}

resp, err := handler(ctx, req)
if err != nil {
return nil, err
}

if err := validate.StructCtx(ctx, resp); err != nil {
if err := validator.StructCtx(ctx, resp); err != nil {
return nil, fmt.Errorf("apitest: error validating response API resource: %w", err)
}

Expand Down
19 changes: 16 additions & 3 deletions apitest/apitest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"testing"

"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/require"

"github.com/riverqueue/apiframe/apiendpoint"
"github.com/riverqueue/apiframe/apierror"
)

Expand All @@ -28,15 +30,15 @@ func TestInvokeHandler(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()

resp, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"})
resp, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"})
require.NoError(t, err)
require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp)
})

t.Run("ValidatesRequest", func(t *testing.T) {
t.Parallel()

_, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: ""})
_, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: ""})
require.Equal(t, apierror.NewBadRequestf("Field `req_field` is required."), err)
})

Expand All @@ -47,7 +49,18 @@ func TestInvokeHandler(t *testing.T) {
return &testResponse{RequiredRespField: ""}, nil
}

_, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"})
_, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"})
require.EqualError(t, err, "apitest: error validating response API resource: Key: 'testResponse.resp_field' Error:Field validation for 'resp_field' failed on the 'required' tag")
})

t.Run("CustomValidator", func(t *testing.T) {
t.Parallel()

customValidator := validator.New()
opts := &apiendpoint.MountOpts{Validator: customValidator}

resp, err := InvokeHandler(ctx, handler, opts, &testRequest{RequiredReqField: "string"})
require.NoError(t, err)
require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp)
})
}
17 changes: 5 additions & 12 deletions internal/validate/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@
package validate

import (
"context"
"fmt"
"reflect"
"strings"

"github.com/go-playground/validator/v10"
)

// WithRequiredStructEnabled can be removed once validator/v11 is released.
var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals
// Default is the package's default validator instance. WithRequiredStructEnabled
// can be removed once validator/v11 is released.
var Default = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals

func init() { //nolint:gochecknoinits
validate.RegisterTagNameFunc(preferPublicName)
Default.RegisterTagNameFunc(preferPublicName)
}

// PublicFacingMessage builds a complete error message from a validator error
// that's suitable for public-facing consumption.
//
// I only added a few possible validations to start. We'll probably need to add
// more as we go and expand our usage.
func PublicFacingMessage(validatorErr error) string {
func PublicFacingMessage(v *validator.Validate, validatorErr error) string {
var message string

//nolint:errorlint
Expand Down Expand Up @@ -97,13 +97,6 @@ func PublicFacingMessage(validatorErr error) string {
return strings.TrimSpace(message)
}

// StructCtx validates a structs exposed fields, and automatically validates
// nested structs, unless otherwise specified and also allows passing of
// context.Context for contextual validation information.
func StructCtx(ctx context.Context, s any) error {
return validate.StructCtx(ctx, s)
}

// preferPublicName is a validator tag naming function that uses public names
// like a field's JSON tag instead of actual field names in structs.
// This is important because we sent these back as user-facing errors (and the
Expand Down
24 changes: 14 additions & 10 deletions internal/validate/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import (
"reflect"
"testing"

"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/require"
)

func TestFromValidator(t *testing.T) {
t.Parallel()

validator := validator.New(validator.WithRequiredStructEnabled())
validator.RegisterTagNameFunc(preferPublicName)

// Fields have JSON tags so we can verify those are used over the
// property name.
type TestStruct struct {
Expand Down Expand Up @@ -43,71 +47,71 @@ func TestFromValidator(t *testing.T) {

testStruct := validTestStruct()
testStruct.MaxInt = 1
require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MaxSlice", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.MaxSlice = []string{"1"}
require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MaxString", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.MaxString = "value"
require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MinInt", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.MinInt = 0
require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MinSlice", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.MinSlice = nil
require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MinString", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.MinString = ""
require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("OneOf", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.OneOf = "red"
require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("Required", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.Required = ""
require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("Unsupported", func(t *testing.T) {
t.Parallel()

testStruct := validTestStruct()
testStruct.Unsupported = "abc"
require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})

t.Run("MultipleErrors", func(t *testing.T) {
Expand All @@ -116,7 +120,7 @@ func TestFromValidator(t *testing.T) {
testStruct := validTestStruct()
testStruct.MinInt = 0
testStruct.Required = ""
require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct)))
require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct)))
})
}

Expand Down
Loading