From c03013c263983b71c40e32d3e210ae66c4c9be1a Mon Sep 17 00:00:00 2001 From: majiayu000 <1835304752@qq.com> Date: Sun, 28 Dec 2025 22:29:51 +0800 Subject: [PATCH] fix(api): reject requests with NUL bytes in URL Added NulByteValidationMiddleware to validate incoming requests and return 400 Bad Request when NUL bytes are detected in the URL path or query parameters. This prevents PostgreSQL encoding errors and properly rejects malformed input. Fixes #862 Signed-off-by: majiayu000 <1835304752@qq.com> --- internal/api/server.go | 24 +++++++++++- internal/api/server_test.go | 73 +++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index 496282b67..4b4f9bdee 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -17,6 +17,26 @@ import ( "github.com/modelcontextprotocol/registry/internal/telemetry" ) +// NulByteValidationMiddleware rejects requests containing NUL bytes in URL path or query parameters +// This prevents PostgreSQL encoding errors and returns a proper 400 Bad Request +func NulByteValidationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check URL path for NUL bytes + if strings.ContainsRune(r.URL.Path, '\x00') { + http.Error(w, "Invalid request: URL path contains null bytes", http.StatusBadRequest) + return + } + + // Check raw query string for NUL bytes + if strings.ContainsRune(r.URL.RawQuery, '\x00') { + http.Error(w, "Invalid request: query parameters contain null bytes", http.StatusBadRequest) + return + } + + next.ServeHTTP(w, r) + }) +} + // TrailingSlashMiddleware redirects requests with trailing slashes to their canonical form func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -67,8 +87,8 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, metr }) // Wrap the mux with middleware stack - // Order: TrailingSlash -> CORS -> Mux - handler := TrailingSlashMiddleware(corsHandler.Handler(mux)) + // Order: NulByteValidation -> TrailingSlash -> CORS -> Mux + handler := NulByteValidationMiddleware(TrailingSlashMiddleware(corsHandler.Handler(mux))) server := &Server{ config: cfg, diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 6309b471e..c7d189ea7 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -3,11 +3,84 @@ package api_test import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/modelcontextprotocol/registry/internal/api" ) +func TestNulByteValidationMiddleware(t *testing.T) { + // Create a simple handler that returns "OK" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + // Wrap with our middleware + middleware := api.NulByteValidationMiddleware(handler) + + t.Run("normal path should pass through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("path with query params should pass through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=abc123", nil) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("path with NUL byte should return 400", func(t *testing.T) { + // Create request with NUL byte in path by manually setting URL + req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil) + req.URL.Path = "/v0/servers/\x00" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + if !strings.Contains(w.Body.String(), "URL path contains null bytes") { + t.Errorf("expected body to contain error message, got %q", w.Body.String()) + } + }) + + t.Run("query with NUL byte should return 400", func(t *testing.T) { + // Create request with NUL byte in query by manually setting RawQuery + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + req.URL.RawQuery = "cursor=\x00" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + if !strings.Contains(w.Body.String(), "query parameters contain null bytes") { + t.Errorf("expected body to contain error message, got %q", w.Body.String()) + } + }) + + t.Run("path with embedded NUL byte should return 400", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil) + req.URL.Path = "/v0/servers/test\x00name" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) +} + func TestTrailingSlashMiddleware(t *testing.T) { // Create a simple handler that returns "OK" handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {