From 245f9dd61e0e69b710af8a230ce686c3fa8df478 Mon Sep 17 00:00:00 2001 From: Blake Gentry Date: Fri, 14 Mar 2025 17:45:01 -0500 Subject: [PATCH] cleanly handle http.MaxBytesError for enforcing req payload size If a middleware applies `http.MaxBytesReader` and we encounter an `*http.MaxBytesError` when reading the request body, return an HTTP 413 with a clean error. --- apiendpoint/api_endpoint.go | 4 ++++ apiendpoint/api_endpoint_test.go | 16 +++++++++++++++- apierror/api_error.go | 17 +++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/apiendpoint/api_endpoint.go b/apiendpoint/api_endpoint.go index 3787c32..206e33f 100644 --- a/apiendpoint/api_endpoint.go +++ b/apiendpoint/api_endpoint.go @@ -139,6 +139,10 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ if r.Method != http.MethodGet { reqData, err := io.ReadAll(r.Body) if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return apierror.NewRequestEntityTooLarge("Request entity too large.") + } return fmt.Errorf("error reading request body: %w", err) } diff --git a/apiendpoint/api_endpoint_test.go b/apiendpoint/api_endpoint_test.go index 036896a..5691f7d 100644 --- a/apiendpoint/api_endpoint_test.go +++ b/apiendpoint/api_endpoint_test.go @@ -72,6 +72,19 @@ func TestMountAndServe(t *testing.T) { requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder) }) + t.Run("MaxBytesErrorHandling", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + payload := mustMarshalJSON(t, &postRequest{Message: "Hello."}) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(payload)) + req.Body = http.MaxBytesReader(bundle.recorder, io.NopCloser(bytes.NewReader(payload)), int64(len(payload)-1)) + mux.ServeHTTP(bundle.recorder, req) + requireStatusAndJSONResponse(t, http.StatusRequestEntityTooLarge, &apierror.APIError{Message: "Request entity too large."}, bundle.recorder) + }) + t.Run("MethodNotAllowed", func(t *testing.T) { t.Parallel() @@ -302,9 +315,10 @@ func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, type postEndpoint struct { Endpoint[postRequest, postResponse] + MaxBodyBytes int64 } -func (*postEndpoint) Meta() *EndpointMeta { +func (a *postEndpoint) Meta() *EndpointMeta { return &EndpointMeta{ Pattern: "POST /api/post-endpoint/{id}", StatusCode: http.StatusCreated, diff --git a/apierror/api_error.go b/apierror/api_error.go index 4923dde..3510419 100644 --- a/apierror/api_error.go +++ b/apierror/api_error.go @@ -131,6 +131,23 @@ func NewNotFoundf(format string, a ...any) *NotFound { return NewNotFound(fmt.Sprintf(format, a...)) } +// +// RequestEntityTooLarge +// + +type RequestEntityTooLarge struct { //nolint:errname + APIError +} + +func NewRequestEntityTooLarge(message string) *RequestEntityTooLarge { + return &RequestEntityTooLarge{ + APIError: APIError{ + Message: message, + StatusCode: http.StatusRequestEntityTooLarge, + }, + } +} + // // ServiceUnavailable //