diff --git a/internal/api/server.go b/internal/api/server.go index 496282b67..4b4f9bdee 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -17,6 +17,26 @@ import ( "github.com/modelcontextprotocol/registry/internal/telemetry" ) +// NulByteValidationMiddleware rejects requests containing NUL bytes in URL path or query parameters +// This prevents PostgreSQL encoding errors and returns a proper 400 Bad Request +func NulByteValidationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check URL path for NUL bytes + if strings.ContainsRune(r.URL.Path, '\x00') { + http.Error(w, "Invalid request: URL path contains null bytes", http.StatusBadRequest) + return + } + + // Check raw query string for NUL bytes + if strings.ContainsRune(r.URL.RawQuery, '\x00') { + http.Error(w, "Invalid request: query parameters contain null bytes", http.StatusBadRequest) + return + } + + next.ServeHTTP(w, r) + }) +} + // TrailingSlashMiddleware redirects requests with trailing slashes to their canonical form func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -67,8 +87,8 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, metr }) // Wrap the mux with middleware stack - // Order: TrailingSlash -> CORS -> Mux - handler := TrailingSlashMiddleware(corsHandler.Handler(mux)) + // Order: NulByteValidation -> TrailingSlash -> CORS -> Mux + handler := NulByteValidationMiddleware(TrailingSlashMiddleware(corsHandler.Handler(mux))) server := &Server{ config: cfg, diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 6309b471e..c7d189ea7 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -3,11 +3,84 @@ package api_test import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/modelcontextprotocol/registry/internal/api" ) +func TestNulByteValidationMiddleware(t *testing.T) { + // Create a simple handler that returns "OK" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + // Wrap with our middleware + middleware := api.NulByteValidationMiddleware(handler) + + t.Run("normal path should pass through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("path with query params should pass through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=abc123", nil) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("path with NUL byte should return 400", func(t *testing.T) { + // Create request with NUL byte in path by manually setting URL + req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil) + req.URL.Path = "/v0/servers/\x00" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + if !strings.Contains(w.Body.String(), "URL path contains null bytes") { + t.Errorf("expected body to contain error message, got %q", w.Body.String()) + } + }) + + t.Run("query with NUL byte should return 400", func(t *testing.T) { + // Create request with NUL byte in query by manually setting RawQuery + req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) + req.URL.RawQuery = "cursor=\x00" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + if !strings.Contains(w.Body.String(), "query parameters contain null bytes") { + t.Errorf("expected body to contain error message, got %q", w.Body.String()) + } + }) + + t.Run("path with embedded NUL byte should return 400", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil) + req.URL.Path = "/v0/servers/test\x00name" + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) +} + func TestTrailingSlashMiddleware(t *testing.T) { // Create a simple handler that returns "OK" handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {