diff --git a/httperror.go b/httperror.go index 682cce2a0..6e14da3d9 100644 --- a/httperror.go +++ b/httperror.go @@ -14,6 +14,16 @@ type HTTPStatusCoder interface { StatusCode() int } +// StatusCode returns status code from error if it implements HTTPStatusCoder interface. +// If error does not implement the interface it returns 0. +func StatusCode(err error) int { + var sc HTTPStatusCoder + if errors.As(err, &sc) { + return sc.StatusCode() + } + return 0 +} + // Following errors can produce HTTP status code by implementing HTTPStatusCoder interface var ( ErrBadRequest = &httpError{http.StatusBadRequest} // 400 diff --git a/httperror_test.go b/httperror_test.go index 9ae88abcb..0a91bbc9c 100644 --- a/httperror_test.go +++ b/httperror_test.go @@ -5,9 +5,11 @@ package echo import ( "errors" - "github.com/stretchr/testify/assert" + "fmt" "net/http" "testing" + + "github.com/stretchr/testify/assert" ) func TestHTTPError_StatusCode(t *testing.T) { @@ -65,3 +67,43 @@ func TestNewHTTPError(t *testing.T) { assert.Equal(t, err2, err) } + +func TestStatusCode(t *testing.T) { + var testCases = []struct { + name string + err error + expect int + }{ + { + name: "ok, HTTPError", + err: &HTTPError{Code: http.StatusNotFound}, + expect: http.StatusNotFound, + }, + { + name: "ok, sentinel error", + err: ErrNotFound, + expect: http.StatusNotFound, + }, + { + name: "ok, wrapped HTTPError", + err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}), + expect: http.StatusTeapot, + }, + { + name: "nok, normal error", + err: errors.New("error"), + expect: 0, + }, + { + name: "nok, nil", + err: nil, + expect: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, StatusCode(tc.err)) + }) + } +}