Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER
from .token_cache import TokenCache, _get_username, _GRANT_TYPE_BROKER, _compute_ext_cache_key
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
Expand Down Expand Up @@ -1571,6 +1571,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
ext_cache_key = _compute_ext_cache_key(kwargs.get("data", {}))
if ext_cache_key: # FMI tokens need cache isolation by path
query["ext_cache_key"] = ext_cache_key
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in self.token_cache.search( # A generator allows us to
Expand Down Expand Up @@ -2424,7 +2427,7 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
except that ``allow_broker`` parameter shall remain ``None``.
"""

def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
def acquire_token_for_client(self, scopes, claims_challenge=None, fmi_path=None, **kwargs):
"""Acquires token for the current confidential client, not for an end user.

Since MSAL Python 1.23, it will automatically look for token from cache,
Expand All @@ -2437,7 +2440,17 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
in the form of a claims_challenge directive in the www-authenticate header to be
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
It is a string of a JSON object which contains lists of claims being requested from these locations.

:param str fmi_path:
Optional. The Federated Managed Identity (FMI) credential path.
When provided, it is sent as the ``fmi_path`` parameter in the
token request body, and the resulting token is cached separately
so that different FMI paths do not share cached tokens.
Example usage::

result = cca.acquire_token_for_client(
scopes=["api://resource/.default"],
fmi_path="SomeFmiPath/FmiCredentialPath",
)
:return: A dict representing the json response from Microsoft Entra:

- A successful response would contain "access_token" key,
Expand All @@ -2447,6 +2460,12 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
raise ValueError( # We choose to disallow force_refresh
"Historically, this method does not support force_refresh behavior. "
)
if fmi_path is not None:
if not isinstance(fmi_path, str):
raise ValueError(
"fmi_path must be a string, got {}".format(type(fmi_path).__name__))
kwargs["data"] = kwargs.get("data", {})
kwargs["data"]["fmi_path"] = fmi_path
return _clean_up(self._acquire_token_silent_with_error(
scopes, None, claims_challenge=claims_challenge, **kwargs))

Expand Down
4 changes: 0 additions & 4 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ def __init__(
self._http_client = http_client
self._oidc_authority_url = oidc_authority_url
if oidc_authority_url:
logger.debug("Initializing with OIDC authority: %s", oidc_authority_url)
tenant_discovery_endpoint = self._initialize_oidc_authority(
oidc_authority_url)
else:
logger.debug("Initializing with Entra authority: %s", authority_url)
tenant_discovery_endpoint = self._initialize_entra_authority(
authority_url, validate_authority, instance_discovery)
try:
Expand All @@ -117,8 +115,6 @@ def __init__(
.format(authority_url)
) + " Also please double check your tenant name or GUID is correct."
raise ValueError(error_message)
logger.debug(
'openid_config("%s") = %s', tenant_discovery_endpoint, openid_config)
self._issuer = openid_config.get('issuer')
self.authorization_endpoint = openid_config['authorization_endpoint']
self.token_endpoint = openid_config['token_endpoint']
Expand Down
114 changes: 110 additions & 4 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import base64
import hashlib
import json
import threading
import time
import logging
Expand All @@ -12,6 +14,89 @@
logger = logging.getLogger(__name__)
_GRANT_TYPE_BROKER = "broker"

# Fields in the request data dict that should NOT be included in the extended
# cache key hash. Everything else in data IS included, because those are extra
# body parameters going on the wire and must differentiate cached tokens.
#
# Excluded fields and reasons:
# - "client_id" : Standard OAuth2 client identifier, same for every request
# - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc.
# - "scope" : Already represented as "target" in the AT cache key
# - "claims" : Handled separately; its presence forces a token refresh
# - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead
# - "password" : Standard ROPC grant parameter. Tokens are tied to credentials.
# - "refresh_token" : Standard refresh grant parameter
# - "code" : Standard authorization code grant parameter
# - "redirect_uri" : Standard authorization code grant parameter
# - "code_verifier" : Standard PKCE parameter
# - "device_code" : Standard device flow parameter
# - "assertion" : Standard OBO/SAML assertion (RFC 7521)
# - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param
# - "client_assertion" : Client authentication credential (RFC 7521 §4.2)
# - "client_assertion_type" : Client authentication type (RFC 7521 §4.2)
# - "client_secret" : Client authentication secret
# - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately
# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
# - "key_id" : Already handled as a separate cache lookup field
#
# Included fields (examples — anything NOT in this set is included):
# - "fmi_path" : Federated Managed Identity credential path
# - any future non-standard body parameter that should isolate cache entries
_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({
# Standard OAuth2 body parameters — these appear in every token request
# and must NOT influence the extended cache key.
# Only non-standard fields (e.g. fmi_path) should contribute to the hash.
"client_id",
"grant_type",
"scope",
"claims",
"username",
"password",
"refresh_token",
"code",
"redirect_uri",
"code_verifier",
"device_code",
"assertion",
"requested_token_use",
"client_assertion",
"client_assertion_type",
"client_secret",
"token_type",
"req_cnf",
"key_id",
})


def _compute_ext_cache_key(data):
"""Compute an extended cache key hash from extra body parameters in *data*.

All fields in *data* that go on the wire are included in the hash,
EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
This ensures tokens acquired with different parameter values
(e.g., different FMI paths) are cached separately.

Returns an empty string when *data* has no hashable fields.

The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
"""
if not data:
return ""
cache_components = {
k: str(v) for k, v in data.items()
if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
}
if not cache_components:
return ""
# Sort keys for consistent hashing (matches Go implementation)
key_str = "".join(
k + cache_components[k] for k in sorted(cache_components.keys())
)
hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()


def is_subdict_of(small, big):
return dict(big, **small) == big

Expand All @@ -30,6 +115,7 @@ class TokenCache(object):

class CredentialType:
ACCESS_TOKEN = "AccessToken"
ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet)
REFRESH_TOKEN = "RefreshToken"
ACCOUNT = "Account" # Not exactly a credential type, but we put it here
ID_TOKEN = "IdToken"
Expand Down Expand Up @@ -59,18 +145,22 @@ def __init__(self):
self.CredentialType.ACCESS_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, target=None,
ext_cache_key=None,
# Note: New field(s) can be added here
#key_id=None,
**ignored_payload_from_a_real_token:
"-".join([ # Note: Could use a hash here to shorten key length
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
# Use "atext" credential type when ext_cache_key is
# present, matching MSAL Go and MSAL .NET behaviour.
"atext" if ext_cache_key else "AccessToken",
client_id or "",
realm or "",
target or "",
#key_id or "", # So ATs of different key_id can coexist
]).lower(),
] + ([ext_cache_key] if ext_cache_key else [])
).lower(),
self.CredentialType.ID_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, **ignored_payload_from_a_real_token:
Expand Down Expand Up @@ -98,6 +188,7 @@ def __init__(self):
def _get_access_token(
self,
home_account_id, environment, client_id, realm, target, # Together they form a compound key
ext_cache_key=None,
default=None,
): # O(1)
return self._get(
Expand All @@ -108,6 +199,7 @@ def _get_access_token(
client_id=client_id,
realm=realm,
target=" ".join(target),
ext_cache_key=ext_cache_key,
),
default=default)

