diff --git a/conformance-tests/VALIDATION_RESULTS.md b/conformance-tests/VALIDATION_RESULTS.md
index f33ff4e81..80ce364c1 100644
--- a/conformance-tests/VALIDATION_RESULTS.md
+++ b/conformance-tests/VALIDATION_RESULTS.md
@@ -15,13 +15,12 @@
- **Resources (4/6):** list, read-text, read-binary, templates-read
- **Prompts (4/4):** list, simple, with-args, embedded-resource, with-image
- **SSE Transport (2/2):** Multiple streams
-- **Security (1/2):** Localhost validation passes
+- **Security (2/2):** Localhost validation passes, DNS rebinding protection
### Failing (3/40)
1. **resources-subscribe** - Not implemented in SDK
2. **resources-unsubscribe** - Not implemented in SDK
-3. **dns-rebinding-protection** - Missing Host/Origin validation (1/2 checks)
## Client Test Results
@@ -44,7 +43,6 @@
1. **Resource Subscriptions:** SDK doesn't implement `resources/subscribe` and `resources/unsubscribe` handlers
2. **Client SSE Retry:** Client doesn't parse or respect the `retry:` field, reconnects immediately, and doesn't send Last-Event-ID header
-3. **DNS Rebinding Protection:** Missing Host/Origin header validation in server transport
## Running Tests
diff --git a/conformance-tests/conformance-baseline.yml b/conformance-tests/conformance-baseline.yml
index 22c061590..920e8401c 100644
--- a/conformance-tests/conformance-baseline.yml
+++ b/conformance-tests/conformance-baseline.yml
@@ -6,9 +6,6 @@ server:
# Resource subscription not implemented in SDK
- resources-subscribe
- resources-unsubscribe
-
- # DNS rebinding protection missing Host/Origin validation
- - dns-rebinding-protection
client:
# SSE retry field handling not implemented
diff --git a/conformance-tests/server-servlet/README.md b/conformance-tests/server-servlet/README.md
index 2c69244fb..bd86636b6 100644
--- a/conformance-tests/server-servlet/README.md
+++ b/conformance-tests/server-servlet/README.md
@@ -32,8 +32,8 @@ The server has been validated against the official [MCP conformance test suite](
✅ **SSE Transport** (2/2)
- Multiple streams support
-⚠️ **Security** (1/2)
-- ⚠️ DNS rebinding protection (SDK limitation)
+✅ **Security** (2/2)
+- ✅ DNS rebinding protection
## Features
diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
index ca09e55e4..411c8ecc5 100644
--- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
+++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java
@@ -7,8 +7,36 @@
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
+import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator;
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
-import io.modelcontextprotocol.spec.McpSchema.*;
+import io.modelcontextprotocol.spec.McpSchema.AudioContent;
+import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents;
+import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
+import io.modelcontextprotocol.spec.McpSchema.CompleteResult;
+import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
+import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
+import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
+import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
+import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource;
+import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
+import io.modelcontextprotocol.spec.McpSchema.ImageContent;
+import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
+import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
+import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
+import io.modelcontextprotocol.spec.McpSchema.ProgressNotification;
+import io.modelcontextprotocol.spec.McpSchema.Prompt;
+import io.modelcontextprotocol.spec.McpSchema.PromptArgument;
+import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
+import io.modelcontextprotocol.spec.McpSchema.PromptReference;
+import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
+import io.modelcontextprotocol.spec.McpSchema.Resource;
+import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
+import io.modelcontextprotocol.spec.McpSchema.Role;
+import io.modelcontextprotocol.spec.McpSchema.SamplingMessage;
+import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
+import io.modelcontextprotocol.spec.McpSchema.TextContent;
+import io.modelcontextprotocol.spec.McpSchema.TextResourceContents;
+import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.startup.Tomcat;
@@ -39,6 +67,8 @@ public static void main(String[] args) throws Exception {
.builder()
.mcpEndpoint(MCP_ENDPOINT)
.keepAliveInterval(Duration.ofSeconds(30))
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
.build();
// Build server with all conformance test features
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java
new file mode 100644
index 000000000..5321aada7
--- /dev/null
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import io.modelcontextprotocol.util.Assert;
+
+/**
+ * Default implementation of {@link ServerTransportSecurityValidator} that validates the
+ * Origin header against a list of allowed origins.
+ *
+ *
+ * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
+ *
+ * @author Daniel Garnier-Moiroux
+ * @see ServerTransportSecurityValidator
+ * @see ServerTransportSecurityException
+ */
+public class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator {
+
+ private static final String ORIGIN_HEADER = "Origin";
+
+ private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
+ "Invalid Origin header");
+
+ private final List allowedOrigins;
+
+ /**
+ * Creates a new validator with the specified allowed origins.
+ * @param allowedOrigins List of allowed origin patterns. Supports exact matches
+ * (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
+ */
+ public DefaultServerTransportSecurityValidator(List allowedOrigins) {
+ Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
+ this.allowedOrigins = allowedOrigins;
+ }
+
+ @Override
+ public void validateHeaders(Map> headers) throws ServerTransportSecurityException {
+ for (Map.Entry> entry : headers.entrySet()) {
+ if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
+ List values = entry.getValue();
+ if (values != null && !values.isEmpty()) {
+ validateOrigin(values.get(0));
+ }
+ break;
+ }
+ }
+ }
+
+ /**
+ * Validates a single origin value against the allowed origins. Subclasses can
+ * override this method to customize origin validation logic.
+ * @param origin The origin header value, or null if not present
+ * @throws ServerTransportSecurityException if the origin is not allowed
+ */
+ protected void validateOrigin(String origin) throws ServerTransportSecurityException {
+ // Origin absent = no validation needed (same-origin request)
+ if (origin == null || origin.isBlank()) {
+ return;
+ }
+
+ for (String allowed : allowedOrigins) {
+ if (allowed.equals(origin)) {
+ return;
+ }
+ else if (allowed.endsWith(":*")) {
+ // Wildcard port pattern: "http://example.com:*"
+ String baseOrigin = allowed.substring(0, allowed.length() - 2);
+ if (origin.equals(baseOrigin) || origin.startsWith(baseOrigin + ":")) {
+ return;
+ }
+ }
+
+ }
+
+ throw INVALID_ORIGIN;
+ }
+
+ /**
+ * Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
+ * @return A new builder instance
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Builder for creating instances of {@link DefaultServerTransportSecurityValidator}.
+ */
+ public static class Builder {
+
+ private final List allowedOrigins = new ArrayList<>();
+
+ /**
+ * Adds an allowed origin pattern.
+ * @param origin The origin to allow (e.g., "http://localhost:8080" or
+ * "http://example.com:*")
+ * @return this builder instance
+ */
+ public Builder allowedOrigin(String origin) {
+ this.allowedOrigins.add(origin);
+ return this;
+ }
+
+ /**
+ * Adds multiple allowed origin patterns.
+ * @param origins The origins to allow
+ * @return this builder instance
+ */
+ public Builder allowedOrigins(List origins) {
+ Assert.notNull(origins, "origins must not be null");
+ this.allowedOrigins.addAll(origins);
+ return this;
+ }
+
+ /**
+ * Builds the validator instance.
+ * @return A new DefaultServerTransportSecurityValidator
+ */
+ public DefaultServerTransportSecurityValidator build() {
+ return new DefaultServerTransportSecurityValidator(allowedOrigins);
+ }
+
+ }
+
+}
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java
index 96cebb74a..d12fb8c9e 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024 - 2024 the original author or authors.
+ * Copyright 2024 - 2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -8,6 +8,9 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.time.Duration;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@@ -142,6 +145,11 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
*/
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
/**
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
* endpoint.
@@ -153,23 +161,25 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
* @param keepAliveInterval The interval for keep-alive pings, or null to disable
* keep-alive functionality
* @param contextExtractor The extractor for transport context from the request.
- * @deprecated Use the builder {@link #builder()} instead for better configuration
- * options.
+ * @param securityValidator The security validator for validating HTTP requests.
*/
private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "JsonMapper must not be null");
Assert.notNull(messageEndpoint, "messageEndpoint must not be null");
Assert.notNull(sseEndpoint, "sseEndpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
if (keepAliveInterval != null) {
@@ -246,6 +256,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
response.setContentType("text/event-stream");
response.setCharacterEncoding(UTF_8);
response.setHeader("Cache-Control", "no-cache");
@@ -311,6 +330,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
// Get the session ID from the request parameter
String sessionId = request.getParameter("sessionId");
if (sessionId == null) {
@@ -411,6 +439,21 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws
}
}
+ /**
+ * Extracts all headers from the HTTP servlet request into a map.
+ * @param request The HTTP servlet request
+ * @return A map of header names to their values
+ */
+ private Map> extractHeaders(HttpServletRequest request) {
+ Map> headers = new HashMap<>();
+ Enumeration names = request.getHeaderNames();
+ while (names.hasMoreElements()) {
+ String name = names.nextElement();
+ headers.put(name, Collections.list(request.getHeaders(name)));
+ }
+ return headers;
+ }
+
/**
* Cleans up resources when the servlet is being destroyed.
*
@@ -547,6 +590,8 @@ public static class Builder {
private Duration keepAliveInterval;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
/**
* Sets the JsonMapper implementation to use for serialization/deserialization. If
* not specified, a JacksonJsonMapper will be created from the configured
@@ -621,6 +666,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return This builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of HttpServletSseServerTransportProvider with the
* configured settings.
@@ -633,7 +690,7 @@ public HttpServletSseServerTransportProvider build() {
}
return new HttpServletSseServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint,
- keepAliveInterval, contextExtractor);
+ keepAliveInterval, contextExtractor, securityValidator);
}
}
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java
index 40767f416..106f834f5 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2024 the original author or authors.
+ * Copyright 2024-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -7,6 +7,11 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -58,15 +63,23 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements
private volatile boolean isClosing = false;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "jsonMapper must not be null");
Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null");
Assert.notNull(contextExtractor, "contextExtractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
}
@Override
@@ -122,6 +135,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
String accept = request.getHeader(ACCEPT);
@@ -209,6 +231,21 @@ private void responseError(HttpServletResponse response, int httpCode, McpError
writer.flush();
}
+ /**
+ * Extracts all headers from the HTTP servlet request into a map.
+ * @param request The HTTP servlet request
+ * @return A map of header names to their values
+ */
+ private Map> extractHeaders(HttpServletRequest request) {
+ Map> headers = new HashMap<>();
+ Enumeration names = request.getHeaderNames();
+ while (names.hasMoreElements()) {
+ String name = names.nextElement();
+ headers.put(name, Collections.list(request.getHeaders(name)));
+ }
+ return headers;
+ }
+
/**
* Cleans up resources when the servlet is being destroyed.
*
@@ -243,6 +280,8 @@ public static class Builder {
private McpTransportContextExtractor contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
private Builder() {
// used by a static method
}
@@ -288,6 +327,18 @@ public Builder contextExtractor(McpTransportContextExtractor
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link HttpServletStatelessServerTransport} with the
* configured settings.
@@ -297,7 +348,7 @@ public Builder contextExtractor(McpTransportContextExtractor
public HttpServletStatelessServerTransport build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new HttpServletStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
- mcpEndpoint, contextExtractor);
+ mcpEndpoint, contextExtractor, securityValidator);
}
}
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java
index 34671c105..b7c8e7b23 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2024 the original author or authors.
+ * Copyright 2024-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -9,7 +9,11 @@
import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
@@ -119,6 +123,11 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
*/
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
/**
* Constructs a new HttpServletStreamableServerTransportProvider instance.
* @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of
@@ -127,19 +136,24 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
* messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
* @param disallowDelete Whether to disallow DELETE requests on the endpoint.
* @param contextExtractor The extractor for transport context from the request.
+ * @param keepAliveInterval The interval for keep-alive pings. If null, no keep-alive
+ * will be scheduled.
+ * @param securityValidator The security validator for validating HTTP requests.
* @throws IllegalArgumentException if any parameter is null
*/
private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint,
boolean disallowDelete, McpTransportContextExtractor contextExtractor,
- Duration keepAliveInterval) {
+ Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "JsonMapper must not be null");
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.disallowDelete = disallowDelete;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
if (keepAliveInterval != null) {
@@ -246,6 +260,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
List badRequestErrors = new ArrayList<>();
String accept = request.getHeader(ACCEPT);
@@ -373,6 +396,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
List badRequestErrors = new ArrayList<>();
String accept = request.getHeader(ACCEPT);
@@ -536,6 +568,15 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
return;
}
+ try {
+ Map> headers = extractHeaders(request);
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ response.sendError(e.getStatusCode(), e.getMessage());
+ return;
+ }
+
if (this.disallowDelete) {
response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
return;
@@ -586,6 +627,21 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m
return;
}
+ /**
+ * Extracts all headers from the HTTP servlet request into a map.
+ * @param request The HTTP servlet request
+ * @return A map of header names to their values
+ */
+ private Map> extractHeaders(HttpServletRequest request) {
+ Map> headers = new HashMap<>();
+ Enumeration names = request.getHeaderNames();
+ while (names.hasMoreElements()) {
+ String name = names.nextElement();
+ headers.put(name, Collections.list(request.getHeaders(name)));
+ }
+ return headers;
+ }
+
/**
* Sends an SSE event to a client with a specific ID.
* @param writer The writer to send the event through
@@ -774,6 +830,8 @@ public static class Builder {
private Duration keepAliveInterval;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
/**
* Sets the JsonMapper to use for JSON serialization/deserialization of MCP
* messages.
@@ -833,6 +891,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
* with the configured settings.
@@ -843,7 +913,7 @@ public HttpServletStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
return new HttpServletStreamableServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete,
- contextExtractor, keepAliveInterval);
+ contextExtractor, keepAliveInterval, securityValidator);
}
}
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java
new file mode 100644
index 000000000..96a06d3bd
--- /dev/null
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+/**
+ * Exception thrown when security validation fails for an HTTP request. Contains HTTP
+ * status code and message.
+ *
+ * @author Daniel Garnier-Moiroux
+ * @see ServerTransportSecurityValidator
+ */
+public class ServerTransportSecurityException extends Exception {
+
+ private final int statusCode;
+
+ /**
+ * Creates a new ServerTransportSecurityException with the specified HTTP status code
+ * and message.
+ */
+ public ServerTransportSecurityException(int statusCode, String message) {
+ super(message);
+ this.statusCode = statusCode;
+ }
+
+ public int getStatusCode() {
+ return statusCode;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ ServerTransportSecurityException that = (ServerTransportSecurityException) obj;
+ return statusCode == that.statusCode && java.util.Objects.equals(getMessage(), that.getMessage());
+ }
+
+ @Override
+ public int hashCode() {
+ return java.util.Objects.hash(statusCode, getMessage());
+ }
+
+}
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java
new file mode 100644
index 000000000..ce805931f
--- /dev/null
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Interface for validating HTTP requests in server transports. Implementations can
+ * validate Origin headers, Host headers, or any other security-related headers according
+ * to the MCP specification.
+ *
+ * @author Daniel Garnier-Moiroux
+ * @see DefaultServerTransportSecurityValidator
+ * @see ServerTransportSecurityException
+ */
+@FunctionalInterface
+public interface ServerTransportSecurityValidator {
+
+ /**
+ * A no-op validator that accepts all requests without validation.
+ */
+ ServerTransportSecurityValidator NOOP = headers -> {
+ };
+
+ /**
+ * Validates the HTTP headers from an incoming request.
+ * @param headers A map of header names to their values (multi-valued headers
+ * supported)
+ * @throws ServerTransportSecurityException if validation fails
+ */
+ void validateHeaders(Map> headers) throws ServerTransportSecurityException;
+
+}
diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java
new file mode 100644
index 000000000..7e1593e1b
--- /dev/null
+++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java
@@ -0,0 +1,197 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * @author Daniel Garnier-Moiroux
+ */
+class DefaultServerTransportSecurityValidatorTests {
+
+ private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
+ "Invalid Origin header");
+
+ private final DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:8080")
+ .build();
+
+ @Test
+ void builder() {
+ assertThatCode(() -> DefaultServerTransportSecurityValidator.builder().build()).doesNotThrowAnyException();
+ assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedOrigins(null).build())
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ void originHeaderMissing() {
+ assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException();
+ }
+
+ @Test
+ void originHeaderListEmpty() {
+ assertThatCode(() -> validator.validateHeaders(Map.of("Origin", List.of()))).doesNotThrowAnyException();
+ }
+
+ @Test
+ void caseInsensitive() {
+ var headers = Map.of("origin", List.of("http://localhost:8080"));
+
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void exactMatch() {
+ var headers = originHeader("http://localhost:8080");
+
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentPort() {
+
+ var headers = originHeader("http://localhost:3000");
+
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Test
+ void differentHost() {
+
+ var headers = originHeader("http://example.com:8080");
+
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Test
+ void differentScheme() {
+
+ var headers = originHeader("https://localhost:8080");
+
+ assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Nested
+ class WildcardPort {
+
+ private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedOrigin("http://localhost:*")
+ .build();
+
+ @Test
+ void anyPortWithWildcard() {
+ var headers = originHeader("http://localhost:3000");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void noPortWithWildcard() {
+ var headers = originHeader("http://localhost");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentPortWithWildcard() {
+ var headers = originHeader("http://localhost:8080");
+
+ assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void differentHostWithWildcard() {
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ @Test
+ void differentSchemeWithWildcard() {
+ var headers = originHeader("https://localhost:3000");
+
+ assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ }
+
+ @Nested
+ class MultipleOrigins {
+
+ DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator
+ .builder()
+ .allowedOrigin("http://localhost:8080")
+ .allowedOrigin("http://example.com:3000")
+ .allowedOrigin("http://myapp.com:*")
+ .build();
+
+ @Test
+ void matchingOneOfMultiple() {
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void matchingWildcardInMultiple() {
+ var headers = originHeader("http://myapp.com:9999");
+
+ assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void notMatchingAny() {
+ var headers = originHeader("http://malicious.example.com:1234");
+
+ assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN);
+ }
+
+ }
+
+ @Nested
+ class BuilderTests {
+
+ @Test
+ void shouldAddMultipleOriginsWithAllowedOriginsMethod() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*"))
+ .build();
+
+ var headers = originHeader("http://example.com:3000");
+
+ assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException();
+ }
+
+ @Test
+ void shouldCombineAllowedOriginMethods() {
+ DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder()
+ .allowedOrigin("http://localhost:8080")
+ .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000"))
+ .build();
+
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000")))
+ .doesNotThrowAnyException();
+ }
+
+ }
+
+ private static Map> originHeader(String origin) {
+ return Map.of("Origin", List.of(origin));
+ }
+
+}
diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java
new file mode 100644
index 000000000..e9e64c0d0
--- /dev/null
+++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.net.URI;
+import java.net.http.HttpRequest;
+import java.time.Duration;
+import java.util.stream.Stream;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
+import io.modelcontextprotocol.common.McpTransportContext;
+import io.modelcontextprotocol.json.McpJsonMapper;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.spec.McpSchema;
+import jakarta.servlet.http.HttpServlet;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.apache.catalina.startup.Tomcat;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.BeforeParameterizedClassInvocation;
+import org.junit.jupiter.params.Parameter;
+import org.junit.jupiter.params.ParameterizedClass;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.junit.jupiter.api.Named.named;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+/**
+ * Test the header security validation for all transport types.
+ *
+ * @author Daniel Garnier-Moiroux
+ */
+@ParameterizedClass
+@MethodSource("transports")
+class ServerTransportSecurityIntegrationTests {
+
+ private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+
+ @Parameter
+ private static Transport transport;
+
+ private static Tomcat tomcat;
+
+ private static String baseUrl;
+
+ @BeforeParameterizedClassInvocation
+ static void createTransportAndStartTomcat(Transport transport) {
+ var port = TomcatTestUtil.findAvailablePort();
+ baseUrl = "http://localhost:" + port;
+ startTomcat(transport.servlet(), port);
+ }
+
+ @AfterAll
+ static void afterAll() {
+ stopTomcat();
+ }
+
+ private McpSyncClient mcpClient;
+
+ private final TestRequestCustomizer requestCustomizer = new TestRequestCustomizer();
+
+ @BeforeEach
+ void setUp() {
+ mcpClient = transport.createMcpClient(baseUrl, requestCustomizer);
+ }
+
+ @AfterEach
+ void tearDown() {
+ mcpClient.close();
+ }
+
+ @Test
+ void originAllowed() {
+ requestCustomizer.setOriginHeader(baseUrl);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void noOrigin() {
+ requestCustomizer.setOriginHeader(null);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectOriginNotAllowed() {
+ requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageOriginNotAllowed() {
+ requestCustomizer.setOriginHeader(baseUrl);
+ mcpClient.initialize();
+ requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
+ // ----------------------------------------------------
+ // Tomcat management
+ // ----------------------------------------------------
+
+ private static void startTomcat(jakarta.servlet.Servlet servlet, int port) {
+ tomcat = TomcatTestUtil.createTomcatServer("", port, servlet);
+ try {
+ tomcat.start();
+ assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+ }
+
+ private static void stopTomcat() {
+ if (tomcat != null) {
+ try {
+ tomcat.stop();
+ tomcat.destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ // ----------------------------------------------------
+ // Transport servers to test
+ // ----------------------------------------------------
+
+ /**
+ * All transport types we want to test. We use a {@link MethodSource} rather than a
+ * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name.
+ */
+ static Stream transports() {
+ //@formatter:off
+ return Stream.of(
+ arguments(named("SSE", new Sse())),
+ arguments(named("Streamable HTTP", new StreamableHttp())),
+ arguments(named("Stateless", new Stateless()))
+ );
+ //@formatter:on
+ }
+
+ /**
+ * Represents a server transport we want to test, and how to create a client for the
+ * resulting MCP Server.
+ */
+ interface Transport {
+
+ McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer);
+
+ HttpServlet servlet();
+
+ }
+
+ /**
+ * SSE-based transport.
+ */
+ static class Sse implements Transport {
+
+ private final HttpServletSseServerTransportProvider transport;
+
+ public Sse() {
+ transport = HttpServletSseServerTransportProvider.builder()
+ .messageEndpoint("/mcp/message")
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transport)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) {
+ var transport = HttpClientSseClientTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public HttpServlet servlet() {
+ return transport;
+ }
+
+ }
+
+ static class StreamableHttp implements Transport {
+
+ private final HttpServletStreamableServerTransportProvider transport;
+
+ public StreamableHttp() {
+ transport = HttpServletStreamableServerTransportProvider.builder()
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transport)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) {
+ var transport = HttpClientStreamableHttpTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public HttpServlet servlet() {
+ return transport;
+ }
+
+ }
+
+ static class Stateless implements Transport {
+
+ private final HttpServletStatelessServerTransport transport;
+
+ public Stateless() {
+ transport = HttpServletStatelessServerTransport.builder()
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transport)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) {
+ var transport = HttpClientStreamableHttpTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public HttpServlet servlet() {
+ return transport;
+ }
+
+ }
+
+ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer {
+
+ private String originHeader = null;
+
+ @Override
+ public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body,
+ McpTransportContext context) {
+ if (originHeader != null) {
+ builder.header("Origin", originHeader);
+ }
+ }
+
+ public void setOriginHeader(String originHeader) {
+ this.originHeader = originHeader;
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java
index 0c80c5b8b..34d6e5085 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2025-2025 the original author or authors.
+ * Copyright 2025-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -7,6 +7,7 @@
import java.io.IOException;
import java.time.Duration;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import io.modelcontextprotocol.common.McpTransportContext;
@@ -132,6 +133,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
*/
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param jsonMapper The ObjectMapper to use for JSON serialization/deserialization of
@@ -144,22 +150,26 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
* @param keepAliveInterval The interval for sending keep-alive pings to clients.
* @param contextExtractor The context extractor to use for extracting MCP transport
* context from HTTP requests. Must not be null.
+ * @param securityValidator The security validator for validating HTTP requests.
* @throws IllegalArgumentException if either parameter is null
*/
private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
@@ -273,6 +283,14 @@ private Mono handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
return ServerResponse.ok()
@@ -332,6 +350,14 @@ private Mono handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
}
@@ -436,6 +462,8 @@ public static class Builder {
private McpTransportContextExtractor contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
/**
* Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP
* messages.
@@ -513,6 +541,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
* configured settings.
@@ -522,7 +562,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont
public WebFluxSseServerTransportProvider build() {
Assert.notNull(messageEndpoint, "Message endpoint must be set");
return new WebFluxSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
- baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor);
+ baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java
index 400be341e..b225ab61b 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2025-2025 the original author or authors.
+ * Copyright 2025-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -24,6 +24,7 @@
import java.io.IOException;
import java.util.List;
+import java.util.Map;
/**
* Implementation of a WebFlux based {@link McpStatelessServerTransport}.
@@ -46,15 +47,23 @@ public class WebFluxStatelessServerTransport implements McpStatelessServerTransp
private volatile boolean isClosing = false;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "jsonMapper must not be null");
Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null");
Assert.notNull(contextExtractor, "contextExtractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
@@ -96,6 +105,14 @@ private Mono handlePost(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
List acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -160,6 +177,8 @@ public static class Builder {
private McpTransportContextExtractor contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
private Builder() {
// used by a static method
}
@@ -205,6 +224,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link WebFluxStatelessServerTransport} with the
* configured settings.
@@ -214,7 +245,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont
public WebFluxStatelessServerTransport build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new WebFluxStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
- mcpEndpoint, contextExtractor);
+ mcpEndpoint, contextExtractor, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
index deebfc616..762ee005d 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2025-2025 the original author or authors.
+ * Copyright 2025-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -36,6 +36,7 @@
import java.io.IOException;
import java.time.Duration;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
@@ -67,17 +68,24 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint,
McpTransportContextExtractor contextExtractor, boolean disallowDelete,
- Duration keepAliveInterval) {
+ Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "JsonMapper must not be null");
Assert.notNull(mcpEndpoint, "Message endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.contextExtractor = contextExtractor;
this.disallowDelete = disallowDelete;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
@@ -166,6 +174,14 @@ private Mono handleGet(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
return Mono.defer(() -> {
@@ -221,6 +237,14 @@ private Mono handlePost(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
List acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -310,6 +334,14 @@ private Mono handleDelete(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).bodyValue(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
return Mono.defer(() -> {
@@ -410,6 +442,8 @@ public static class Builder {
private Duration keepAliveInterval;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
private Builder() {
// used by a static method
}
@@ -477,6 +511,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with
* the configured settings.
@@ -487,7 +533,7 @@ public WebFluxStreamableServerTransportProvider build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new WebFluxStreamableServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, contextExtractor,
- disallowDelete, keepAliveInterval);
+ disallowDelete, keepAliveInterval, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java
new file mode 100644
index 000000000..06e1286d2
--- /dev/null
+++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/security/WebFluxServerTransportSecurityIntegrationTests.java
@@ -0,0 +1,291 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.security;
+
+import java.time.Duration;
+import java.util.stream.Stream;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
+import io.modelcontextprotocol.json.McpJsonMapper;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.TestUtil;
+import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator;
+import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
+import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport;
+import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.BeforeParameterizedClassInvocation;
+import org.junit.jupiter.params.Parameter;
+import org.junit.jupiter.params.ParameterizedClass;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import reactor.core.publisher.Mono;
+import reactor.netty.DisposableServer;
+import reactor.netty.http.server.HttpServer;
+
+import org.springframework.http.server.reactive.HttpHandler;
+import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.server.RouterFunction;
+import org.springframework.web.reactive.function.server.RouterFunctions;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.junit.jupiter.api.Named.named;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+/**
+ * Test the header security validation for all transport types.
+ *
+ * @author Daniel Garnier-Moiroux
+ */
+@ParameterizedClass
+@MethodSource("transports")
+public class WebFluxServerTransportSecurityIntegrationTests {
+
+ private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+
+ @Parameter
+ private static Transport transport;
+
+ private static DisposableServer httpServer;
+
+ private static String baseUrl;
+
+ @BeforeParameterizedClassInvocation
+ static void createTransportAndStartServer(Transport transport) {
+ var port = TestUtil.findAvailablePort();
+ baseUrl = "http://localhost:" + port;
+ startServer(transport.routerFunction(), port);
+ }
+
+ @AfterAll
+ static void afterAll() {
+ stopServer();
+ }
+
+ private McpSyncClient mcpClient;
+
+ private final TestOriginHeaderExchangeFilterFunction exchangeFilterFunction = new TestOriginHeaderExchangeFilterFunction();
+
+ @BeforeEach
+ void setUp() {
+ mcpClient = transport.createMcpClient(baseUrl, exchangeFilterFunction);
+ }
+
+ @AfterEach
+ void tearDown() {
+ mcpClient.close();
+ }
+
+ @Test
+ void originAllowed() {
+ exchangeFilterFunction.setOriginHeader(baseUrl);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void noOrigin() {
+ exchangeFilterFunction.setOriginHeader(null);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectOriginNotAllowed() {
+ exchangeFilterFunction.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageOriginNotAllowed() {
+ exchangeFilterFunction.setOriginHeader(baseUrl);
+ mcpClient.initialize();
+ exchangeFilterFunction.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
+ // ----------------------------------------------------
+ // Server management
+ // ----------------------------------------------------
+
+ private static void startServer(RouterFunction> routerFunction, int port) {
+ HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction);
+ ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);
+ httpServer = HttpServer.create().port(port).handle(adapter).bindNow();
+ }
+
+ private static void stopServer() {
+ if (httpServer != null) {
+ httpServer.disposeNow();
+ }
+ }
+
+ // ----------------------------------------------------
+ // Transport servers to test
+ // ----------------------------------------------------
+
+ /**
+ * All transport types we want to test. We use a {@link MethodSource} rather than a
+ * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name.
+ */
+ static Stream transports() {
+ //@formatter:off
+ return Stream.of(
+ arguments(named("SSE", new Sse())),
+ arguments(named("Streamable HTTP", new StreamableHttp())),
+ arguments(named("Stateless", new Stateless()))
+ );
+ //@formatter:on
+ }
+
+ /**
+ * Represents a server transport we want to test, and how to create a client for the
+ * resulting MCP Server.
+ */
+ interface Transport {
+
+ McpSyncClient createMcpClient(String baseUrl, TestOriginHeaderExchangeFilterFunction customizer);
+
+ RouterFunction> routerFunction();
+
+ }
+
+ /**
+ * SSE-based transport.
+ */
+ static class Sse implements Transport {
+
+ private final WebFluxSseServerTransportProvider transportProvider;
+
+ public Sse() {
+ transportProvider = WebFluxSseServerTransportProvider.builder()
+ .messageEndpoint("/mcp/message")
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl,
+ TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ var transport = WebFluxSseClientTransport
+ .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
+ .jsonMapper(McpJsonMapper.getDefault())
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public RouterFunction> routerFunction() {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ static class StreamableHttp implements Transport {
+
+ private final WebFluxStreamableServerTransportProvider transportProvider;
+
+ public StreamableHttp() {
+ transportProvider = WebFluxStreamableServerTransportProvider.builder()
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl,
+ TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ var transport = WebClientStreamableHttpTransport
+ .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public RouterFunction> routerFunction() {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ static class Stateless implements Transport {
+
+ private final WebFluxStatelessServerTransport transportProvider;
+
+ public Stateless() {
+ transportProvider = WebFluxStatelessServerTransport.builder()
+ .securityValidator(
+ DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build())
+ .build();
+ McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ @Override
+ public McpSyncClient createMcpClient(String baseUrl,
+ TestOriginHeaderExchangeFilterFunction exchangeFilterFunction) {
+ var transport = WebClientStreamableHttpTransport
+ .builder(WebClient.builder().baseUrl(baseUrl).filter(exchangeFilterFunction))
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Override
+ public RouterFunction> routerFunction() {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ static class TestOriginHeaderExchangeFilterFunction implements ExchangeFilterFunction {
+
+ private String origin = null;
+
+ public void setOriginHeader(String origin) {
+ this.origin = origin;
+ }
+
+ @Override
+ public Mono filter(ClientRequest request, ExchangeFunction next) {
+ var updatedRequest = ClientRequest.from(request).header("origin", this.origin).build();
+ return next.exchange(updatedRequest);
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
index 6c35de56d..7e925a0af 100644
--- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2024 the original author or authors.
+ * Copyright 2024-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -7,6 +7,7 @@
import java.io.IOException;
import java.time.Duration;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
@@ -118,6 +119,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
/**
* Constructs a new WebMvcSseServerTransportProvider instance.
* @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization
@@ -131,22 +137,26 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
* @param keepAliveInterval The interval for sending keep-alive messages to clients.
* @param contextExtractor The contextExtractor to fill in a
* {@link McpTransportContext}.
+ * @param securityValidator The security validator for validating HTTP requests.
* @throws IllegalArgumentException if any parameter is null
*/
private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "McpJsonMapper must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
@@ -255,6 +265,14 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
// Send initial endpoint event
return ServerResponse.sse(sseBuilder -> {
WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sseBuilder);
@@ -313,6 +331,14 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
if (request.param(SESSION_ID).isEmpty()) {
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
}
@@ -474,6 +500,8 @@ public static class Builder {
private McpTransportContextExtractor contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
/**
* Sets the JSON object mapper to use for message serialization/deserialization.
* @param jsonMapper The object mapper to use
@@ -549,6 +577,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of WebMvcSseServerTransportProvider with the configured
* settings.
@@ -560,7 +600,7 @@ public WebMvcSseServerTransportProvider build() {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
- baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor);
+ baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java
index 67b5f571c..92a08a8f4 100644
--- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2024 the original author or authors.
+ * Copyright 2024-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -24,6 +24,7 @@
import java.io.IOException;
import java.util.List;
+import java.util.Map;
/**
* Implementation of a WebMVC based {@link McpStatelessServerTransport}.
@@ -50,15 +51,23 @@ public class WebMvcStatelessServerTransport implements McpStatelessServerTranspo
private volatile boolean isClosing = false;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint,
- McpTransportContextExtractor contextExtractor) {
+ McpTransportContextExtractor contextExtractor,
+ ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "jsonMapper must not be null");
Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null");
Assert.notNull(contextExtractor, "contextExtractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
@@ -100,6 +109,14 @@ private ServerResponse handlePost(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
McpTransportContext transportContext = this.contextExtractor.extract(request);
List acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -179,6 +196,8 @@ public static class Builder {
private McpTransportContextExtractor contextExtractor = (
serverRequest) -> McpTransportContext.EMPTY;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
private Builder() {
// used by a static method
}
@@ -224,6 +243,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link WebMvcStatelessServerTransport} with the
* configured settings.
@@ -233,7 +264,7 @@ public Builder contextExtractor(McpTransportContextExtractor cont
public WebMvcStatelessServerTransport build() {
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
return new WebMvcStatelessServerTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
- mcpEndpoint, contextExtractor);
+ mcpEndpoint, contextExtractor, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java
index f2a58d4d8..7ca76f80b 100644
--- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2024-2024 the original author or authors.
+ * Copyright 2024-2026 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;
@@ -7,6 +7,7 @@
import java.io.IOException;
import java.time.Duration;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
@@ -102,6 +103,11 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer
private KeepAliveScheduler keepAliveScheduler;
+ /**
+ * Security validator for validating HTTP requests.
+ */
+ private final ServerTransportSecurityValidator securityValidator;
+
/**
* Constructs a new WebMvcStreamableServerTransportProvider instance.
* @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization
@@ -111,19 +117,26 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer
* @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
* @param disallowDelete Whether to disallow DELETE requests on the endpoint.
+ * @param contextExtractor The context extractor for transport context from the
+ * request.
+ * @param keepAliveInterval The interval for keep-alive pings. If null, no keep-alive
+ * will be scheduled.
+ * @param securityValidator The security validator for validating HTTP requests.
* @throws IllegalArgumentException if any parameter is null
*/
private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint,
boolean disallowDelete, McpTransportContextExtractor contextExtractor,
- Duration keepAliveInterval) {
+ Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
Assert.notNull(jsonMapper, "McpJsonMapper must not be null");
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null");
+ Assert.notNull(securityValidator, "Security validator must not be null");
this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.disallowDelete = disallowDelete;
this.contextExtractor = contextExtractor;
+ this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
@@ -233,6 +246,14 @@ private ServerResponse handleGet(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
List acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
@@ -315,6 +336,14 @@ private ServerResponse handlePost(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
List acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)
|| !acceptHeaders.contains(MediaType.APPLICATION_JSON)) {
@@ -427,6 +456,14 @@ private ServerResponse handleDelete(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
+ try {
+ Map> headers = request.headers().asHttpHeaders();
+ this.securityValidator.validateHeaders(headers);
+ }
+ catch (ServerTransportSecurityException e) {
+ return ServerResponse.status(e.getStatusCode()).body(e.getMessage());
+ }
+
if (this.disallowDelete) {
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
}
@@ -609,6 +646,8 @@ public static class Builder {
private Duration keepAliveInterval;
+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
+
/**
* Sets the McpJsonMapper to use for JSON serialization/deserialization of MCP
* messages.
@@ -672,6 +711,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
return this;
}
+ /**
+ * Sets the security validator for validating HTTP requests.
+ * @param securityValidator The security validator to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if securityValidator is null
+ */
+ public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
+ Assert.notNull(securityValidator, "Security validator must not be null");
+ this.securityValidator = securityValidator;
+ return this;
+ }
+
/**
* Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with
* the configured settings.
@@ -682,7 +733,7 @@ public WebMvcStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
return new WebMvcStreamableServerTransportProvider(
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, mcpEndpoint, disallowDelete,
- contextExtractor, keepAliveInterval);
+ contextExtractor, keepAliveInterval, securityValidator);
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java
new file mode 100644
index 000000000..9615547d3
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/security/ServerTransportSecurityIntegrationTests.java
@@ -0,0 +1,334 @@
+/*
+ * Copyright 2026-2026 the original author or authors.
+ */
+
+package io.modelcontextprotocol.security;
+
+import java.net.URI;
+import java.net.http.HttpRequest;
+import java.time.Duration;
+import java.util.stream.Stream;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.McpSyncClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
+import io.modelcontextprotocol.common.McpTransportContext;
+import io.modelcontextprotocol.json.McpJsonMapper;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.McpStatelessSyncServer;
+import io.modelcontextprotocol.server.McpSyncServer;
+import io.modelcontextprotocol.server.TestUtil;
+import io.modelcontextprotocol.server.TomcatTestUtil;
+import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer;
+import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator;
+import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
+import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport;
+import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.BeforeParameterizedClassInvocation;
+import org.junit.jupiter.params.Parameter;
+import org.junit.jupiter.params.ParameterizedClass;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.annotation.Import;
+import org.springframework.context.annotation.Scope;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.junit.jupiter.api.Named.named;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+/**
+ * Test the header security validation for all transport types.
+ *
+ * @author Daniel Garnier-Moiroux
+ */
+@ParameterizedClass
+@MethodSource("transports")
+public class ServerTransportSecurityIntegrationTests {
+
+ private static final String DISALLOWED_ORIGIN = "https://malicious.example.com";
+
+ @Parameter
+ private static Class> configClass;
+
+ private static TomcatServer tomcatServer;
+
+ private static String baseUrl;
+
+ @BeforeParameterizedClassInvocation
+ static void createTransportAndStartTomcat(Class> configClass) {
+ var port = TestUtil.findAvailablePort();
+ baseUrl = "http://localhost:" + port;
+ startTomcat(configClass, port);
+ }
+
+ @AfterAll
+ static void afterAll() {
+ stopTomcat();
+ }
+
+ private McpSyncClient mcpClient;
+
+ private TestRequestCustomizer requestCustomizer;
+
+ @BeforeEach
+ void setUp() {
+ mcpClient = tomcatServer.appContext().getBean(McpSyncClient.class);
+ requestCustomizer = tomcatServer.appContext().getBean(TestRequestCustomizer.class);
+ }
+
+ @AfterEach
+ void tearDown() {
+ mcpClient.close();
+ }
+
+ @Test
+ void originAllowed() {
+ requestCustomizer.setOriginHeader(baseUrl);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void noOrigin() {
+ requestCustomizer.setOriginHeader(null);
+ var result = mcpClient.initialize();
+ var tools = mcpClient.listTools();
+
+ assertThat(result.protocolVersion()).isNotEmpty();
+ assertThat(tools.tools()).isEmpty();
+ }
+
+ @Test
+ void connectOriginNotAllowed() {
+ requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.initialize());
+ }
+
+ @Test
+ void messageOriginNotAllowed() {
+ requestCustomizer.setOriginHeader(baseUrl);
+ mcpClient.initialize();
+ requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN);
+ assertThatThrownBy(() -> mcpClient.listTools());
+ }
+
+ // ----------------------------------------------------
+ // Tomcat management
+ // ----------------------------------------------------
+
+ private static void startTomcat(Class> componentClass, int port) {
+ tomcatServer = TomcatTestUtil.createTomcatServer("", port, componentClass);
+ try {
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+ }
+
+ private static void stopTomcat() {
+ if (tomcatServer != null) {
+ if (tomcatServer.appContext() != null) {
+ tomcatServer.appContext().close();
+ }
+ if (tomcatServer.tomcat() != null) {
+ try {
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+ }
+
+ // ----------------------------------------------------
+ // Transport servers to test
+ // ----------------------------------------------------
+
+ /**
+ * All transport types we want to test. We use a {@link MethodSource} rather than a
+ * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name.
+ */
+ static Stream transports() {
+ //@formatter:off
+ return Stream.of(
+ arguments(named("SSE", SseConfig.class)),
+ arguments(named("Streamable HTTP", StreamableHttpConfig.class)),
+ arguments(named("Stateless", StatelessConfig.class))
+ );
+ //@formatter:on
+ }
+
+ // ----------------------------------------------------
+ // Spring Configuration classes
+ // ----------------------------------------------------
+
+ @Configuration
+ static class CommonConfig {
+
+ @Bean
+ TestRequestCustomizer requestCustomizer() {
+ return new TestRequestCustomizer();
+ }
+
+ @Bean
+ DefaultServerTransportSecurityValidator validator() {
+ return DefaultServerTransportSecurityValidator.builder().allowedOrigin("http://localhost:*").build();
+ }
+
+ }
+
+ @Configuration
+ @EnableWebMvc
+ @Import(CommonConfig.class)
+ static class SseConfig {
+
+ @Bean
+ @Scope("prototype")
+ McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) {
+ var transport = HttpClientSseClientTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Bean
+ public WebMvcSseServerTransportProvider webMvcSseServerTransport(
+ DefaultServerTransportSecurityValidator validator) {
+ return WebMvcSseServerTransportProvider.builder()
+ .messageEndpoint("/mcp/message")
+ .securityValidator(validator)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpSyncServer mcpServer(WebMvcSseServerTransportProvider transportProvider) {
+ return McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ }
+
+ @Configuration
+ @EnableWebMvc
+ @Import(CommonConfig.class)
+ static class StreamableHttpConfig {
+
+ @Bean
+ @Scope("prototype")
+ McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) {
+ var transport = HttpClientStreamableHttpTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Bean
+ public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport(
+ DefaultServerTransportSecurityValidator validator) {
+ return WebMvcStreamableServerTransportProvider.builder().securityValidator(validator).build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(
+ WebMvcStreamableServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpSyncServer mcpServer(WebMvcStreamableServerTransportProvider transportProvider) {
+ return McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ }
+
+ @Configuration
+ @EnableWebMvc
+ @Import(CommonConfig.class)
+ static class StatelessConfig {
+
+ @Bean
+ @Scope("prototype")
+ McpSyncClient createMcpClient(McpSyncHttpClientRequestCustomizer requestCustomizer) {
+ var transport = HttpClientStreamableHttpTransport.builder(baseUrl)
+ .httpRequestCustomizer(requestCustomizer)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .openConnectionOnStartup(true)
+ .build();
+ return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build();
+ }
+
+ @Bean
+ public WebMvcStatelessServerTransport webMvcStatelessServerTransport(
+ DefaultServerTransportSecurityValidator validator) {
+ return WebMvcStatelessServerTransport.builder().securityValidator(validator).build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ @Bean
+ public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) {
+ return McpServer.sync(transportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
+ .build();
+ }
+
+ }
+
+ static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer {
+
+ private String originHeader = null;
+
+ @Override
+ public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body,
+ McpTransportContext context) {
+ if (originHeader != null) {
+ builder.header("Origin", originHeader);
+ }
+ }
+
+ public void setOriginHeader(String originHeader) {
+ this.originHeader = originHeader;
+ }
+
+ }
+
+}