From 38cbae1bb77d087bb05e924d9a40e80af1b55bd0 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 15:42:50 +0900 Subject: [PATCH 01/15] test: add case for token request with basic auth header This test case only contains client_id and client_secret at Basic Authentication header, assuming that there is no client_id and client_secret at form data. --- tests/server/auth/handlers/test_token.py | 265 +++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 tests/server/auth/handlers/test_token.py diff --git a/tests/server/auth/handlers/test_token.py b/tests/server/auth/handlers/test_token.py new file mode 100644 index 000000000..594beb5d9 --- /dev/null +++ b/tests/server/auth/handlers/test_token.py @@ -0,0 +1,265 @@ +""" +Tests for the TokenHandler. +""" + +import base64 +import hashlib +import time +from typing import Any, cast + +import pytest +from pydantic import AnyUrl +from starlette.requests import Request +from starlette.types import Message, Scope + +from mcp.server.auth.handlers.token import TokenHandler +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + OAuthAuthorizationServerProvider, + RefreshToken, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + + +class MockOAuthProvider: + """Mock OAuth provider for testing.""" + + def __init__(self): + self.clients: dict[str, OAuthClientInformationFull] = {} + self.authorization_codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def add_client(self, client: OAuthClientInformationFull) -> None: + """Add a client to the provider.""" + if client.client_id: + self.clients[client.client_id] = client + + def add_authorization_code(self, code: str, auth_code: AuthorizationCode) -> None: + """Add an authorization code.""" + self.authorization_codes[code] = auth_code + + def add_refresh_token(self, token: str, refresh_token: RefreshToken) -> None: + """Add a refresh token.""" + self.refresh_tokens[token] = refresh_token + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get client by ID.""" + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + """Register a client (not used in these tests).""" + pass # pragma: no cover + + async def authorize(self, client: OAuthClientInformationFull, params: Any) -> str: + """Authorize a client (not used in these tests).""" + return "" # pragma: no cover + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + """Load authorization code.""" + return self.authorization_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + return OAuthToken( + access_token="mock_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="mock_refresh_token", + ) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + """Load refresh token.""" + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token for new tokens.""" + return OAuthToken( + access_token="mock_new_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="mock_new_refresh_token", + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load an access token.""" + return self.access_tokens.get(token) # pragma: no cover + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Revoke a token (not used in these tests).""" + pass # pragma: no cover + + +class MockClientAuthenticator: + """Mock client authenticator for testing.""" + + def __init__(self): + self.should_fail = False + self.client_to_return: OAuthClientInformationFull | None = None + + async def authenticate_request(self, request: Request) -> OAuthClientInformationFull: + """Authenticate a client request.""" + if self.should_fail: + raise AuthenticationError("Authentication failed") + if self.client_to_return is None: + raise AuthenticationError("No client configured") + return self.client_to_return + + +def create_mock_request(form_data: dict[str, str], headers: dict[str, str] | None = None) -> Request: + """Create a mock Starlette Request with form data and headers.""" + raw_headers: list[tuple[bytes, bytes]] = [] + if headers: + for key, value in headers.items(): + raw_headers.append((key.lower().encode(), value.encode())) + + raw_headers.append((b"content-type", b"application/x-www-form-urlencoded")) + + scope: Scope = { + "type": "http", + "method": "POST", + "headers": raw_headers, + } + + # Create a simple receive callable that returns form data + messages: list[Message] = [] + + # Encode form data + encoded_body = "&".join(f"{k}={v}" for k, v in form_data.items()).encode() + messages.append( + { + "type": "http.request", + "body": encoded_body, + } + ) + + async def receive() -> Message: + if messages: + return messages.pop(0) + return {"type": "http.disconnect"} + + request = Request(scope, receive) + return request + + +def generate_code_verifier() -> str: + """Generate a PKCE code verifier.""" + return "test_code_verifier_with_sufficient_length_for_pkce_validation" + + +def generate_code_challenge(verifier: str) -> str: + """Generate a PKCE code challenge from a verifier.""" + sha256 = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + +@pytest.fixture +def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]: + """Create a mock OAuth provider.""" + return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider()) + + +@pytest.fixture +def mock_client_authenticator() -> ClientAuthenticator: + """Create a mock client authenticator.""" + return cast(ClientAuthenticator, MockClientAuthenticator()) + + +@pytest.fixture +def test_client() -> OAuthClientInformationFull: + """Create a test client.""" + return OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("https://example.com/callback")], + token_endpoint_auth_method="client_secret_basic", + grant_types=["authorization_code", "refresh_token"], + ) + + +@pytest.mark.anyio +class TestTokenHandlerAuthBasic: + """Tests for TokenHandler with Auth Basic header.""" + + async def test_auth_basic_without_form_client_credentials( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + test_client: OAuthClientInformationFull, + ): + """Test token request with Auth Basic header but no client_id/client_secret in form data. + + This test validates the scenario where: + - The client uses HTTP Basic authentication (Authorization: Basic header) + - The form data does NOT include client_id or client_secret fields + - The handler should response correctly + + Note: This test may fail if the current implementation does not properly + handle the case where client_id is missing from form_data. + """ + # Setup provider + provider = cast(MockOAuthProvider, mock_oauth_provider) + provider.add_client(test_client) + # Create REAL authenticator (not mock) to test actual behavior + authenticator = ClientAuthenticator(provider=provider) + + # Create handler + handler = TokenHandler(provider=provider, client_authenticator=authenticator) + + # Generate PKCE values + code_verifier = generate_code_verifier() + code_challenge = generate_code_challenge(code_verifier) + + # Add authorization code to provider + auth_code = AuthorizationCode( + code="test_auth_code", + scopes=["read", "write"], + expires_at=time.time() + 600, # 10 minutes from now + client_id="test_client", + code_challenge=code_challenge, + redirect_uri=AnyUrl("https://example.com/callback"), + redirect_uri_provided_explicitly=True, + ) + provider.add_authorization_code("test_auth_code", auth_code) + + # Create Basic auth header + credentials = f"{test_client.client_id}:{test_client.client_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + auth_header = f"Basic {encoded_credentials}" + + # Create form data WITHOUT client_id and client_secret + form_data = { + "grant_type": "authorization_code", + "code": "test_auth_code", + "redirect_uri": "https://example.com/callback", + "code_verifier": code_verifier, + } + + # Create request with Auth Basic header + request = create_mock_request(form_data, headers={"Authorization": auth_header}) + + # Execute the handler + response = await handler.handle(request) + + + # Validate the response + # Note: This test may fail if client_id is not extracted from the Basic auth header + # or form_data, since the handler expects client_id in the form_data + assert response is not None + + if response.status_code != 200: + # If not successful, print the response body for debugging + body_bytes = bytes(response.body) + body = body_bytes.decode() + pytest.fail(f"Handler response error: {body}") + From 185d2ad9df5914fa1bcf783fb987aa890e766fe4 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 18:58:22 +0900 Subject: [PATCH 02/15] test: add case for basic auth only --- .../fastmcp/auth/test_auth_integration.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 953d59aa1..e9b8f45c4 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1325,6 +1325,51 @@ async def test_basic_auth_client_id_mismatch_fails( # RFC 6749: authentication failures return "invalid_client" assert error_response["error"] == "invalid_client" assert "Client ID mismatch" in error_response["error_description"] + + @pytest.mark.anyio + async def test_basic_auth_without_client_id_at_body( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that Basic auth works even if client_id is missing from body.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + credentials = f"{client_info['client_id']}:{client_info['client_secret']}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "authorization_code", + # client_id omitted from body + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + token_response = response.json() + assert "access_token" in token_response @pytest.mark.anyio async def test_none_auth_method_public_client( From f3d84f7e4176482661bcd424372f2307d66f1379 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 18:58:48 +0900 Subject: [PATCH 03/15] fix: resolve client_id from client_info --- src/mcp/server/auth/handlers/token.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 7e8294ce6..d86494b8c 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -19,7 +19,7 @@ class AuthorizationCodeRequest(BaseModel): grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") - client_id: str + client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header") # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 @@ -131,7 +131,7 @@ async def handle(self, request: Request): match token_request: case AuthorizationCodeRequest(): auth_code = await self.provider.load_authorization_code(client_info, token_request.code) - if auth_code is None or auth_code.client_id != token_request.client_id: + if auth_code is None or auth_code.client_id != client_info.client_id: # if code belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( @@ -197,7 +197,7 @@ async def handle(self, request: Request): case RefreshTokenRequest(): # pragma: no cover refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: + if refresh_token is None or refresh_token.client_id != client_info.client_id: # if token belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( From cc4a70ff32836b71851a3c8dfac32da47ed0caff Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 18:59:58 +0900 Subject: [PATCH 04/15] feat: add ClientCredentials This data class contains creds plus auth method --- src/mcp/server/auth/middleware/client_auth.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 6126c6e4f..c66db3b5c 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -2,7 +2,8 @@ import binascii import hmac import time -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal from urllib.parse import unquote from starlette.requests import Request @@ -16,6 +17,13 @@ def __init__(self, message: str): self.message = message # pragma: no cover +@dataclass +class ClientCredentials: + auth_method: Literal["client_secret_basic", "client_secret_post"] + client_id: str + client_secret: str | None = None + + class ClientAuthenticator: """ ClientAuthenticator is a callable which validates requests from a client From ea1f8c7f70e1d119a46c7104657eff0fad788507 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 19:01:07 +0900 Subject: [PATCH 05/15] feat: seperate credential extraction into a method --- src/mcp/server/auth/middleware/client_auth.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index c66db3b5c..0b4772b3f 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -121,3 +121,49 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation raise AuthenticationError("Client secret has expired") # pragma: no cover return client + + async def _get_credentials(self, request: Request) -> ClientCredentials: + """ + Extract client credentials from request, either from form data or Basic auth header. + + Basic auth header takes precedence over form data. + + Args: + request: The HTTP request containing client credentials + Returns: + The extracted client credentials + """ + # First, check for Basic auth header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + try: + encoded_credentials = auth_header[6:] # Remove "Basic " prefix + decoded = base64.b64decode(encoded_credentials).decode("utf-8") + if ":" not in decoded: + raise ValueError("Invalid Basic auth format") + client_id, client_secret = decoded.split(":", 1) + + # URL-decode the client_id per RFC 6749 Section 2.3.1 + client_id = unquote(client_id) + client_secret = unquote(client_secret) + return ClientCredentials( + auth_method="client_secret_basic", + client_id=client_id, + client_secret=client_secret, + ) + except (ValueError, UnicodeDecodeError, binascii.Error): + raise AuthenticationError("Invalid Basic authentication header") + + # If not, check for client_id and client_secret in form data + form_data = await request.form() + client_id = form_data.get("client_id") + if not client_id: + raise AuthenticationError("Missing client_id") + + raw_client_secret = form_data.get("client_secret") + client_secret = str(raw_client_secret) if isinstance(raw_client_secret, str) else None + return ClientCredentials( + auth_method="client_secret_post", + client_id=str(client_id), + client_secret=client_secret, + ) From 85ce44b8ee8cead0b7de63e6eec25923d74fc8de Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 19:01:43 +0900 Subject: [PATCH 06/15] fix: retrieve client_id first then compare auth method Retrieve client_id from auth header or the body, then retrieve client via that client_id. After that, compare auth method. --- src/mcp/server/auth/middleware/client_auth.py | 58 +++++-------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 0b4772b3f..d1112caf1 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -60,61 +60,33 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation Raises: AuthenticationError: If authentication fails """ - form_data = await request.form() - client_id = form_data.get("client_id") - if not client_id: - raise AuthenticationError("Missing client_id") - - client = await self.provider.get_client(str(client_id)) + client_credentials = await self._get_credentials(request) + client = await self.provider.get_client(str(client_credentials.client_id)) if not client: raise AuthenticationError("Invalid client_id") # pragma: no cover - request_client_secret: str | None = None - auth_header = request.headers.get("Authorization", "") - - if client.token_endpoint_auth_method == "client_secret_basic": - if not auth_header.startswith("Basic "): - raise AuthenticationError("Missing or invalid Basic authentication in Authorization header") - - try: - encoded_credentials = auth_header[6:] # Remove "Basic " prefix - decoded = base64.b64decode(encoded_credentials).decode("utf-8") - if ":" not in decoded: - raise ValueError("Invalid Basic auth format") - basic_client_id, request_client_secret = decoded.split(":", 1) - - # URL-decode both parts per RFC 6749 Section 2.3.1 - basic_client_id = unquote(basic_client_id) - request_client_secret = unquote(request_client_secret) - - if basic_client_id != client_id: - raise AuthenticationError("Client ID mismatch in Basic auth") - except (ValueError, UnicodeDecodeError, binascii.Error): - raise AuthenticationError("Invalid Basic authentication header") - - elif client.token_endpoint_auth_method == "client_secret_post": - raw_form_data = form_data.get("client_secret") - # form_data.get() can return a UploadFile or None, so we need to check if it's a string - if isinstance(raw_form_data, str): - request_client_secret = str(raw_form_data) - - elif client.token_endpoint_auth_method == "none": - request_client_secret = None - else: - raise AuthenticationError( # pragma: no cover - f"Unsupported auth method: {client.token_endpoint_auth_method}" - ) + match client.token_endpoint_auth_method: + case "client_secret_basic": + if client_credentials.auth_method != "client_secret_basic": + raise AuthenticationError(f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}") + case "client_secret_post": + if client_credentials.auth_method != "client_secret_post": + raise AuthenticationError(f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}") + case "none": + pass + case _: + raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # pragma: no cover # If client from the store expects a secret, validate that the request provides # that secret if client.client_secret: # pragma: no branch - if not request_client_secret: + if not client_credentials.client_secret: raise AuthenticationError("Client secret is required") # pragma: no cover # hmac.compare_digest requires that both arguments are either bytes or a `str` containing # only ASCII characters. Since we do not control `request_client_secret`, we encode both # arguments to bytes. - if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()): + if not hmac.compare_digest(client.client_secret.encode(), client_credentials.client_secret.encode()): raise AuthenticationError("Invalid client_secret") # pragma: no cover if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): From 57886929351a5e798a65d8c4338f916bfff6c6c4 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 19:03:07 +0900 Subject: [PATCH 07/15] test: update unit test --- .../fastmcp/auth/test_auth_integration.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index e9b8f45c4..d0eb436f6 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1117,7 +1117,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( ) # Try to use Basic auth when client_secret_post is registered (without secret in body) - # This should fail because the secret is missing from the expected location + # This should fail despite that credentials are provided via Basic auth, because the method is wrong credentials = f"{client_info['client_id']}:{client_info['client_secret']}" encoded_credentials = base64.b64encode(credentials.encode()).decode() @@ -1138,7 +1138,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( error_response = response.json() # RFC 6749: authentication failures return "invalid_client" assert error_response["error"] == "invalid_client" - assert "Client secret is required" in error_response["error_description"] + assert "Expected client_secret_post authentication method" in error_response["error_description"] @pytest.mark.anyio async def test_basic_auth_without_header_fails( @@ -1183,7 +1183,7 @@ async def test_basic_auth_without_header_fails( error_response = response.json() # RFC 6749: authentication failures return "invalid_client" assert error_response["error"] == "invalid_client" - assert "Missing or invalid Basic authentication" in error_response["error_description"] + assert "Expected client_secret_basic authentication method" in error_response["error_description"] @pytest.mark.anyio async def test_basic_auth_invalid_base64_fails( @@ -1279,10 +1279,10 @@ async def test_basic_auth_no_colon_fails( assert "Invalid Basic authentication header" in error_response["error_description"] @pytest.mark.anyio - async def test_basic_auth_client_id_mismatch_fails( + async def test_basic_auth_takes_precedence( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] ): - """Test that client_id mismatch between body and Basic auth fails.""" + """Test that even client_id at body is invalid, Basic auth passes because of the priority.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Basic Auth Client", @@ -1308,23 +1308,21 @@ async def test_basic_auth_client_id_mismatch_fails( # Send different client_id in Basic auth header import base64 - wrong_creds = base64.b64encode(f"wrong-client-id:{client_info['client_secret']}".encode()).decode() + creds = base64.b64encode(f"{client_info['client_id']}:{client_info['client_secret']}".encode()).decode() response = await test_client.post( "/token", - headers={"Authorization": f"Basic {wrong_creds}"}, + headers={"Authorization": f"Basic {creds}"}, data={ "grant_type": "authorization_code", - "client_id": client_info["client_id"], # Correct client_id in body + "client_id": "wrong-client-id", # Wrong client_id in body "code": auth_code, "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) - assert response.status_code == 401 - error_response = response.json() - # RFC 6749: authentication failures return "invalid_client" - assert error_response["error"] == "invalid_client" - assert "Client ID mismatch" in error_response["error_description"] + + # Header takes precedence, so this should succeed + assert response.status_code == 200 @pytest.mark.anyio async def test_basic_auth_without_client_id_at_body( From c7e46f339f8346c8db85513dbd43cbd2e07eed92 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 19:06:06 +0900 Subject: [PATCH 08/15] test: rm duplicated test --- tests/server/auth/handlers/test_token.py | 265 ----------------------- 1 file changed, 265 deletions(-) delete mode 100644 tests/server/auth/handlers/test_token.py diff --git a/tests/server/auth/handlers/test_token.py b/tests/server/auth/handlers/test_token.py deleted file mode 100644 index 594beb5d9..000000000 --- a/tests/server/auth/handlers/test_token.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -Tests for the TokenHandler. -""" - -import base64 -import hashlib -import time -from typing import Any, cast - -import pytest -from pydantic import AnyUrl -from starlette.requests import Request -from starlette.types import Message, Scope - -from mcp.server.auth.handlers.token import TokenHandler -from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator -from mcp.server.auth.provider import ( - AccessToken, - AuthorizationCode, - OAuthAuthorizationServerProvider, - RefreshToken, -) -from mcp.shared.auth import OAuthClientInformationFull, OAuthToken - - -class MockOAuthProvider: - """Mock OAuth provider for testing.""" - - def __init__(self): - self.clients: dict[str, OAuthClientInformationFull] = {} - self.authorization_codes: dict[str, AuthorizationCode] = {} - self.refresh_tokens: dict[str, RefreshToken] = {} - self.access_tokens: dict[str, AccessToken] = {} - - def add_client(self, client: OAuthClientInformationFull) -> None: - """Add a client to the provider.""" - if client.client_id: - self.clients[client.client_id] = client - - def add_authorization_code(self, code: str, auth_code: AuthorizationCode) -> None: - """Add an authorization code.""" - self.authorization_codes[code] = auth_code - - def add_refresh_token(self, token: str, refresh_token: RefreshToken) -> None: - """Add a refresh token.""" - self.refresh_tokens[token] = refresh_token - - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """Get client by ID.""" - return self.clients.get(client_id) - - async def register_client(self, client_info: OAuthClientInformationFull) -> None: - """Register a client (not used in these tests).""" - pass # pragma: no cover - - async def authorize(self, client: OAuthClientInformationFull, params: Any) -> str: - """Authorize a client (not used in these tests).""" - return "" # pragma: no cover - - async def load_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCode | None: - """Load authorization code.""" - return self.authorization_codes.get(authorization_code) - - async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode - ) -> OAuthToken: - """Exchange authorization code for tokens.""" - return OAuthToken( - access_token="mock_access_token", - token_type="Bearer", - expires_in=3600, - refresh_token="mock_refresh_token", - ) - - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: - """Load refresh token.""" - return self.refresh_tokens.get(refresh_token) - - async def exchange_refresh_token( - self, - client: OAuthClientInformationFull, - refresh_token: RefreshToken, - scopes: list[str], - ) -> OAuthToken: - """Exchange refresh token for new tokens.""" - return OAuthToken( - access_token="mock_new_access_token", - token_type="Bearer", - expires_in=3600, - refresh_token="mock_new_refresh_token", - ) - - async def load_access_token(self, token: str) -> AccessToken | None: - """Load an access token.""" - return self.access_tokens.get(token) # pragma: no cover - - async def revoke_token(self, token: AccessToken | RefreshToken) -> None: - """Revoke a token (not used in these tests).""" - pass # pragma: no cover - - -class MockClientAuthenticator: - """Mock client authenticator for testing.""" - - def __init__(self): - self.should_fail = False - self.client_to_return: OAuthClientInformationFull | None = None - - async def authenticate_request(self, request: Request) -> OAuthClientInformationFull: - """Authenticate a client request.""" - if self.should_fail: - raise AuthenticationError("Authentication failed") - if self.client_to_return is None: - raise AuthenticationError("No client configured") - return self.client_to_return - - -def create_mock_request(form_data: dict[str, str], headers: dict[str, str] | None = None) -> Request: - """Create a mock Starlette Request with form data and headers.""" - raw_headers: list[tuple[bytes, bytes]] = [] - if headers: - for key, value in headers.items(): - raw_headers.append((key.lower().encode(), value.encode())) - - raw_headers.append((b"content-type", b"application/x-www-form-urlencoded")) - - scope: Scope = { - "type": "http", - "method": "POST", - "headers": raw_headers, - } - - # Create a simple receive callable that returns form data - messages: list[Message] = [] - - # Encode form data - encoded_body = "&".join(f"{k}={v}" for k, v in form_data.items()).encode() - messages.append( - { - "type": "http.request", - "body": encoded_body, - } - ) - - async def receive() -> Message: - if messages: - return messages.pop(0) - return {"type": "http.disconnect"} - - request = Request(scope, receive) - return request - - -def generate_code_verifier() -> str: - """Generate a PKCE code verifier.""" - return "test_code_verifier_with_sufficient_length_for_pkce_validation" - - -def generate_code_challenge(verifier: str) -> str: - """Generate a PKCE code challenge from a verifier.""" - sha256 = hashlib.sha256(verifier.encode()).digest() - return base64.urlsafe_b64encode(sha256).decode().rstrip("=") - - -@pytest.fixture -def mock_oauth_provider() -> OAuthAuthorizationServerProvider[Any, Any, Any]: - """Create a mock OAuth provider.""" - return cast(OAuthAuthorizationServerProvider[Any, Any, Any], MockOAuthProvider()) - - -@pytest.fixture -def mock_client_authenticator() -> ClientAuthenticator: - """Create a mock client authenticator.""" - return cast(ClientAuthenticator, MockClientAuthenticator()) - - -@pytest.fixture -def test_client() -> OAuthClientInformationFull: - """Create a test client.""" - return OAuthClientInformationFull( - client_id="test_client", - client_secret="test_secret", - redirect_uris=[AnyUrl("https://example.com/callback")], - token_endpoint_auth_method="client_secret_basic", - grant_types=["authorization_code", "refresh_token"], - ) - - -@pytest.mark.anyio -class TestTokenHandlerAuthBasic: - """Tests for TokenHandler with Auth Basic header.""" - - async def test_auth_basic_without_form_client_credentials( - self, - mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], - test_client: OAuthClientInformationFull, - ): - """Test token request with Auth Basic header but no client_id/client_secret in form data. - - This test validates the scenario where: - - The client uses HTTP Basic authentication (Authorization: Basic header) - - The form data does NOT include client_id or client_secret fields - - The handler should response correctly - - Note: This test may fail if the current implementation does not properly - handle the case where client_id is missing from form_data. - """ - # Setup provider - provider = cast(MockOAuthProvider, mock_oauth_provider) - provider.add_client(test_client) - # Create REAL authenticator (not mock) to test actual behavior - authenticator = ClientAuthenticator(provider=provider) - - # Create handler - handler = TokenHandler(provider=provider, client_authenticator=authenticator) - - # Generate PKCE values - code_verifier = generate_code_verifier() - code_challenge = generate_code_challenge(code_verifier) - - # Add authorization code to provider - auth_code = AuthorizationCode( - code="test_auth_code", - scopes=["read", "write"], - expires_at=time.time() + 600, # 10 minutes from now - client_id="test_client", - code_challenge=code_challenge, - redirect_uri=AnyUrl("https://example.com/callback"), - redirect_uri_provided_explicitly=True, - ) - provider.add_authorization_code("test_auth_code", auth_code) - - # Create Basic auth header - credentials = f"{test_client.client_id}:{test_client.client_secret}" - encoded_credentials = base64.b64encode(credentials.encode()).decode() - auth_header = f"Basic {encoded_credentials}" - - # Create form data WITHOUT client_id and client_secret - form_data = { - "grant_type": "authorization_code", - "code": "test_auth_code", - "redirect_uri": "https://example.com/callback", - "code_verifier": code_verifier, - } - - # Create request with Auth Basic header - request = create_mock_request(form_data, headers={"Authorization": auth_header}) - - # Execute the handler - response = await handler.handle(request) - - - # Validate the response - # Note: This test may fail if client_id is not extracted from the Basic auth header - # or form_data, since the handler expects client_id in the form_data - assert response is not None - - if response.status_code != 200: - # If not successful, print the response body for debugging - body_bytes = bytes(response.body) - body = body_bytes.decode() - pytest.fail(f"Handler response error: {body}") - From 4f789a7f64bab2897010fd6acbbc3513c6f4f633 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 19:13:20 +0900 Subject: [PATCH 09/15] chore: address ruff --- src/mcp/server/auth/middleware/client_auth.py | 18 ++++++++++++------ .../fastmcp/auth/test_auth_integration.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index d1112caf1..6d1c256b1 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -68,14 +68,20 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation match client.token_endpoint_auth_method: case "client_secret_basic": if client_credentials.auth_method != "client_secret_basic": - raise AuthenticationError(f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}") + raise AuthenticationError( + f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}" + ) case "client_secret_post": if client_credentials.auth_method != "client_secret_post": - raise AuthenticationError(f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}") + raise AuthenticationError( + f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}" + ) case "none": pass case _: - raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # pragma: no cover + raise AuthenticationError( # pragma: no cover + f"Unsupported auth method: {client.token_endpoint_auth_method}" + ) # If client from the store expects a secret, validate that the request provides # that secret @@ -97,9 +103,9 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation async def _get_credentials(self, request: Request) -> ClientCredentials: """ Extract client credentials from request, either from form data or Basic auth header. - + Basic auth header takes precedence over form data. - + Args: request: The HTTP request containing client credentials Returns: @@ -131,7 +137,7 @@ async def _get_credentials(self, request: Request) -> ClientCredentials: client_id = form_data.get("client_id") if not client_id: raise AuthenticationError("Missing client_id") - + raw_client_secret = form_data.get("client_secret") client_secret = str(raw_client_secret) if isinstance(raw_client_secret, str) else None return ClientCredentials( diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index d0eb436f6..6faeda4ce 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1323,7 +1323,7 @@ async def test_basic_auth_takes_precedence( # Header takes precedence, so this should succeed assert response.status_code == 200 - + @pytest.mark.anyio async def test_basic_auth_without_client_id_at_body( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] From 368c2a882468da4750fe13c18745086f800cfa0d Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 20:13:13 +0900 Subject: [PATCH 10/15] test: add case for refresh token --- .../fastmcp/auth/test_auth_integration.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 6faeda4ce..4ce837d26 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1089,7 +1089,7 @@ async def test_client_secret_basic_authentication( assert "access_token" in token_response @pytest.mark.anyio - async def test_wrong_auth_method_without_valid_credentials_fails( + async def test_wrong_auth_method_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] ): """Test that using the wrong authentication method fails when credentials are missing.""" @@ -1368,6 +1368,23 @@ async def test_basic_auth_without_client_id_at_body( assert response.status_code == 200 token_response = response.json() assert "access_token" in token_response + assert "refresh_token" in token_response + + refresh_token = token_response["refresh_token"] + + # Now, use the refresh token without client_id in body + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "refresh_token", + # client_id omitted from body + "refresh_token": refresh_token, + }, + ) + assert response.status_code == 200 + new_token_response = response.json() + assert "access_token" in new_token_response @pytest.mark.anyio async def test_none_auth_method_public_client( From c6639a60716fbf97fc1b1941830c053b471a454e Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 20:13:21 +0900 Subject: [PATCH 11/15] fix: apply changes to RefreshTokenRequest --- src/mcp/server/auth/handlers/token.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index d86494b8c..234c01934 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -33,7 +33,7 @@ class RefreshTokenRequest(BaseModel): grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: str | None = Field(None, description="Optional scope parameter") - client_id: str + client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header") # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # RFC 8707 resource indicator From b2641a574643b1c0e0492ae2fdf4cfa22dd24bf5 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 20:17:41 +0900 Subject: [PATCH 12/15] refactor: rm redundant str casting --- src/mcp/server/auth/middleware/client_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 6d1c256b1..2c3ad92a0 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -61,7 +61,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation AuthenticationError: If authentication fails """ client_credentials = await self._get_credentials(request) - client = await self.provider.get_client(str(client_credentials.client_id)) + client = await self.provider.get_client(client_credentials.client_id) if not client: raise AuthenticationError("Invalid client_id") # pragma: no cover From 64b4446680d7884193355f189843549d5147b58b Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 21:03:39 +0900 Subject: [PATCH 13/15] chore: fix coverage --- src/mcp/server/auth/middleware/client_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 2c3ad92a0..6fbb9e0ad 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -78,8 +78,8 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation ) case "none": pass - case _: - raise AuthenticationError( # pragma: no cover + case _: # pragma: no cover + raise AuthenticationError( f"Unsupported auth method: {client.token_endpoint_auth_method}" ) From 7d59fa7cd472e74e777a3ea88dfa39b6c5ac7b9c Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 21:06:09 +0900 Subject: [PATCH 14/15] chore: address ruff --- src/mcp/server/auth/middleware/client_auth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 6fbb9e0ad..75da747af 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -79,9 +79,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation case "none": pass case _: # pragma: no cover - raise AuthenticationError( - f"Unsupported auth method: {client.token_endpoint_auth_method}" - ) + raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # If client from the store expects a secret, validate that the request provides # that secret From d284f6f458dbd6bd4cccf5ba0ae7de8ad3c7cd09 Mon Sep 17 00:00:00 2001 From: Mason Oh Date: Sun, 11 Jan 2026 21:29:04 +0900 Subject: [PATCH 15/15] refactor: use if rather than match `match` causes coverage error on python 3.10. --- src/mcp/server/auth/middleware/client_auth.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 75da747af..48176d8fa 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -62,24 +62,24 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation """ client_credentials = await self._get_credentials(request) client = await self.provider.get_client(client_credentials.client_id) + if not client: raise AuthenticationError("Invalid client_id") # pragma: no cover - match client.token_endpoint_auth_method: - case "client_secret_basic": - if client_credentials.auth_method != "client_secret_basic": - raise AuthenticationError( - f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}" - ) - case "client_secret_post": - if client_credentials.auth_method != "client_secret_post": - raise AuthenticationError( - f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}" - ) - case "none": - pass - case _: # pragma: no cover - raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") + if client.token_endpoint_auth_method == "client_secret_basic": + if client_credentials.auth_method != "client_secret_basic": + raise AuthenticationError( + f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}" + ) + elif client.token_endpoint_auth_method == "client_secret_post": + if client_credentials.auth_method != "client_secret_post": + raise AuthenticationError( + f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}" + ) + elif client.token_endpoint_auth_method == "none": + pass + else: # pragma: no cover + raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # If client from the store expects a secret, validate that the request provides # that secret