Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ import (
"github.com/modelcontextprotocol/registry/internal/telemetry"
)

// NulByteValidationMiddleware rejects requests containing NUL bytes in URL path or query parameters
// This prevents PostgreSQL encoding errors and returns a proper 400 Bad Request
func NulByteValidationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check URL path for NUL bytes
if strings.ContainsRune(r.URL.Path, '\x00') {
http.Error(w, "Invalid request: URL path contains null bytes", http.StatusBadRequest)
return
}

// Check raw query string for NUL bytes
if strings.ContainsRune(r.URL.RawQuery, '\x00') {
http.Error(w, "Invalid request: query parameters contain null bytes", http.StatusBadRequest)
return
}

next.ServeHTTP(w, r)
})
}

// TrailingSlashMiddleware redirects requests with trailing slashes to their canonical form
func TrailingSlashMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -67,8 +87,8 @@ func NewServer(cfg *config.Config, registryService service.RegistryService, metr
})

// Wrap the mux with middleware stack
// Order: TrailingSlash -> CORS -> Mux
handler := TrailingSlashMiddleware(corsHandler.Handler(mux))
// Order: NulByteValidation -> TrailingSlash -> CORS -> Mux
handler := NulByteValidationMiddleware(TrailingSlashMiddleware(corsHandler.Handler(mux)))

server := &Server{
config: cfg,
Expand Down
73 changes: 73 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,84 @@ package api_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/modelcontextprotocol/registry/internal/api"
)

func TestNulByteValidationMiddleware(t *testing.T) {
// Create a simple handler that returns "OK"
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})

// Wrap with our middleware
middleware := api.NulByteValidationMiddleware(handler)

t.Run("normal path should pass through", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
})

t.Run("path with query params should pass through", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v0/servers?cursor=abc123", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
})

t.Run("path with NUL byte should return 400", func(t *testing.T) {
// Create request with NUL byte in path by manually setting URL
req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil)
req.URL.Path = "/v0/servers/\x00"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
if !strings.Contains(w.Body.String(), "URL path contains null bytes") {
t.Errorf("expected body to contain error message, got %q", w.Body.String())
}
})

t.Run("query with NUL byte should return 400", func(t *testing.T) {
// Create request with NUL byte in query by manually setting RawQuery
req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil)
req.URL.RawQuery = "cursor=\x00"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
if !strings.Contains(w.Body.String(), "query parameters contain null bytes") {
t.Errorf("expected body to contain error message, got %q", w.Body.String())
}
})

t.Run("path with embedded NUL byte should return 400", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v0/servers/test", nil)
req.URL.Path = "/v0/servers/test\x00name"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)

if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
}

func TestTrailingSlashMiddleware(t *testing.T) {
// Create a simple handler that returns "OK"
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down
Loading