Expand Down Expand Up @@ -153,7 +245,8 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
): # Special case for O(1) AT lookup
preferred_result = self._get_access_token(
query["home_account_id"], query["environment"],
query["client_id"], query["realm"], target)
query["client_id"], query["realm"], target,
ext_cache_key=query.get("ext_cache_key"))
if preferred_result and self._is_matching(
preferred_result, query,
# Needs no target_set here because it is satisfied by dict key
Expand All @@ -179,6 +272,13 @@ def search(self, credential_type, target=None, query=None, *, now=None): # O(n)
if (entry != preferred_result # Avoid yielding the same entry twice
and self._is_matching(entry, query, target_set=target_set)
):
# Cache isolation for extended cache keys (e.g., FMI path).
# Entries with ext_cache_key must not match queries without one.
if (credential_type == self.CredentialType.ACCESS_TOKEN
and "ext_cache_key" in entry
and "ext_cache_key" not in (query or {})
):
continue
yield entry
for at in expired_access_tokens:
self.remove_at(at)
Expand Down Expand Up @@ -278,6 +378,12 @@ def __add(self, event, now=None):
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
}})
# Compute and store extended cache key for cache isolation
# (e.g., different FMI paths should have separate cache entries)
ext_cache_key = _compute_ext_cache_key(data)

if ext_cache_key:
at["ext_cache_key"] = ext_cache_key
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
Expand Down
Loading
Loading