From e2cc0c1582df13bd7160c6918be860152276af99 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Mon, 17 Mar 2025 11:08:47 -0500 Subject: [PATCH] allow passing a custom validator instance Rework the internal validation logic so that a customized validator can be provided. This allows for custom types and validations to be registered for use within the framework. Custom validators are passed as part of MountOpts. --- apiendpoint/api_endpoint.go | 20 +++++++++++++------- apitest/apitest.go | 20 +++++++++++++++----- apitest/apitest_test.go | 19 ++++++++++++++++--- internal/validate/validate.go | 17 +++++------------ internal/validate/validate_test.go | 24 ++++++++++++++---------- 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/apiendpoint/api_endpoint.go b/apiendpoint/api_endpoint.go index 206e33f..148a94a 100644 --- a/apiendpoint/api_endpoint.go +++ b/apiendpoint/api_endpoint.go @@ -14,6 +14,7 @@ import ( "net/http" "time" + "github.com/go-playground/validator/v10" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" @@ -94,6 +95,9 @@ type MountOpts struct { // MiddlewareStack is a stack of middleware that will be mounted in front of // the API endpoint handler. If not specified, no middleware will be used. MiddlewareStack *apimiddleware.MiddlewareStack + // Validator is the validator to use for this endpoint. If not specified, + // the default validator will be used. + Validator *validator.Validate } // Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log @@ -108,6 +112,11 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI logger = slog.Default() } + validator := opts.Validator + if validator == nil { + validator = validate.Default + } + apiEndpoint.SetLogger(logger) meta := apiEndpoint.Meta() @@ -115,7 +124,7 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI apiEndpoint.SetMeta(meta) innerHandler := func(w http.ResponseWriter, r *http.Request) { - executeAPIEndpoint(w, r, opts.Logger, meta, apiEndpoint.Execute) + executeAPIEndpoint(w, r, opts.Logger, meta, validator, apiEndpoint.Execute) } if opts.MiddlewareStack != nil { @@ -127,13 +136,10 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI return apiEndpoint } -func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { +func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, validator *validator.Validate, execute func(ctx context.Context, req *TReq) (*TResp, error)) { ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() - // Run as much code as we can in a sub-function that can return an error. - // This is more convenient to write, but is also safer because unlike when - // writing errors to ResponseWriter, there's no danger of a missing return. err := func() error { var req TReq if r.Method != http.MethodGet { @@ -161,8 +167,8 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ } } - if err := validate.StructCtx(ctx, &req); err != nil { - return apierror.NewBadRequest(validate.PublicFacingMessage(err)) + if err := validator.StructCtx(ctx, &req); err != nil { + return apierror.NewBadRequest(validate.PublicFacingMessage(validator, err)) } resp, err := execute(ctx, &req) diff --git a/apitest/apitest.go b/apitest/apitest.go index 45f52a0..341970f 100644 --- a/apitest/apitest.go +++ b/apitest/apitest.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/riverqueue/apiframe/apiendpoint" "github.com/riverqueue/apiframe/apierror" "github.com/riverqueue/apiframe/internal/validate" ) @@ -21,11 +22,20 @@ import ( // Sample invocation: // // endpoint := &testEndpoint{} -// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, &testRequest{ReqField: "string"}) +// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, nil, &testRequest{ReqField: "string"}) // require.NoError(t, err) -func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), req *TReq) (*TResp, error) { - if err := validate.StructCtx(ctx, req); err != nil { - return nil, apierror.NewBadRequest(validate.PublicFacingMessage(err)) +func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), opts *apiendpoint.MountOpts, req *TReq) (*TResp, error) { + if opts == nil { + opts = &apiendpoint.MountOpts{} + } + + validator := opts.Validator + if validator == nil { + validator = validate.Default + } + + if err := validator.StructCtx(ctx, req); err != nil { + return nil, apierror.NewBadRequest(validate.PublicFacingMessage(validator, err)) } resp, err := handler(ctx, req) @@ -33,7 +43,7 @@ func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(contex return nil, err } - if err := validate.StructCtx(ctx, resp); err != nil { + if err := validator.StructCtx(ctx, resp); err != nil { return nil, fmt.Errorf("apitest: error validating response API resource: %w", err) } diff --git a/apitest/apitest_test.go b/apitest/apitest_test.go index 58a7714..dd5a204 100644 --- a/apitest/apitest_test.go +++ b/apitest/apitest_test.go @@ -4,8 +4,10 @@ import ( "context" "testing" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/require" + "github.com/riverqueue/apiframe/apiendpoint" "github.com/riverqueue/apiframe/apierror" ) @@ -28,7 +30,7 @@ func TestInvokeHandler(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - resp, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"}) + resp, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"}) require.NoError(t, err) require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp) }) @@ -36,7 +38,7 @@ func TestInvokeHandler(t *testing.T) { t.Run("ValidatesRequest", func(t *testing.T) { t.Parallel() - _, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: ""}) + _, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: ""}) require.Equal(t, apierror.NewBadRequestf("Field `req_field` is required."), err) }) @@ -47,7 +49,18 @@ func TestInvokeHandler(t *testing.T) { return &testResponse{RequiredRespField: ""}, nil } - _, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"}) + _, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"}) require.EqualError(t, err, "apitest: error validating response API resource: Key: 'testResponse.resp_field' Error:Field validation for 'resp_field' failed on the 'required' tag") }) + + t.Run("CustomValidator", func(t *testing.T) { + t.Parallel() + + customValidator := validator.New() + opts := &apiendpoint.MountOpts{Validator: customValidator} + + resp, err := InvokeHandler(ctx, handler, opts, &testRequest{RequiredReqField: "string"}) + require.NoError(t, err) + require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp) + }) } diff --git a/internal/validate/validate.go b/internal/validate/validate.go index fd8cee5..ab44821 100644 --- a/internal/validate/validate.go +++ b/internal/validate/validate.go @@ -4,7 +4,6 @@ package validate import ( - "context" "fmt" "reflect" "strings" @@ -12,11 +11,12 @@ import ( "github.com/go-playground/validator/v10" ) -// WithRequiredStructEnabled can be removed once validator/v11 is released. -var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals +// Default is the package's default validator instance. WithRequiredStructEnabled +// can be removed once validator/v11 is released. +var Default = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals func init() { //nolint:gochecknoinits - validate.RegisterTagNameFunc(preferPublicName) + Default.RegisterTagNameFunc(preferPublicName) } // PublicFacingMessage builds a complete error message from a validator error @@ -24,7 +24,7 @@ func init() { //nolint:gochecknoinits // // I only added a few possible validations to start. We'll probably need to add // more as we go and expand our usage. -func PublicFacingMessage(validatorErr error) string { +func PublicFacingMessage(v *validator.Validate, validatorErr error) string { var message string //nolint:errorlint @@ -97,13 +97,6 @@ func PublicFacingMessage(validatorErr error) string { return strings.TrimSpace(message) } -// StructCtx validates a structs exposed fields, and automatically validates -// nested structs, unless otherwise specified and also allows passing of -// context.Context for contextual validation information. -func StructCtx(ctx context.Context, s any) error { - return validate.StructCtx(ctx, s) -} - // preferPublicName is a validator tag naming function that uses public names // like a field's JSON tag instead of actual field names in structs. // This is important because we sent these back as user-facing errors (and the diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go index 81e2d69..6d13d0d 100644 --- a/internal/validate/validate_test.go +++ b/internal/validate/validate_test.go @@ -4,12 +4,16 @@ import ( "reflect" "testing" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/require" ) func TestFromValidator(t *testing.T) { t.Parallel() + validator := validator.New(validator.WithRequiredStructEnabled()) + validator.RegisterTagNameFunc(preferPublicName) + // Fields have JSON tags so we can verify those are used over the // property name. type TestStruct struct { @@ -43,7 +47,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MaxInt = 1 - require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MaxSlice", func(t *testing.T) { @@ -51,7 +55,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MaxSlice = []string{"1"} - require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MaxString", func(t *testing.T) { @@ -59,7 +63,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MaxString = "value" - require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MinInt", func(t *testing.T) { @@ -67,7 +71,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MinInt = 0 - require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MinSlice", func(t *testing.T) { @@ -75,7 +79,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MinSlice = nil - require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MinString", func(t *testing.T) { @@ -83,7 +87,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MinString = "" - require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("OneOf", func(t *testing.T) { @@ -91,7 +95,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.OneOf = "red" - require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("Required", func(t *testing.T) { @@ -99,7 +103,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.Required = "" - require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("Unsupported", func(t *testing.T) { @@ -107,7 +111,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.Unsupported = "abc" - require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) t.Run("MultipleErrors", func(t *testing.T) { @@ -116,7 +120,7 @@ func TestFromValidator(t *testing.T) { testStruct := validTestStruct() testStruct.MinInt = 0 testStruct.Required = "" - require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct))) }) }