Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ class TransportSecuritySettings(BaseModel):
allowed_hosts: list[str] = Field(default_factory=list)
"""List of allowed Host header values.

Supports:
- Exact match: ``example.com``, ``127.0.0.1:8080``
- Wildcard port: ``example.com:*`` matches ``example.com`` with any port
- Subdomain wildcard: ``*.mysite.com`` matches ``mysite.com`` and any subdomain
(e.g. ``app.mysite.com``, ``api.mysite.com``). Optionally use ``*.mysite.com:*``
to also allow any port.

Only applies when `enable_dns_rebinding_protection` is `True`.
"""

Expand All @@ -40,6 +47,15 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _hostname_from_host(self, host: str) -> str:
"""Extract hostname from Host header (strip optional port)."""
if host.startswith("["):
idx = host.find("]:")
if idx != -1:
return host[: idx + 1]
return host
return host.split(":", 1)[0]

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
"""Validate the Host header against allowed values."""
if not host:
Expand All @@ -50,15 +66,27 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
if host in self.settings.allowed_hosts:
return True

# Check wildcard port patterns
# Check wildcard port patterns (e.g. example.com:*)
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
# Subdomain pattern *.domain.com:* is handled below; skip here
if base_host.startswith("*."):
continue
if host.startswith(base_host + ":"):
return True

# Check subdomain wildcard patterns (e.g. *.mysite.com or *.mysite.com:*)
hostname = self._hostname_from_host(host)
for allowed in self.settings.allowed_hosts:
if allowed.startswith("*."):
pattern = allowed[:-2] if allowed.endswith(":*") else allowed
base_domain = pattern[2:]
if not base_domain:
continue
if hostname == base_domain or hostname.endswith("." + base_domain):
return True

logger.warning(f"Invalid Host header: {host}")
return False

Expand Down
53 changes: 53 additions & 0 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,59 @@ async def test_sse_security_wildcard_ports(server_port: int):
process.join()


@pytest.mark.anyio
async def test_sse_security_ipv6_host_header(server_port: int):
"""Test SSE with IPv6 Host header ([::1] and [::1]:port) to cover _hostname_from_host."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["127.0.0.1:*", "[::1]:*", "[::1]"],
allowed_origins=["http://127.0.0.1:*", "http://[::1]:*"],
)
process = start_server_process(server_port, settings)

try:
async with httpx.AsyncClient(timeout=5.0) as client:
async with client.stream(
"GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]:8080"}
) as response:
assert response.status_code == 200
async with client.stream(
"GET", f"http://127.0.0.1:{server_port}/sse", headers={"Host": "[::1]"}
) as response:
assert response.status_code == 200
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_sse_security_subdomain_wildcard_host(server_port: int):
"""Test SSE with *.domain subdomain wildcard in allowed_hosts (issue #2141)."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["*.mysite.com", "127.0.0.1:*"],
allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"],
)
process = start_server_process(server_port, settings)

try:
# Allowed: subdomain and base domain
for host in ["app.mysite.com", "api.mysite.com", "mysite.com"]:
headers = {"Host": host}
async with httpx.AsyncClient(timeout=5.0) as client:
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
assert response.status_code == 200, f"Host {host} should be allowed"

# Rejected: other domain
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers={"Host": "other.com"})
assert response.status_code == 421
assert response.text == "Invalid Host header"
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_sse_security_post_valid_content_type(server_port: int):
"""Test POST endpoint with valid Content-Type headers."""
Expand Down
41 changes: 41 additions & 0 deletions tests/server/test_streamable_http_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,47 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int):
process.join()


@pytest.mark.anyio
async def test_streamable_http_security_subdomain_wildcard_host(server_port: int):
"""Test StreamableHTTP with *.domain subdomain wildcard in allowed_hosts (issue #2141)."""
settings = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["*.mysite.com", "127.0.0.1:*"],
allowed_origins=["http://127.0.0.1:*", "http://app.mysite.com:*"],
)
process = start_server_process(server_port, settings)

try:
headers = {
"Host": "app.mysite.com",
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"http://127.0.0.1:{server_port}/",
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
headers=headers,
)
assert response.status_code == 200

async with httpx.AsyncClient() as client:
response = await client.post(
f"http://127.0.0.1:{server_port}/",
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
headers={
"Host": "other.com",
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
},
)
assert response.status_code == 421
assert response.text == "Invalid Host header"
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_streamable_http_security_get_request(server_port: int):
"""Test StreamableHTTP GET request with security."""
Expand Down
21 changes: 21 additions & 0 deletions tests/server/test_transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Tests for transport security (DNS rebinding protection)."""

from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings


def test_hostname_from_host_ipv6_with_port():
"""_hostname_from_host strips port from [::1]:port (coverage for lines 52-55)."""
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
assert m._hostname_from_host("[::1]:8080") == "[::1]"


def test_hostname_from_host_ipv6_no_port():
"""_hostname_from_host returns [::1] as-is when no port (coverage for line 56)."""
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
assert m._hostname_from_host("[::1]") == "[::1]"


def test_hostname_from_host_plain_with_port():
"""_hostname_from_host strips port from hostname (coverage for line 57)."""
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
assert m._hostname_from_host("app.mysite.com:8080") == "app.mysite.com"