From 9903b56845a108bf6415e2b9fd9f90c00c086bc2 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 4 Feb 2026 18:56:23 +0100 Subject: [PATCH 1/2] Add Origin header validation - Fixes #695 - Does not implement Host header validation yet Signed-off-by: Daniel Garnier-Moiroux --- .../server/ConformanceServlet.java | 29 +- ...faultServerTransportSecurityValidator.java | 132 +++++++ ...HttpServletSseServerTransportProvider.java | 67 +++- .../HttpServletStatelessServerTransport.java | 57 ++- ...vletStreamableServerTransportProvider.java | 76 +++- .../ServerTransportSecurityException.java | 48 +++ .../ServerTransportSecurityValidator.java | 36 ++ ...ServerTransportSecurityValidatorTests.java | 197 +++++++++++ ...rverTransportSecurityIntegrationTests.java | 292 +++++++++++++++ .../WebFluxSseServerTransportProvider.java | 46 ++- .../WebFluxStatelessServerTransport.java | 37 +- ...FluxStreamableServerTransportProvider.java | 52 ++- ...rverTransportSecurityIntegrationTests.java | 291 +++++++++++++++ .../WebMvcSseServerTransportProvider.java | 46 ++- .../WebMvcStatelessServerTransport.java | 37 +- ...bMvcStreamableServerTransportProvider.java | 57 ++- ...rverTransportSecurityIntegrationTests.java | 334 ++++++++++++++++++ 17 files changed, 1804 insertions(+), 30 deletions(-) create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java index ca09e55e4..ff127cd3d 100644 --- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java +++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java @@ -8,7 +8,34 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpSchema.*; +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java new file mode 100644 index 000000000..5321aada7 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -0,0 +1,132 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation of {@link ServerTransportSecurityValidator} that validates the + * Origin header against a list of allowed origins. + * + *

+ * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*"). + * + * @author Daniel Garnier-Moiroux + * @see ServerTransportSecurityValidator + * @see ServerTransportSecurityException + */ +public class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator { + + private static final String ORIGIN_HEADER = "Origin"; + + private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, + "Invalid Origin header"); + + private final List allowedOrigins; + + /** + * Creates a new validator with the specified allowed origins. + * @param allowedOrigins List of allowed origin patterns. Supports exact matches + * (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*") + */ + public DefaultServerTransportSecurityValidator(List allowedOrigins) { + Assert.notNull(allowedOrigins, "allowedOrigins must not be null"); + this.allowedOrigins = allowedOrigins; + } + + @Override + public void validateHeaders(Map> headers) throws ServerTransportSecurityException { + for (Map.Entry> entry : headers.entrySet()) { + if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { + List values = entry.getValue(); + if (values != null && !values.isEmpty()) { + validateOrigin(values.get(0)); + } + break; + } + } + } + + /** + * Validates a single origin value against the allowed origins. Subclasses can + * override this method to customize origin validation logic. + * @param origin The origin header value, or null if not present + * @throws ServerTransportSecurityException if the origin is not allowed + */ + protected void validateOrigin(String origin) throws ServerTransportSecurityException { + // Origin absent = no validation needed (same-origin request) + if (origin == null || origin.isBlank()) { + return; + } + + for (String allowed : allowedOrigins) { + if (allowed.equals(origin)) { + return; + } + else if (allowed.endsWith(":*")) { + // Wildcard port pattern: "http://example.com:*" + String baseOrigin = allowed.substring(0, allowed.length() - 2); + if (origin.equals(baseOrigin) || origin.startsWith(baseOrigin + ":")) { + return; + } + } + + } + + throw INVALID_ORIGIN; + } + + /** + * Creates a new builder for constructing a DefaultServerTransportSecurityValidator. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link DefaultServerTransportSecurityValidator}. + */ + public static class Builder { + + private final List allowedOrigins = new ArrayList<>(); + + /** + * Adds an allowed origin pattern. + * @param origin The origin to allow (e.g., "http://localhost:8080" or + * "http://example.com:*") + * @return this builder instance + */ + public Builder allowedOrigin(String origin) { + this.allowedOrigins.add(origin); + return this; + } + + /** + * Adds multiple allowed origin patterns. + * @param origins The origins to allow + * @return this builder instance + */ + public Builder allowedOrigins(List origins) { + Assert.notNull(origins, "origins must not be null"); + this.allowedOrigins.addAll(origins); + return this; + } + + /** + * Builds the validator instance. + * @return A new DefaultServerTransportSecurityValidator + */ + public DefaultServerTransportSecurityValidator build() { + return new DefaultServerTransportSecurityValidator(allowedOrigins); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 96cebb74a..d12fb8c9e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 - 2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -8,6 +8,9 @@ import java.io.IOException; import java.io.PrintWriter; import java.time.Duration; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -142,6 +145,11 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -153,23 +161,25 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement * @param keepAliveInterval The interval for keep-alive pings, or null to disable * keep-alive functionality * @param contextExtractor The extractor for transport context from the request. - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. + * @param securityValidator The security validator for validating HTTP requests. */ private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(messageEndpoint, "messageEndpoint must not be null"); Assert.notNull(sseEndpoint, "sseEndpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; if (keepAliveInterval != null) { @@ -246,6 +256,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -311,6 +330,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + // Get the session ID from the request parameter String sessionId = request.getParameter("sessionId"); if (sessionId == null) { @@ -411,6 +439,21 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws } } + /** + * Extracts all headers from the HTTP servlet request into a map. + * @param request The HTTP servlet request + * @return A map of header names to their values + */ + private Map> extractHeaders(HttpServletRequest request) { + Map> headers = new HashMap<>(); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + headers.put(name, Collections.list(request.getHeaders(name))); + } + return headers; + } + /** * Cleans up resources when the servlet is being destroyed. *

@@ -547,6 +590,8 @@ public static class Builder { private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** * Sets the JsonMapper implementation to use for serialization/deserialization. If * not specified, a JacksonJsonMapper will be created from the configured @@ -621,6 +666,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return This builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -633,7 +690,7 @@ public HttpServletSseServerTransportProvider build() { } return new HttpServletSseServerTransportProvider( jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval, contextExtractor); + keepAliveInterval, contextExtractor, securityValidator); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 40767f416..106f834f5 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -7,6 +7,11 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,15 +63,23 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements private volatile boolean isClosing = false; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; } @Override @@ -122,6 +135,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + McpTransportContext transportContext = this.contextExtractor.extract(request); String accept = request.getHeader(ACCEPT); @@ -209,6 +231,21 @@ private void responseError(HttpServletResponse response, int httpCode, McpError writer.flush(); } + /** + * Extracts all headers from the HTTP servlet request into a map. + * @param request The HTTP servlet request + * @return A map of header names to their values + */ + private Map> extractHeaders(HttpServletRequest request) { + Map> headers = new HashMap<>(); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + headers.put(name, Collections.list(request.getHeaders(name))); + } + return headers; + } + /** * Cleans up resources when the servlet is being destroyed. *

@@ -243,6 +280,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + private Builder() { // used by a static method } @@ -288,6 +327,18 @@ public Builder contextExtractor(McpTransportContextExtractor return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link HttpServletStatelessServerTransport} with the * configured settings. @@ -297,7 +348,7 @@ public Builder contextExtractor(McpTransportContextExtractor public HttpServletStatelessServerTransport build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - mcpEndpoint, contextExtractor); + mcpEndpoint, contextExtractor, securityValidator); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 34671c105..b7c8e7b23 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -9,7 +9,11 @@ import java.io.PrintWriter; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; @@ -119,6 +123,11 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Constructs a new HttpServletStreamableServerTransportProvider instance. * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of @@ -127,19 +136,24 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. * @param disallowDelete Whether to disallow DELETE requests on the endpoint. * @param contextExtractor The extractor for transport context from the request. + * @param keepAliveInterval The interval for keep-alive pings. If null, no keep-alive + * will be scheduled. + * @param securityValidator The security validator for validating HTTP requests. * @throws IllegalArgumentException if any parameter is null */ private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval) { + Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; if (keepAliveInterval != null) { @@ -246,6 +260,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + List badRequestErrors = new ArrayList<>(); String accept = request.getHeader(ACCEPT); @@ -373,6 +396,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + List badRequestErrors = new ArrayList<>(); String accept = request.getHeader(ACCEPT); @@ -536,6 +568,15 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response return; } + try { + Map> headers = extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + if (this.disallowDelete) { response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); return; @@ -586,6 +627,21 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m return; } + /** + * Extracts all headers from the HTTP servlet request into a map. + * @param request The HTTP servlet request + * @return A map of header names to their values + */ + private Map> extractHeaders(HttpServletRequest request) { + Map> headers = new HashMap<>(); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + headers.put(name, Collections.list(request.getHeaders(name))); + } + return headers; + } + /** * Sends an SSE event to a client with a specific ID. * @param writer The writer to send the event through @@ -774,6 +830,8 @@ public static class Builder { private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. @@ -833,6 +891,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} * with the configured settings. @@ -843,7 +913,7 @@ public HttpServletStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new HttpServletStreamableServerTransportProvider( jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, - contextExtractor, keepAliveInterval); + contextExtractor, keepAliveInterval, securityValidator); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java new file mode 100644 index 000000000..96a06d3bd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +/** + * Exception thrown when security validation fails for an HTTP request. Contains HTTP + * status code and message. + * + * @author Daniel Garnier-Moiroux + * @see ServerTransportSecurityValidator + */ +public class ServerTransportSecurityException extends Exception { + + private final int statusCode; + + /** + * Creates a new ServerTransportSecurityException with the specified HTTP status code + * and message. + */ + public ServerTransportSecurityException(int statusCode, String message) { + super(message); + this.statusCode = statusCode; + } + + public int getStatusCode() { + return statusCode; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ServerTransportSecurityException that = (ServerTransportSecurityException) obj; + return statusCode == that.statusCode && java.util.Objects.equals(getMessage(), that.getMessage()); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(statusCode, getMessage()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java new file mode 100644 index 000000000..ce805931f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.List; +import java.util.Map; + +/** + * Interface for validating HTTP requests in server transports. Implementations can + * validate Origin headers, Host headers, or any other security-related headers according + * to the MCP specification. + * + * @author Daniel Garnier-Moiroux + * @see DefaultServerTransportSecurityValidator + * @see ServerTransportSecurityException + */ +@FunctionalInterface +public interface ServerTransportSecurityValidator { + + /** + * A no-op validator that accepts all requests without validation. + */ + ServerTransportSecurityValidator NOOP = headers -> { + }; + + /** + * Validates the HTTP headers from an incoming request. + * @param headers A map of header names to their values (multi-valued headers + * supported) + * @throws ServerTransportSecurityException if validation fails + */ + void validateHeaders(Map> headers) throws ServerTransportSecurityException; + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java new file mode 100644 index 000000000..7e1593e1b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -0,0 +1,197 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Daniel Garnier-Moiroux + */ +class DefaultServerTransportSecurityValidatorTests { + + private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, + "Invalid Origin header"); + + private final DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:8080") + .build(); + + @Test + void builder() { + assertThatCode(() -> DefaultServerTransportSecurityValidator.builder().build()).doesNotThrowAnyException(); + assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedOrigins(null).build()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void originHeaderMissing() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } + + @Test + void originHeaderListEmpty() { + assertThatCode(() -> validator.validateHeaders(Map.of("Origin", List.of()))).doesNotThrowAnyException(); + } + + @Test + void caseInsensitive() { + var headers = Map.of("origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void exactMatch() { + var headers = originHeader("http://localhost:8080"); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPort() { + + var headers = originHeader("http://localhost:3000"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentHost() { + + var headers = originHeader("http://example.com:8080"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentScheme() { + + var headers = originHeader("https://localhost:8080"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Nested + class WildcardPort { + + private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .build(); + + @Test + void anyPortWithWildcard() { + var headers = originHeader("http://localhost:3000"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void noPortWithWildcard() { + var headers = originHeader("http://localhost"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPortWithWildcard() { + var headers = originHeader("http://localhost:8080"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHostWithWildcard() { + var headers = originHeader("http://example.com:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentSchemeWithWildcard() { + var headers = originHeader("https://localhost:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class MultipleOrigins { + + DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("http://example.com:3000") + .allowedOrigin("http://myapp.com:*") + .build(); + + @Test + void matchingOneOfMultiple() { + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void matchingWildcardInMultiple() { + var headers = originHeader("http://myapp.com:9999"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void notMatchingAny() { + var headers = originHeader("http://malicious.example.com:1234"); + + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class BuilderTests { + + @Test + void shouldAddMultipleOriginsWithAllowedOriginsMethod() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) + .build(); + + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void shouldCombineAllowedOriginMethods() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) + .build(); + + assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + .doesNotThrowAnyException(); + } + + } + + private static Map> originHeader(String origin) { + return Map.of("Origin", List.of(origin)); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java new file mode 100644 index 000000000..e9e64c0d0 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java @@ -0,0 +1,292 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.http.HttpServlet; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.BeforeParameterizedClassInvocation; +import org.junit.jupiter.params.Parameter; +import org.junit.jupiter.params.ParameterizedClass; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +/** + * Test the header security validation for all transport types. + * + * @author Daniel Garnier-Moiroux + */ +@ParameterizedClass +@MethodSource("transports") +class ServerTransportSecurityIntegrationTests { + + private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + + @Parameter + private static Transport transport; + + private static Tomcat tomcat; + + private static String baseUrl; + + @BeforeParameterizedClassInvocation + static void createTransportAndStartTomcat(Transport transport) { + var port = TomcatTestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + startTomcat(transport.servlet(), port); + } + + @AfterAll + static void afterAll() { + stopTomcat(); + } + + private McpSyncClient mcpClient; + + private final TestRequestCustomizer requestCustomizer = new TestRequestCustomizer(); + + @BeforeEach + void setUp() { + mcpClient = transport.createMcpClient(baseUrl, requestCustomizer); + } + + @AfterEach + void tearDown() { + mcpClient.close(); + } + + @Test + void originAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void noOrigin() { + requestCustomizer.setOriginHeader(null); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectOriginNotAllowed() { + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageOriginNotAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + mcpClient.initialize(); + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.listTools()); + } + + // ---------------------------------------------------- + // Tomcat management + // ---------------------------------------------------- + + private static void startTomcat(jakarta.servlet.Servlet servlet, int port) { + tomcat = TomcatTestUtil.createTomcatServer("", port, servlet); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private static void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // ---------------------------------------------------- + // Transport servers to test + // ---------------------------------------------------- + + /** + * All transport types we want to test. We use a {@link MethodSource} rather than a + * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name. + */ + static Stream transports() { + //@formatter:off + return Stream.of( + arguments(named("SSE", new Sse())), + arguments(named("Streamable HTTP", new StreamableHttp())), + arguments(named("Stateless", new Stateless())) + ); + //@formatter:on + } + + /** + * Represents a server transport we want to test, and how to create a client for the + * resulting MCP Server. + */ + interface Transport { + + McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer); + + HttpServlet servlet(); + + } + + /** + * SSE-based transport. + */ + static class Sse implements Transport { + + private final HttpServletSseServerTransportProvider transport; + + public Sse() { + transport = HttpServletSseServerTransportProvider.builder() + .messageEndpoint("/mcp/message") + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientSseClientTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class StreamableHttp implements Transport { + + private final HttpServletStreamableServerTransportProvider transport; + + public StreamableHttp() { + transport = HttpServletStreamableServerTransportProvider.builder() + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class Stateless implements Transport { + + private final HttpServletStatelessServerTransport transport; + + public Stateless() { + transport = HttpServletStatelessServerTransport.builder() + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private String originHeader = null; + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + if (originHeader != null) { + builder.header("Origin", originHeader); + } + } + + public void setOriginHeader(String originHeader) { + this.originHeader = originHeader; + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 0c80c5b8b..34d6e5085 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -7,6 +7,7 @@ import java.io.IOException; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import io.modelcontextprotocol.common.McpTransportContext; @@ -132,6 +133,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Constructs a new WebFlux SSE server transport provider instance. * @param jsonMapper The ObjectMapper to use for JSON serialization/deserialization of @@ -144,22 +150,26 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * @param keepAliveInterval The interval for sending keep-alive pings to clients. * @param contextExtractor The context extractor to use for extracting MCP transport * context from HTTP requests. Must not be null. + * @param securityValidator The security validator for validating HTTP requests. * @throws IllegalArgumentException if either parameter is null */ private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -273,6 +283,14 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); return ServerResponse.ok() @@ -332,6 +350,14 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + if (request.queryParam("sessionId").isEmpty()) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); } @@ -436,6 +462,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP * messages. @@ -513,6 +541,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -522,7 +562,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont public WebFluxSseServerTransportProvider build() { Assert.notNull(messageEndpoint, "Message endpoint must be set"); return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java index 400be341e..b225ab61b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; /** * Implementation of a WebFlux based {@link McpStatelessServerTransport}. @@ -46,15 +47,23 @@ public class WebFluxStatelessServerTransport implements McpStatelessServerTransp private volatile boolean isClosing = false; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.mcpEndpoint, this::handleGet) .POST(this.mcpEndpoint, this::handlePost) @@ -96,6 +105,14 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); @@ -160,6 +177,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + private Builder() { // used by a static method } @@ -205,6 +224,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link WebFluxStatelessServerTransport} with the * configured settings. @@ -214,7 +245,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont public WebFluxStatelessServerTransport build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - mcpEndpoint, contextExtractor); + mcpEndpoint, contextExtractor, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index deebfc616..762ee005d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -36,6 +36,7 @@ import java.io.IOException; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** @@ -67,17 +68,24 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, McpTransportContextExtractor contextExtractor, boolean disallowDelete, - Duration keepAliveInterval) { + Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; this.disallowDelete = disallowDelete; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.mcpEndpoint, this::handleGet) .POST(this.mcpEndpoint, this::handlePost) @@ -166,6 +174,14 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); return Mono.defer(() -> { @@ -221,6 +237,14 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); @@ -310,6 +334,14 @@ private Mono handleDelete(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); return Mono.defer(() -> { @@ -410,6 +442,8 @@ public static class Builder { private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + private Builder() { // used by a static method } @@ -477,6 +511,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with * the configured settings. @@ -487,7 +533,7 @@ public WebFluxStreamableServerTransportProvider build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new WebFluxStreamableServerTransportProvider( jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor, - disallowDelete, keepAliveInterval); + disallowDelete, keepAliveInterval, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java new file mode 100644 index 000000000..06e1286d2 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java @@ -0,0 +1,291 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.security; + +import java.time.Duration; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.BeforeParameterizedClassInvocation; +import org.junit.jupiter.params.Parameter; +import org.junit.jupiter.params.ParameterizedClass; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +/** + * Test the header security validation for all transport types. + * + * @author Daniel Garnier-Moiroux + */ +@ParameterizedClass +@MethodSource("transports") +public class WebFluxServerTransportSecurityIntegrationTests { + + private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + + @Parameter + private static Transport transport; + + private static DisposableServer httpServer; + + private static String baseUrl; + + @BeforeParameterizedClassInvocation + static void createTransportAndStartServer(Transport transport) { + var port = TestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + startServer(transport.routerFunction(), port); + } + + @AfterAll + static void afterAll() { + stopServer(); + } + + private McpSyncClient mcpClient; + + private final TestOriginHeaderExchangeFilterFunction exchangeFilterFunction = new TestOriginHeaderExchangeFilterFunction(); + + @BeforeEach + void setUp() { + mcpClient = transport.createMcpClient(baseUrl, exchangeFilterFunction); + } + + @AfterEach + void tearDown() { + mcpClient.close(); + } + + @Test + void originAllowed() { + exchangeFilterFunction.setOriginHeader(baseUrl); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void noOrigin() { + exchangeFilterFunction.setOriginHeader(null); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectOriginNotAllowed() { + exchangeFilterFunction.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageOriginNotAllowed() { + exchangeFilterFunction.setOriginHeader(baseUrl); + mcpClient.initialize(); + exchangeFilterFunction.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.listTools()); + } + + // ---------------------------------------------------- + // Server management + // ---------------------------------------------------- + + private static void startServer(RouterFunction routerFunction, int port) { + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(port).handle(adapter).bindNow(); + } + + private static void stopServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // ---------------------------------------------------- + // Transport servers to test + // ---------------------------------------------------- + + /** + * All transport types we want to test. We use a {@link MethodSource} rather than a + * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name. + */ + static Stream transports() { + //@formatter:off + return Stream.of( + arguments(named("SSE", new Sse())), + arguments(named("Streamable HTTP", new StreamableHttp())), + arguments(named("Stateless", new Stateless())) + ); + //@formatter:on + } + + /** + * Represents a server transport we want to test, and how to create a client for the + * resulting MCP Server. + */ + interface Transport { + + McpSyncClient createMcpClient(String baseUrl, TestOriginHeaderExchangeFilterFunction customizer); + + RouterFunction routerFunction(); + + } + + /** + * SSE-based transport. + */ + static class Sse implements Transport { + + private final WebFluxSseServerTransportProvider transportProvider; + + public Sse() { + transportProvider = WebFluxSseServerTransportProvider.builder() + .messageEndpoint("/mcp/message") + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, + TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + var transport = WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) + .jsonMapper(McpJsonMapper.getDefault()) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public RouterFunction routerFunction() { + return transportProvider.getRouterFunction(); + } + + } + + static class StreamableHttp implements Transport { + + private final WebFluxStreamableServerTransportProvider transportProvider; + + public StreamableHttp() { + transportProvider = WebFluxStreamableServerTransportProvider.builder() + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, + TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + var transport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public RouterFunction routerFunction() { + return transportProvider.getRouterFunction(); + } + + } + + static class Stateless implements Transport { + + private final WebFluxStatelessServerTransport transportProvider; + + public Stateless() { + transportProvider = WebFluxStatelessServerTransport.builder() + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) + .build(); + McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, + TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) { + var transport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction)) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public RouterFunction routerFunction() { + return transportProvider.getRouterFunction(); + } + + } + + static class TestOriginHeaderExchangeFilterFunction implements ExchangeFilterFunction { + + private String origin = null; + + public void setOriginHeader(String origin) { + this.origin = origin; + } + + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + var updatedRequest = ClientRequest.from(request).header("origin", this.origin).build(); + return next.exchange(updatedRequest); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 6c35de56d..7e925a0af 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -7,6 +7,7 @@ import java.io.IOException; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; @@ -118,6 +119,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization @@ -131,22 +137,26 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * @param keepAliveInterval The interval for sending keep-alive messages to clients. * @param contextExtractor The contextExtractor to fill in a * {@link McpTransportContext}. + * @param securityValidator The security validator for validating HTTP requests. * @throws IllegalArgumentException if any parameter is null */ private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -255,6 +265,14 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + // Send initial endpoint event return ServerResponse.sse(sseBuilder -> { WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder); @@ -313,6 +331,14 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + if (request.param(SESSION_ID).isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } @@ -474,6 +500,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param jsonMapper The object mapper to use @@ -549,6 +577,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of WebMvcSseServerTransportProvider with the configured * settings. @@ -560,7 +600,7 @@ public WebMvcSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java index 67b5f571c..92a08a8f4 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.List; +import java.util.Map; /** * Implementation of a WebMVC based {@link McpStatelessServerTransport}. @@ -50,15 +51,23 @@ public class WebMvcStatelessServerTransport implements McpStatelessServerTranspo private volatile boolean isClosing = false; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.mcpEndpoint, this::handleGet) .POST(this.mcpEndpoint, this::handlePost) @@ -100,6 +109,14 @@ private ServerResponse handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + McpTransportContext transportContext = this.contextExtractor.extract(request); List acceptHeaders = request.headers().asHttpHeaders().getAccept(); @@ -179,6 +196,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + private Builder() { // used by a static method } @@ -224,6 +243,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link WebMvcStatelessServerTransport} with the * configured settings. @@ -233,7 +264,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont public WebMvcStatelessServerTransport build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new WebMvcStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - mcpEndpoint, contextExtractor); + mcpEndpoint, contextExtractor, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index f2a58d4d8..7ca76f80b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -7,6 +7,7 @@ import java.io.IOException; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; @@ -102,6 +103,11 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Constructs a new WebMvcStreamableServerTransportProvider instance. * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization @@ -111,19 +117,26 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @param contextExtractor The context extractor for transport context from the + * request. + * @param keepAliveInterval The interval for keep-alive pings. If null, no keep-alive + * will be scheduled. + * @param securityValidator The security validator for validating HTTP requests. * @throws IllegalArgumentException if any parameter is null */ private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval) { + Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) { Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() .GET(this.mcpEndpoint, this::handleGet) .POST(this.mcpEndpoint, this::handlePost) @@ -233,6 +246,14 @@ private ServerResponse handleGet(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); @@ -315,6 +336,14 @@ private ServerResponse handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { @@ -427,6 +456,14 @@ private ServerResponse handleDelete(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + try { + Map> headers = request.headers().asHttpHeaders(); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + return ServerResponse.status(e.getStatusCode()).body(e.getMessage()); + } + if (this.disallowDelete) { return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); } @@ -609,6 +646,8 @@ public static class Builder { private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** * Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP * messages. @@ -672,6 +711,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with * the configured settings. @@ -682,7 +733,7 @@ public WebMvcStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new WebMvcStreamableServerTransportProvider( jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete, - contextExtractor, keepAliveInterval); + contextExtractor, keepAliveInterval, securityValidator); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java new file mode 100644 index 000000000..9615547d3 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java @@ -0,0 +1,334 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.security; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer; +import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.BeforeParameterizedClassInvocation; +import org.junit.jupiter.params.Parameter; +import org.junit.jupiter.params.ParameterizedClass; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.context.annotation.Scope; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +/** + * Test the header security validation for all transport types. + * + * @author Daniel Garnier-Moiroux + */ +@ParameterizedClass +@MethodSource("transports") +public class ServerTransportSecurityIntegrationTests { + + private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + + @Parameter + private static Class configClass; + + private static TomcatServer tomcatServer; + + private static String baseUrl; + + @BeforeParameterizedClassInvocation + static void createTransportAndStartTomcat(Class configClass) { + var port = TestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + startTomcat(configClass, port); + } + + @AfterAll + static void afterAll() { + stopTomcat(); + } + + private McpSyncClient mcpClient; + + private TestRequestCustomizer requestCustomizer; + + @BeforeEach + void setUp() { + mcpClient = tomcatServer.appContext().getBean(McpSyncClient.class); + requestCustomizer = tomcatServer.appContext().getBean(TestRequestCustomizer.class); + } + + @AfterEach + void tearDown() { + mcpClient.close(); + } + + @Test + void originAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void noOrigin() { + requestCustomizer.setOriginHeader(null); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectOriginNotAllowed() { + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageOriginNotAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + mcpClient.initialize(); + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.listTools()); + } + + // ---------------------------------------------------- + // Tomcat management + // ---------------------------------------------------- + + private static void startTomcat(Class componentClass, int port) { + tomcatServer = TomcatTestUtil.createTomcatServer("", port, componentClass); + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private static void stopTomcat() { + if (tomcatServer != null) { + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + } + + // ---------------------------------------------------- + // Transport servers to test + // ---------------------------------------------------- + + /** + * All transport types we want to test. We use a {@link MethodSource} rather than a + * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name. + */ + static Stream transports() { + //@formatter:off + return Stream.of( + arguments(named("SSE", SseConfig.class)), + arguments(named("Streamable HTTP", StreamableHttpConfig.class)), + arguments(named("Stateless", StatelessConfig.class)) + ); + //@formatter:on + } + + // ---------------------------------------------------- + // Spring Configuration classes + // ---------------------------------------------------- + + @Configuration + static class CommonConfig { + + @Bean + TestRequestCustomizer requestCustomizer() { + return new TestRequestCustomizer(); + } + + @Bean + DefaultServerTransportSecurityValidator validator() { + return DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build(); + } + + } + + @Configuration + @EnableWebMvc + @Import(CommonConfig.class) + static class SseConfig { + + @Bean + @Scope("prototype") + McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) { + var transport = HttpClientSseClientTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport( + DefaultServerTransportSecurityValidator validator) { + return WebMvcSseServerTransportProvider.builder() + .messageEndpoint("/mcp/message") + .securityValidator(validator) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpServer(WebMvcSseServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + } + + @Configuration + @EnableWebMvc + @Import(CommonConfig.class) + static class StreamableHttpConfig { + + @Bean + @Scope("prototype") + McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport( + DefaultServerTransportSecurityValidator validator) { + return WebMvcStreamableServerTransportProvider.builder().securityValidator(validator).build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpServer(WebMvcStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + } + + @Configuration + @EnableWebMvc + @Import(CommonConfig.class) + static class StatelessConfig { + + @Bean + @Scope("prototype") + McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonMapper.getDefault()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport( + DefaultServerTransportSecurityValidator validator) { + return WebMvcStatelessServerTransport.builder().securityValidator(validator).build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) { + return McpServer.sync(transportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + } + + static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private String originHeader = null; + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + if (originHeader != null) { + builder.header("Origin", originHeader); + } + } + + public void setOriginHeader(String originHeader) { + this.originHeader = originHeader; + } + + } + +} From 68ed7955b9c252ddda75af0d3d0f5ef875b7f5ae Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 4 Feb 2026 21:17:43 +0100 Subject: [PATCH 2/2] Update conformance tests with DNS rebinding protection Signed-off-by: Daniel Garnier-Moiroux --- conformance-tests/VALIDATION_RESULTS.md | 4 +--- conformance-tests/conformance-baseline.yml | 3 --- conformance-tests/server-servlet/README.md | 4 ++-- .../conformance/server/ConformanceServlet.java | 3 +++ 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/conformance-tests/VALIDATION_RESULTS.md b/conformance-tests/VALIDATION_RESULTS.md index f33ff4e81..80ce364c1 100644 --- a/conformance-tests/VALIDATION_RESULTS.md +++ b/conformance-tests/VALIDATION_RESULTS.md @@ -15,13 +15,12 @@ - **Resources (4/6):** list, read-text, read-binary, templates-read - **Prompts (4/4):** list, simple, with-args, embedded-resource, with-image - **SSE Transport (2/2):** Multiple streams -- **Security (1/2):** Localhost validation passes +- **Security (2/2):** Localhost validation passes, DNS rebinding protection ### Failing (3/40) 1. **resources-subscribe** - Not implemented in SDK 2. **resources-unsubscribe** - Not implemented in SDK -3. **dns-rebinding-protection** - Missing Host/Origin validation (1/2 checks) ## Client Test Results @@ -44,7 +43,6 @@ 1. **Resource Subscriptions:** SDK doesn't implement `resources/subscribe` and `resources/unsubscribe` handlers 2. **Client SSE Retry:** Client doesn't parse or respect the `retry:` field, reconnects immediately, and doesn't send Last-Event-ID header -3. **DNS Rebinding Protection:** Missing Host/Origin header validation in server transport ## Running Tests diff --git a/conformance-tests/conformance-baseline.yml b/conformance-tests/conformance-baseline.yml index 22c061590..920e8401c 100644 --- a/conformance-tests/conformance-baseline.yml +++ b/conformance-tests/conformance-baseline.yml @@ -6,9 +6,6 @@ server: # Resource subscription not implemented in SDK - resources-subscribe - resources-unsubscribe - - # DNS rebinding protection missing Host/Origin validation - - dns-rebinding-protection client: # SSE retry field handling not implemented diff --git a/conformance-tests/server-servlet/README.md b/conformance-tests/server-servlet/README.md index 2c69244fb..bd86636b6 100644 --- a/conformance-tests/server-servlet/README.md +++ b/conformance-tests/server-servlet/README.md @@ -32,8 +32,8 @@ The server has been validated against the official [MCP conformance test suite]( ✅ **SSE Transport** (2/2) - Multiple streams support -⚠️ **Security** (1/2) -- ⚠️ DNS rebinding protection (SDK limitation) +✅ **Security** (2/2) +- ✅ DNS rebinding protection ## Features diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java index ff127cd3d..411c8ecc5 100644 --- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java +++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java @@ -7,6 +7,7 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator; import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema.AudioContent; import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; @@ -66,6 +67,8 @@ public static void main(String[] args) throws Exception { .builder() .mcpEndpoint(MCP_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(30)) + .securityValidator( + DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build()) .build(); // Build server with all conformance test features