diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 7e8294ce6e..234c019346 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -19,7 +19,7 @@ class AuthorizationCodeRequest(BaseModel): grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") - client_id: str + client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header") # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 @@ -33,7 +33,7 @@ class RefreshTokenRequest(BaseModel): grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: str | None = Field(None, description="Optional scope parameter") - client_id: str + client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header") # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None # RFC 8707 resource indicator @@ -131,7 +131,7 @@ async def handle(self, request: Request): match token_request: case AuthorizationCodeRequest(): auth_code = await self.provider.load_authorization_code(client_info, token_request.code) - if auth_code is None or auth_code.client_id != token_request.client_id: + if auth_code is None or auth_code.client_id != client_info.client_id: # if code belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( @@ -197,7 +197,7 @@ async def handle(self, request: Request): case RefreshTokenRequest(): # pragma: no cover refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: + if refresh_token is None or refresh_token.client_id != client_info.client_id: # if token belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 6126c6e4f9..48176d8fa3 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -2,7 +2,8 @@ import binascii import hmac import time -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal from urllib.parse import unquote from starlette.requests import Request @@ -16,6 +17,13 @@ def __init__(self, message: str): self.message = message # pragma: no cover +@dataclass +class ClientCredentials: + auth_method: Literal["client_secret_basic", "client_secret_post"] + client_id: str + client_secret: str | None = None + + class ClientAuthenticator: """ ClientAuthenticator is a callable which validates requests from a client @@ -52,64 +60,86 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation Raises: AuthenticationError: If authentication fails """ - form_data = await request.form() - client_id = form_data.get("client_id") - if not client_id: - raise AuthenticationError("Missing client_id") + client_credentials = await self._get_credentials(request) + client = await self.provider.get_client(client_credentials.client_id) - client = await self.provider.get_client(str(client_id)) if not client: raise AuthenticationError("Invalid client_id") # pragma: no cover - request_client_secret: str | None = None - auth_header = request.headers.get("Authorization", "") - if client.token_endpoint_auth_method == "client_secret_basic": - if not auth_header.startswith("Basic "): - raise AuthenticationError("Missing or invalid Basic authentication in Authorization header") - - try: - encoded_credentials = auth_header[6:] # Remove "Basic " prefix - decoded = base64.b64decode(encoded_credentials).decode("utf-8") - if ":" not in decoded: - raise ValueError("Invalid Basic auth format") - basic_client_id, request_client_secret = decoded.split(":", 1) - - # URL-decode both parts per RFC 6749 Section 2.3.1 - basic_client_id = unquote(basic_client_id) - request_client_secret = unquote(request_client_secret) - - if basic_client_id != client_id: - raise AuthenticationError("Client ID mismatch in Basic auth") - except (ValueError, UnicodeDecodeError, binascii.Error): - raise AuthenticationError("Invalid Basic authentication header") - + if client_credentials.auth_method != "client_secret_basic": + raise AuthenticationError( + f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}" + ) elif client.token_endpoint_auth_method == "client_secret_post": - raw_form_data = form_data.get("client_secret") - # form_data.get() can return a UploadFile or None, so we need to check if it's a string - if isinstance(raw_form_data, str): - request_client_secret = str(raw_form_data) - + if client_credentials.auth_method != "client_secret_post": + raise AuthenticationError( + f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}" + ) elif client.token_endpoint_auth_method == "none": - request_client_secret = None - else: - raise AuthenticationError( # pragma: no cover - f"Unsupported auth method: {client.token_endpoint_auth_method}" - ) + pass + else: # pragma: no cover + raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") # If client from the store expects a secret, validate that the request provides # that secret if client.client_secret: # pragma: no branch - if not request_client_secret: + if not client_credentials.client_secret: raise AuthenticationError("Client secret is required") # pragma: no cover # hmac.compare_digest requires that both arguments are either bytes or a `str` containing # only ASCII characters. Since we do not control `request_client_secret`, we encode both # arguments to bytes. - if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()): + if not hmac.compare_digest(client.client_secret.encode(), client_credentials.client_secret.encode()): raise AuthenticationError("Invalid client_secret") # pragma: no cover if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): raise AuthenticationError("Client secret has expired") # pragma: no cover return client + + async def _get_credentials(self, request: Request) -> ClientCredentials: + """ + Extract client credentials from request, either from form data or Basic auth header. + + Basic auth header takes precedence over form data. + + Args: + request: The HTTP request containing client credentials + Returns: + The extracted client credentials + """ + # First, check for Basic auth header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + try: + encoded_credentials = auth_header[6:] # Remove "Basic " prefix + decoded = base64.b64decode(encoded_credentials).decode("utf-8") + if ":" not in decoded: + raise ValueError("Invalid Basic auth format") + client_id, client_secret = decoded.split(":", 1) + + # URL-decode the client_id per RFC 6749 Section 2.3.1 + client_id = unquote(client_id) + client_secret = unquote(client_secret) + return ClientCredentials( + auth_method="client_secret_basic", + client_id=client_id, + client_secret=client_secret, + ) + except (ValueError, UnicodeDecodeError, binascii.Error): + raise AuthenticationError("Invalid Basic authentication header") + + # If not, check for client_id and client_secret in form data + form_data = await request.form() + client_id = form_data.get("client_id") + if not client_id: + raise AuthenticationError("Missing client_id") + + raw_client_secret = form_data.get("client_secret") + client_secret = str(raw_client_secret) if isinstance(raw_client_secret, str) else None + return ClientCredentials( + auth_method="client_secret_post", + client_id=str(client_id), + client_secret=client_secret, + ) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 953d59aa14..4ce837d268 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1089,7 +1089,7 @@ async def test_client_secret_basic_authentication( assert "access_token" in token_response @pytest.mark.anyio - async def test_wrong_auth_method_without_valid_credentials_fails( + async def test_wrong_auth_method_fails( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] ): """Test that using the wrong authentication method fails when credentials are missing.""" @@ -1117,7 +1117,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( ) # Try to use Basic auth when client_secret_post is registered (without secret in body) - # This should fail because the secret is missing from the expected location + # This should fail despite that credentials are provided via Basic auth, because the method is wrong credentials = f"{client_info['client_id']}:{client_info['client_secret']}" encoded_credentials = base64.b64encode(credentials.encode()).decode() @@ -1138,7 +1138,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails( error_response = response.json() # RFC 6749: authentication failures return "invalid_client" assert error_response["error"] == "invalid_client" - assert "Client secret is required" in error_response["error_description"] + assert "Expected client_secret_post authentication method" in error_response["error_description"] @pytest.mark.anyio async def test_basic_auth_without_header_fails( @@ -1183,7 +1183,7 @@ async def test_basic_auth_without_header_fails( error_response = response.json() # RFC 6749: authentication failures return "invalid_client" assert error_response["error"] == "invalid_client" - assert "Missing or invalid Basic authentication" in error_response["error_description"] + assert "Expected client_secret_basic authentication method" in error_response["error_description"] @pytest.mark.anyio async def test_basic_auth_invalid_base64_fails( @@ -1279,10 +1279,10 @@ async def test_basic_auth_no_colon_fails( assert "Invalid Basic authentication header" in error_response["error_description"] @pytest.mark.anyio - async def test_basic_auth_client_id_mismatch_fails( + async def test_basic_auth_takes_precedence( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] ): - """Test that client_id mismatch between body and Basic auth fails.""" + """Test that even client_id at body is invalid, Basic auth passes because of the priority.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Basic Auth Client", @@ -1308,23 +1308,83 @@ async def test_basic_auth_client_id_mismatch_fails( # Send different client_id in Basic auth header import base64 - wrong_creds = base64.b64encode(f"wrong-client-id:{client_info['client_secret']}".encode()).decode() + creds = base64.b64encode(f"{client_info['client_id']}:{client_info['client_secret']}".encode()).decode() response = await test_client.post( "/token", - headers={"Authorization": f"Basic {wrong_creds}"}, + headers={"Authorization": f"Basic {creds}"}, data={ "grant_type": "authorization_code", - "client_id": client_info["client_id"], # Correct client_id in body + "client_id": "wrong-client-id", # Wrong client_id in body "code": auth_code, "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) - assert response.status_code == 401 - error_response = response.json() - # RFC 6749: authentication failures return "invalid_client" - assert error_response["error"] == "invalid_client" - assert "Client ID mismatch" in error_response["error_description"] + + # Header takes precedence, so this should succeed + assert response.status_code == 200 + + @pytest.mark.anyio + async def test_basic_auth_without_client_id_at_body( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that Basic auth works even if client_id is missing from body.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + credentials = f"{client_info['client_id']}:{client_info['client_secret']}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "authorization_code", + # client_id omitted from body + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + token_response = response.json() + assert "access_token" in token_response + assert "refresh_token" in token_response + + refresh_token = token_response["refresh_token"] + + # Now, use the refresh token without client_id in body + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "refresh_token", + # client_id omitted from body + "refresh_token": refresh_token, + }, + ) + assert response.status_code == 200 + new_token_response = response.json() + assert "access_token" in new_token_response @pytest.mark.anyio async def test_none_auth_method_public_client(