diff --git a/msal/application.py b/msal/application.py index ba16df83..e4dac87e 100644 --- a/msal/application.py +++ b/msal/application.py @@ -15,7 +15,7 @@ from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * -from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER +from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key import msal.telemetry from .region import _detect_region from .throttled_http_client import ThrottledHttpClient @@ -1571,6 +1571,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( key_id = kwargs.get("data", {}).get("key_id") if key_id: # Some token types (SSH-certs, POP) are bound to a key query["key_id"] = key_id + ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {})) + if ext_cache_key: # FMI tokens need cache isolation by path + query["ext_cache_key"] = ext_cache_key now = time.time() refresh_reason = msal.telemetry.AT_ABSENT for entry in self.token_cache.search( # A generator allows us to @@ -2424,7 +2427,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app except that ``allow_broker`` parameter shall remain ``None``. """ - def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): + def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, **kwargs): """Acquires token for the current confidential client, not for an end user. Since MSAL Python 1.23, it will automatically look for token from cache, @@ -2437,7 +2440,17 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): in the form of a claims_challenge directive in the www-authenticate header to be returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. It is a string of a JSON object which contains lists of claims being requested from these locations. - + :param str fmi_path: + Optional. The Federated Managed Identity (FMI) credential path. + When provided, it is sent as the ``fmi_path`` parameter in the + token request body, and the resulting token is cached separately + so that different FMI paths do not share cached tokens. + Example usage:: + + result = cca.acquire_token_for_client( + scopes=["api://resource/.default"], + fmi_path="SomeFmiPath/FmiCredentialPath", + ) :return: A dict representing the json response from Microsoft Entra: - A successful response would contain "access_token" key, @@ -2447,6 +2460,12 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): raise ValueError( # We choose to disallow force_refresh "Historically, this method does not support force_refresh behavior. " ) + if fmi_path is not None: + if not isinstance(fmi_path, str): + raise ValueError( + "fmi_path must be a string, got {}".format(type(fmi_path).__name__)) + kwargs["data"] = kwargs.get("data", {}) + kwargs["data"]["fmi_path"] = fmi_path return _clean_up(self._acquire_token_silent_with_error( scopes, None, claims_challenge=claims_challenge, **kwargs)) diff --git a/msal/authority.py b/msal/authority.py index b114831f..4a3a56ee 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -92,11 +92,9 @@ def __init__( self._http_client = http_client self._oidc_authority_url = oidc_authority_url if oidc_authority_url: - logger.debug("Initializing with OIDC authority: %s", oidc_authority_url) tenant_discovery_endpoint = self._initialize_oidc_authority( oidc_authority_url) else: - logger.debug("Initializing with Entra authority: %s", authority_url) tenant_discovery_endpoint = self._initialize_entra_authority( authority_url, validate_authority, instance_discovery) try: @@ -117,8 +115,6 @@ def __init__( .format(authority_url) ) + " Also please double check your tenant name or GUID is correct." raise ValueError(error_message) - logger.debug( - 'openid_config("%s") = %s', tenant_discovery_endpoint, openid_config) self._issuer = openid_config.get('issuer') self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] diff --git a/msal/token_cache.py b/msal/token_cache.py index 846c8132..d6e2a2b1 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -1,4 +1,6 @@ -import json +import base64 +import hashlib +import json import threading import time import logging @@ -12,6 +14,89 @@ logger = logging.getLogger(__name__) _GRANT_TYPE_BROKER = "broker" +# Fields in the request data dict that should NOT be included in the extended +# cache key hash. Everything else in data IS included, because those are extra +# body parameters going on the wire and must differentiate cached tokens. +# +# Excluded fields and reasons: +# - "client_id" : Standard OAuth2 client identifier, same for every request +# - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc. +# - "scope" : Already represented as "target" in the AT cache key +# - "claims" : Handled separately; its presence forces a token refresh +# - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead +# - "password" : Standard ROPC grant parameter. Tokens are tied to credentials. +# - "refresh_token" : Standard refresh grant parameter +# - "code" : Standard authorization code grant parameter +# - "redirect_uri" : Standard authorization code grant parameter +# - "code_verifier" : Standard PKCE parameter +# - "device_code" : Standard device flow parameter +# - "assertion" : Standard OBO/SAML assertion (RFC 7521) +# - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param +# - "client_assertion" : Client authentication credential (RFC 7521 §4.2) +# - "client_assertion_type" : Client authentication type (RFC 7521 §4.2) +# - "client_secret" : Client authentication secret +# - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately +# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request +# - "key_id" : Already handled as a separate cache lookup field +# +# Included fields (examples — anything NOT in this set is included): +# - "fmi_path" : Federated Managed Identity credential path +# - any future non-standard body parameter that should isolate cache entries +_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({ + # Standard OAuth2 body parameters — these appear in every token request + # and must NOT influence the extended cache key. + # Only non-standard fields (e.g. fmi_path) should contribute to the hash. + "client_id", + "grant_type", + "scope", + "claims", + "username", + "password", + "refresh_token", + "code", + "redirect_uri", + "code_verifier", + "device_code", + "assertion", + "requested_token_use", + "client_assertion", + "client_assertion_type", + "client_secret", + "token_type", + "req_cnf", + "key_id", +}) + + +def _compute_ext_cache_key(data): + """Compute an extended cache key hash from extra body parameters in *data*. + + All fields in *data* that go on the wire are included in the hash, + EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``. + This ensures tokens acquired with different parameter values + (e.g., different FMI paths) are cached separately. + + Returns an empty string when *data* has no hashable fields. + + The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator): + sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded. + """ + if not data: + return "" + cache_components = { + k: str(v) for k, v in data.items() + if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v + } + if not cache_components: + return "" + # Sort keys for consistent hashing (matches Go implementation) + key_str = "".join( + k + cache_components[k] for k in sorted(cache_components.keys()) + ) + hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() + return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() + + def is_subdict_of(small, big): return dict(big, **small) == big @@ -30,6 +115,7 @@ class TokenCache(object): class CredentialType: ACCESS_TOKEN = "AccessToken" + ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet) REFRESH_TOKEN = "RefreshToken" ACCOUNT = "Account" # Not exactly a credential type, but we put it here ID_TOKEN = "IdToken" @@ -59,18 +145,22 @@ def __init__(self): self.CredentialType.ACCESS_TOKEN: lambda home_account_id=None, environment=None, client_id=None, realm=None, target=None, + ext_cache_key=None, # Note: New field(s) can be added here #key_id=None, **ignored_payload_from_a_real_token: "-".join([ # Note: Could use a hash here to shorten key length home_account_id or "", environment or "", - self.CredentialType.ACCESS_TOKEN, + # Use "atext" credential type when ext_cache_key is + # present, matching MSAL Go and MSAL .NET behaviour. + "atext" if ext_cache_key else "AccessToken", client_id or "", realm or "", target or "", #key_id or "", # So ATs of different key_id can coexist - ]).lower(), + ] + ([ext_cache_key] if ext_cache_key else []) + ).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, realm=None, **ignored_payload_from_a_real_token: @@ -98,6 +188,7 @@ def __init__(self): def _get_access_token( self, home_account_id, environment, client_id, realm, target, # Together they form a compound key + ext_cache_key=None, default=None, ): # O(1) return self._get( @@ -108,6 +199,7 @@ def _get_access_token( client_id=client_id, realm=realm, target=" ".join(target), + ext_cache_key=ext_cache_key, ), default=default) @@ -153,7 +245,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n) ): # Special case for O(1) AT lookup preferred_result = self._get_access_token( query["home_account_id"], query["environment"], - query["client_id"], query["realm"], target) + query["client_id"], query["realm"], target, + ext_cache_key=query.get("ext_cache_key")) if preferred_result and self._is_matching( preferred_result, query, # Needs no target_set here because it is satisfied by dict key @@ -179,6 +272,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n) if (entry != preferred_result # Avoid yielding the same entry twice and self._is_matching(entry, query, target_set=target_set) ): + # Cache isolation for extended cache keys (e.g., FMI path). + # Entries with ext_cache_key must not match queries without one. + if (credential_type == self.CredentialType.ACCESS_TOKEN + and "ext_cache_key" in entry + and "ext_cache_key" not in (query or {}) + ): + continue yield entry for at in expired_access_tokens: self.remove_at(at) @@ -278,6 +378,12 @@ def __add(self, event, now=None): # So that we won't accidentally store a user's password etc. "key_id", # It happens in SSH-cert or POP scenario }}) + # Compute and store extended cache key for cache isolation + # (e.g., different FMI paths should have separate cache entries) + ext_cache_key = _compute_ext_cache_key(data) + + if ext_cache_key: + at["ext_cache_key"] = ext_cache_key if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer at["refresh_on"] = str(now + refresh_in) # Schema wants a string diff --git a/tests/test_application.py b/tests/test_application.py index a31c8580..54da96c0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -707,6 +707,170 @@ def test_organizations_authority_should_emit_warning(self): authority="https://login.microsoftonline.com/organizations") +@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) +class TestAcquireTokenForClientWithFmiPath(unittest.TestCase): + """Test that acquire_token_for_client(fmi_path=...) attaches fmi_path to HTTP body.""" + + def test_fmi_path_rejects_non_string_types(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + for bad_value in [123, True, ["path"], {"path": "value"}, b"bytes"]: + with self.assertRaises(ValueError, msg="fmi_path={!r} should raise".format(bad_value)): + app.acquire_token_for_client(["scope"], fmi_path=bad_value) + + def test_fmi_path_is_included_in_request_body(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "SomeFmiPath/FmiCredentialPath" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client( + ["scope"], fmi_path=fmi_path, post=mock_post) + self.assertIn("access_token", result) + self.assertIn("fmi_path", captured_data, + "fmi_path should be present in the HTTP request body") + self.assertEqual(fmi_path, captured_data["fmi_path"], + "fmi_path value should match the input") + + def test_fmi_path_coexists_with_other_data(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "another/fmi/path" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client( + ["scope"], fmi_path=fmi_path, post=mock_post) + self.assertIn("access_token", result) + self.assertEqual(fmi_path, captured_data["fmi_path"]) + self.assertEqual("client_credentials", captured_data.get("grant_type")) + + def test_fmi_path_preserves_existing_data_params(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "my/fmi/path" + captured_data = {} + + def mock_post(url, headers=None, data=None, *args, **kwargs): + captured_data.update(data or {}) + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result = app.acquire_token_for_client( + ["scope"], fmi_path=fmi_path, + data={"extra_key": "extra_value"}, + post=mock_post) + self.assertIn("access_token", result) + self.assertEqual(fmi_path, captured_data["fmi_path"]) + self.assertEqual("extra_value", captured_data.get("extra_key"), + "Pre-existing data params should be preserved") + + def test_cached_token_is_returned_on_second_call(self): + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + fmi_path = "SomeFmiPath/FmiCredentialPath" + call_count = [0] + + def mock_post(url, headers=None, data=None, *args, **kwargs): + call_count[0] += 1 + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "an AT", + "expires_in": 3600, + })) + + result1 = app.acquire_token_for_client( + ["scope"], fmi_path=fmi_path, post=mock_post) + self.assertIn("access_token", result1) + self.assertEqual(result1[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP) + + result2 = app.acquire_token_for_client( + ["scope"], fmi_path=fmi_path, post=mock_post) + self.assertIn("access_token", result2) + self.assertEqual(result2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Second call should return token from cache") + + def test_different_fmi_paths_are_cached_separately(self): + """Tokens acquired with different fmi_path values must NOT share cache entries.""" + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + def mock_post_factory(token_value): + def mock_post(url, headers=None, data=None, *args, **kwargs): + return MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": token_value, + "expires_in": 3600, + })) + return mock_post + + # Acquire token with path A + result_a = app.acquire_token_for_client( + ["scope"], fmi_path="PathA/credential", post=mock_post_factory("AT_for_path_A")) + self.assertEqual("AT_for_path_A", result_a["access_token"]) + + # Acquire token with path B (should NOT get path A's cached token) + result_b = app.acquire_token_for_client( + ["scope"], fmi_path="PathB/credential", post=mock_post_factory("AT_for_path_B")) + self.assertEqual("AT_for_path_B", result_b["access_token"]) + self.assertEqual(result_b[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Different FMI path should NOT return a cached token from another path") + + # Verify path A still returns its own cached token + result_a2 = app.acquire_token_for_client( + ["scope"], fmi_path="PathA/credential", post=mock_post_factory("should_not_be_used")) + self.assertEqual("AT_for_path_A", result_a2["access_token"]) + self.assertEqual(result_a2[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE, + "Same FMI path should return cached token") + + def test_fmi_token_does_not_interfere_with_non_fmi_token(self): + """FMI-cached tokens must not be returned for non-FMI acquire_token_for_client.""" + app = ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/my_tenant") + + # First, cache a token via FMI path + app.acquire_token_for_client( + ["scope"], fmi_path="some/fmi/path", + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "FMI_AT", "expires_in": 3600}))) + + # Now call regular acquire_token_for_client — should NOT get FMI token + result = app.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({ + "access_token": "regular_AT", "expires_in": 3600}))) + self.assertEqual("regular_AT", result["access_token"]) + self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP, + "Non-FMI call should not return FMI-cached token") + + @patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK) class TestRemoveTokensForClient(unittest.TestCase): def test_remove_tokens_for_client_should_remove_client_tokens_only(self): diff --git a/tests/test_ccs.py b/tests/test_ccs.py index 8b801773..9bbc2787 100644 --- a/tests/test_ccs.py +++ b/tests/test_ccs.py @@ -61,11 +61,14 @@ def test_acquire_token_silent(self): "CSS routing info should be derived from home_account_id") def test_acquire_token_by_username_password(self): + import warnings app = msal.ClientApplication("client_id") username = "johndoe@contoso.com" with patch.object(app.http_client, "post", return_value=MinimalResponse( status_code=400, text='{"error": "mock"}')) as mocked_method: - app.acquire_token_by_username_password(username, "password", ["scope"]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + app.acquire_token_by_username_password(username, "password", ["scope"]) self.assertEqual( "upn:" + username, mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), diff --git a/tests/test_fmi_e2e.py b/tests/test_fmi_e2e.py new file mode 100644 index 00000000..d2ba5bc0 --- /dev/null +++ b/tests/test_fmi_e2e.py @@ -0,0 +1,283 @@ +"""End-to-end tests for Federated Managed Identity (FMI) functionality. + +These tests verify: +1. Tokens can be acquired using certificate authentication with FMI path +2. Tokens are properly cached and returned from cache on subsequent calls +3. Tokens can be acquired using an assertion callback (RMA pattern) with FMI path + +""" + +import logging +import os +import sys +import unittest + +import msal +from tests.http_client import MinimalHttpClient +from tests.lab_config import get_client_certificate +from tests.test_e2e import LabBasedTestCase + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG if "-v" in sys.argv else logging.INFO) + +# Test configuration +_FMI_TENANT_ID = "f645ad92-e38d-4d1a-b510-d1b09a74a8ca" +_FMI_CLIENT_ID = "4df2cbbb-8612-49c1-87c8-f334d6d065ad" +_FMI_SCOPE = "3091264c-7afb-45d4-b527-39737ee86187/.default" +_FMI_PATH = "SomeFmiPath/FmiCredentialPath" +_FMI_CLIENT_ID_URN = "urn:microsoft:identity:fmi" +_FMI_SCOPE_FOR_RMA = "api://AzureFMITokenExchange/.default" +_AUTHORITY_URL = "https://login.microsoftonline.com/" + _FMI_TENANT_ID + + +def _get_fmi_credential_from_rma(): + """Acquire an FMI token from RMA service using certificate credentials. + + This mirrors the Go function GetFmiCredentialFromRma: + 1. Create a confidential client with certificate credential + 2. Acquire a token for the FMI scope with the FMI path + 3. Return the access token as an assertion string + """ + + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + result = app.acquire_token_for_client( + [_FMI_SCOPE_FOR_RMA], fmi_path=_FMI_PATH) + if "access_token" not in result: + raise RuntimeError( + "Failed to acquire FMI token from RMA: {}: {}".format( + result.get("error"), result.get("error_description"))) + return result["access_token"] + + +class TestFMIBasicFunctionality(LabBasedTestCase): + """Test basic FMI token acquisition with certificate credential. + + Mirrors TestFMIBasicFunctionality from Go: + 1. Acquire token by credential with FMI path + 2. Verify silent (cached) token acquisition works + 3. Validate tokens match (proving cache was used) + """ + + def test_acquire_and_cache_with_fmi_path(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # 1. Acquire token by credential with FMI path + result = app.acquire_token_for_client(scopes, fmi_path=_FMI_PATH) + self.assertIn("access_token", result, + "acquire_token_for_client(fmi_path=...) failed: {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertNotEqual("", result["access_token"], + "acquire_token_for_client(fmi_path=...) returned empty access token") + + first_token = result["access_token"] + + # 2. Verify silent token acquisition works (should retrieve from cache) + cache_result = app.acquire_token_for_client(scopes, fmi_path=_FMI_PATH) + self.assertIn("access_token", cache_result, + "Second call failed: {}: {}".format( + cache_result.get("error"), cache_result.get("error_description"))) + self.assertNotEqual("", cache_result["access_token"], + "Second call returned empty access token") + self.assertEqual( + cache_result.get("token_source"), "cache", + "Second call should return token from cache") + + # 3. Validate tokens match (proving cache was used) + self.assertEqual(first_token, cache_result["access_token"], + "Token comparison failed - tokens don't match, " + "cache might not be working correctly") + +class TestFMIIntegration(LabBasedTestCase): + """Test FMI with assertion callback (RMA pattern). + + Mirrors TestFMIIntegration from Go: + 1. Get credentials from RMA via assertion callback + 2. Acquire token by credential with FMI path + 3. Verify cached token acquisition works + 4. Compare tokens to verify cache was used + """ + + def test_acquire_with_assertion_callback_and_fmi_path(self): + # Create credential from assertion callback (mirrors Go's NewCredFromAssertionCallback) + client_credential = { + "client_assertion": lambda: _get_fmi_credential_from_rma(), + } + + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID_URN, + client_credential=client_credential, + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + fmi_path = "SomeFmiPath/Path" + + # 1. Acquire token by credential with FMI path + result = app.acquire_token_for_client(scopes, fmi_path=fmi_path) + self.assertIn("access_token", result, + "acquire_token_for_client(fmi_path=...) failed: {}: {}".format( + result.get("error"), result.get("error_description"))) + self.assertNotEqual("", result["access_token"], + "acquire_token_for_client(fmi_path=...) returned empty access token") + first_token = result["access_token"] + + # 2. Verify cached token acquisition works + cache_result = app.acquire_token_for_client(scopes, fmi_path=fmi_path) + self.assertIn("access_token", cache_result, + "Second call failed: {}: {}".format( + cache_result.get("error"), cache_result.get("error_description"))) + self.assertNotEqual("", cache_result["access_token"], + "Second call returned empty access token") + self.assertEqual( + cache_result.get("token_source"), "cache", + "Second call should return token from cache") + + # 3. Compare tokens to verify cache was used + self.assertEqual(first_token, cache_result["access_token"], + "Token comparison failed - tokens don't match, " + "cache might not be working correctly") + + +class TestFMICacheIsolation(LabBasedTestCase): + """Test that tokens acquired with different FMI paths are cached separately. + + This verifies the cache key extensibility: two calls with different fmi_path + values should NOT return each other's cached tokens. + """ + + def test_different_fmi_paths_are_cached_separately(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # Acquire token with path A + result_a = app.acquire_token_for_client( + scopes, fmi_path="PathA/credential") + self.assertIn("access_token", result_a, + "Path A acquisition failed: {}: {}".format( + result_a.get("error"), result_a.get("error_description"))) + + # Acquire token with path B — should NOT get path A's cached token + result_b = app.acquire_token_for_client( + scopes, fmi_path="PathB/credential") + self.assertIn("access_token", result_b, + "Path B acquisition failed: {}: {}".format( + result_b.get("error"), result_b.get("error_description"))) + self.assertNotEqual( + result_b.get("token_source"), "cache", + "Different FMI path should NOT return cached token from another path") + + # Verify path A still returns its own cached token + result_a2 = app.acquire_token_for_client( + scopes, fmi_path="PathA/credential") + self.assertIn("access_token", result_a2) + self.assertEqual( + result_a2.get("token_source"), "cache", + "Same FMI path should return cached token") + self.assertEqual(result_a["access_token"], result_a2["access_token"]) + + def test_fmi_token_does_not_interfere_with_non_fmi_token(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + + # Cache a token via FMI path + fmi_result = app.acquire_token_for_client(scopes, fmi_path=_FMI_PATH) + self.assertIn("access_token", fmi_result) + + # Regular acquire_token_for_client should NOT get the FMI token + regular_result = app.acquire_token_for_client(scopes) + self.assertIn("access_token", regular_result, + "Regular call failed: {}: {}".format( + regular_result.get("error"), regular_result.get("error_description"))) + self.assertNotEqual( + regular_result.get("token_source"), "cache", + "Non-FMI call should not return FMI-cached token") + + +class TestFMICacheInspection(LabBasedTestCase): + """Acquire tokens with two different FMI paths and inspect the underlying + cache to verify the entries are correctly isolated.""" + + def test_two_fmi_paths_produce_separate_cache_entries(self): + app = msal.ConfidentialClientApplication( + _FMI_CLIENT_ID, + client_credential=get_client_certificate(), + authority=_AUTHORITY_URL, + http_client=MinimalHttpClient(), + ) + scopes = [_FMI_SCOPE] + path_a = "PathAlpha/Credential" + path_b = "PathBeta/Credential" + + # 1. Acquire token with path A + result_a = app.acquire_token_for_client(scopes, fmi_path=path_a) + self.assertIn("access_token", result_a, + "Path A acquisition failed: {}: {}".format( + result_a.get("error"), result_a.get("error_description"))) + token_a = result_a["access_token"] + + # 2. Acquire token with path B + result_b = app.acquire_token_for_client(scopes, fmi_path=path_b) + self.assertIn("access_token", result_b, + "Path B acquisition failed: {}: {}".format( + result_b.get("error"), result_b.get("error_description"))) + token_b = result_b["access_token"] + + # Tokens should be different (different paths go to different resources) + self.assertNotEqual(token_a, token_b, + "Tokens for different FMI paths should differ") + + # 3. Inspect cache: there should be exactly 2 AccessToken entries + cache = app.token_cache._cache + at_entries = cache.get("AccessToken", {}) + # Filter to our client_id + scope to avoid noise + our_entries = { + k: v for k, v in at_entries.items() + if v.get("client_id") == _FMI_CLIENT_ID + and _FMI_SCOPE.split("/")[0] in v.get("target", "") + } + self.assertEqual(2, len(our_entries), + "Cache should contain exactly 2 AT entries for our client, " + "got {}: {}".format(len(our_entries), list(our_entries.keys()))) + + # 4. Each entry must have a non-empty ext_cache_key, and they must differ + ext_keys = [v.get("ext_cache_key") for v in our_entries.values()] + for ek in ext_keys: + self.assertTrue(ek, "Each FMI cache entry must have a non-empty ext_cache_key") + self.assertNotEqual(ext_keys[0], ext_keys[1], + "ext_cache_key values for different FMI paths must differ") + + # 5. Verify each path still returns its own cached token + cached_a = app.acquire_token_for_client(scopes, fmi_path=path_a) + self.assertEqual("cache", cached_a.get("token_source")) + self.assertEqual(token_a, cached_a["access_token"], + "Path A should return its own cached token") + + cached_b = app.acquire_token_for_client(scopes, fmi_path=path_b) + self.assertEqual("cache", cached_b.get("token_source")) + self.assertEqual(token_b, cached_b["access_token"], + "Path B should return its own cached token") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 5310b789..d7dfe8de 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -4,7 +4,7 @@ import time import warnings -from msal.token_cache import TokenCache, SerializableTokenCache +from msal.token_cache import TokenCache, SerializableTokenCache, _compute_ext_cache_key from tests import unittest @@ -321,3 +321,242 @@ def tearDown(self): output.get("AccessToken", {}).get("an-entry"), {"foo": "bar"}, "Undefined token keys and their values should be intact") + +class TestComputeExtCacheKey(unittest.TestCase): + """Tests for the _compute_ext_cache_key hash function.""" + + def test_empty_data_returns_empty_string(self): + self.assertEqual("", _compute_ext_cache_key(None)) + self.assertEqual("", _compute_ext_cache_key({})) + + def test_excluded_fields_are_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"key_id": "k1", "token_type": "ssh-cert", "req_cnf": "nonce", "claims": "{}"}), + "Fields in _EXT_CACHE_KEY_EXCLUDED_FIELDS should produce an empty hash") + + def test_fmi_path_produces_non_empty_hash(self): + result = _compute_ext_cache_key({"fmi_path": "SomePath/Credential"}) + self.assertNotEqual("", result) + self.assertIsInstance(result, str) + + def test_same_input_produces_same_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a"}) + self.assertEqual(h1, h2) + + def test_different_fmi_paths_produce_different_hashes(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/b"}) + self.assertNotEqual(h1, h2) + + def test_empty_fmi_path_value_is_ignored(self): + self.assertEqual("", _compute_ext_cache_key({"fmi_path": ""})) + + def test_excluded_fields_dont_affect_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a", "key_id": "k1", "req_cnf": "nonce"}) + self.assertEqual(h1, h2, "Excluded fields should not affect the hash") + + def test_non_excluded_fields_are_included_in_hash(self): + h1 = _compute_ext_cache_key({"fmi_path": "path/a"}) + h2 = _compute_ext_cache_key({"fmi_path": "path/a", "custom_param": "val"}) + self.assertNotEqual(h1, h2, "Non-excluded fields should change the hash") + + +class TestExtCacheKeyIsolation(unittest.TestCase): + """Tests that ext_cache_key provides proper cache isolation in TokenCache.""" + + def _build_event(self, client_id, scope, token_endpoint, access_token, data=None, **kwargs): + return { + "client_id": client_id, + "scope": scope, + "token_endpoint": token_endpoint, + "response": build_response(access_token=access_token, expires_in=3600), + "data": data or {}, + **kwargs, + } + + def test_at_key_includes_ext_cache_key_when_present(self): + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key_without = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope") + key_with = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="somehash") + self.assertNotEqual(key_without, key_with, + "Keys with and without ext_cache_key should differ") + self.assertIn("somehash", key_with) + + def test_different_ext_cache_keys_produce_different_at_keys(self): + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key_a = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="hash_a") + key_b = key_maker( + home_account_id="", environment="env", client_id="cid", + realm="realm", target="scope", ext_cache_key="hash_b") + self.assertNotEqual(key_a, key_b) + + def test_fmi_tokens_are_stored_with_ext_cache_key(self): + cache = TokenCache() + event = self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"}) + cache.add(event) + at_entries = list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"])) + self.assertEqual(0, len(at_entries), + "FMI tokens should NOT be found by a query without ext_cache_key") + + def test_fmi_tokens_found_with_matching_ext_cache_key_query(self): + cache = TokenCache() + ext_key = _compute_ext_cache_key({"fmi_path": "some/path"}) + event = self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"}) + cache.add(event) + at_entries = list(cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"], + query={"client_id": "cid", "environment": "login.example.com", + "realm": "tenant", "home_account_id": None, + "ext_cache_key": ext_key})) + self.assertEqual(1, len(at_entries)) + self.assertEqual("fmi_token", at_entries[0]["secret"]) + + def test_non_fmi_tokens_not_affected_by_fmi_cache(self): + cache = TokenCache() + # Add FMI token + cache.add(self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "fmi_token", data={"fmi_path": "some/path"})) + # Add regular token + cache.add(self._build_event( + "cid", ["s1"], "https://login.example.com/tenant/v2/token", + "regular_token")) + # Search without ext_cache_key should find only regular token + at_entries = list(cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, target=["s1"], + query={"client_id": "cid", "environment": "login.example.com", + "realm": "tenant", "home_account_id": None})) + self.assertEqual(1, len(at_entries)) + self.assertEqual("regular_token", at_entries[0]["secret"]) + + +class TestCrossMsalCacheKeyCompatibility(unittest.TestCase): + """Verify that _compute_ext_cache_key produces hashes identical to MSAL Go + (CacheExtKeyGenerator) and MSAL .NET (CoreHelpers.ComputeAccessTokenExtCacheKey). + + All three libraries use the same algorithm: + 1. Sort key-value pairs alphabetically by key (ordinal / case-sensitive) + 2. Concatenate them: "key1value1key2value2…" + 3. SHA-256 hash + 4. Base64url encode (no padding), lowercased + + The expected hashes below are copied from: + - MSAL Go: authority_ext_cachekey_test.go (TestAppKeyWithCacheKeyComponent) + - MSAL .NET: CacheKeyExtensionTests.cs (RunHappyPathTest, CacheExtEnsurePopKeysFunctionAsync) + """ + + def test_two_params_hash_matches_go_and_dotnet(self): + """Go/dotnet expected: bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi""" + result = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) + self.assertEqual("bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", result) + + def test_two_different_params_hash_matches_go_and_dotnet(self): + """Go/dotnet expected: 3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u""" + result = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) + self.assertEqual("3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u", result) + + def test_five_params_hash_matches_go_and_dotnet(self): + """Go/dotnet expected (full hash): rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e + Go test uses substring match 'gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e'.""" + result = _compute_ext_cache_key({ + "key3": "value3", "key4": "value4", + "key5": "value5", "key6": "value6", "key7": "value7", + }) + self.assertEqual("rn_gkpxxkkqjxcqnvnmr2duvxg66xanvkz6qfqpwp2e", result) + + def test_order_independence_matches_go_and_dotnet(self): + """Same keys in different insertion order must produce the same hash + (mirrors TestCacheKeyComponentHashConsistency in Go).""" + h1 = _compute_ext_cache_key({"key3": "value3", "key4": "value4", + "key5": "value5", "key6": "value6", "key7": "value7"}) + h2 = _compute_ext_cache_key({"key7": "value7", "key4": "value4", + "key6": "value6", "key5": "value5", "key3": "value3"}) + self.assertEqual(h1, h2) + + def test_at_cache_key_uses_atext_credential_type(self): + """When ext_cache_key is present the credential type segment of the + AT cache key must be 'atext' (not 'accesstoken'), matching Go/dotnet. + + Go: {hid}-{env}-atext-{clientID}-{realm}-{scopes}-{hash} + .NET: {hid}-{env}-atext-{clientID}-{tenantId}-{scopes}-{hash} + """ + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key = key_maker( + home_account_id="hid", environment="env", client_id="cid", + realm="realm", target="scope", + ext_cache_key="bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi") + self.assertEqual( + "hid-env-atext-cid-realm-scope-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi", + key) + + def test_at_cache_key_without_ext_uses_accesstoken(self): + """Regular ATs (no ext_cache_key) must keep 'accesstoken' credential type.""" + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + key = key_maker( + home_account_id="hid", environment="env", client_id="cid", + realm="realm", target="scope") + self.assertEqual("hid-env-accesstoken-cid-realm-scope", key) + + def test_dotnet_style_full_at_cache_key(self): + """Reproduce the exact cache key from MSAL .NET CacheKeyExtensionTests: + expectedCacheKey1 = '-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi' + """ + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + ext_hash = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) + key = key_maker( + home_account_id="", + environment="login.windows.net", + client_id="d3adb33f-c0de-ed0c-c0de-deadb33fc0d3", + realm="common", + target="r1/scope1 r1/scope2", + ext_cache_key=ext_hash) + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + self.assertEqual(expected, key) + + def test_dotnet_style_second_cache_key(self): + """Reproduce CacheKeyExtensionTests expectedCacheKey2.""" + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + ext_hash = _compute_ext_cache_key({"key3": "value3", "key4": "value4"}) + key = key_maker( + home_account_id="", + environment="login.windows.net", + client_id="d3adb33f-c0de-ed0c-c0de-deadb33fc0d3", + realm="common", + target="r1/scope1 r1/scope2", + ext_cache_key=ext_hash) + expected = "-login.windows.net-atext-d3adb33f-c0de-ed0c-c0de-deadb33fc0d3-common-r1/scope1 r1/scope2-3-rg6_wyjx5bcy0c3cqq7gajtzgsqy3oxqpwj4y8k4u" + self.assertEqual(expected, key) + + def test_go_style_at_cache_key(self): + """Reproduce the Go AccessToken.Key() format: + Go test: 'testhid-env-atext-clientid-realm-user.read-{hash}' + """ + cache = TokenCache() + key_maker = cache.key_makers[TokenCache.CredentialType.ACCESS_TOKEN] + ext_hash = _compute_ext_cache_key({"key1": "value1", "key2": "value2"}) + key = key_maker( + home_account_id="testhid", + environment="env", + client_id="clientid", + realm="realm", + target="user.read", + ext_cache_key=ext_hash) + expected = "testhid-env-atext-clientid-realm-user.read-bns2ytmx5hxkh4fnfixridmezpbbayhnmuh6t4bbghi" + self.assertEqual(expected, key)