Source code for auth0.authentication.token_verifier

"""Token Verifier module"""
from __future__ import annotations

import json
import time
from typing import TYPE_CHECKING, Any, ClassVar

import jwt
import requests

from auth0.exceptions import TokenValidationError

if TYPE_CHECKING:
    from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey


[docs]class SignatureVerifier: """Abstract class that will verify a given JSON web token's signature using the key fetched internally given its key id. Args: algorithm (str): The expected signing algorithm (e.g. RS256). """ DISABLE_JWT_CHECKS: ClassVar[dict[str, bool]] = { "verify_signature": True, "verify_exp": False, "verify_nbf": False, "verify_iat": False, "verify_aud": False, "verify_iss": False, "require_exp": False, "require_iat": False, "require_nbf": False, } def __init__(self, algorithm: str) -> None: if not algorithm or type(algorithm) != str: raise ValueError("algorithm must be specified.") self._algorithm = algorithm def _fetch_key(self, key_id: str) -> str | RSAPublicKey: """Obtains the key associated to the given key id. Must be implemented by subclasses. Args: key_id (str): The id of the key to fetch. Returns: the key to use for verifying a cryptographic signature """ raise NotImplementedError def _get_kid(self, token: str) -> str | None: """Gets the key id from the kid claim of the header of the token Args: token (str): The JWT to get the header from. Raises: TokenValidationError: if the token cannot be decoded, the algorithm is invalid or the token's signature doesn't match the calculated one. Returns: the key id or None """ try: header = jwt.get_unverified_header(token) except jwt.exceptions.DecodeError: raise TokenValidationError("token could not be decoded.") alg = header.get("alg", None) if alg != self._algorithm: raise TokenValidationError( 'Signature algorithm of "{}" is not supported. Expected the token ' 'to be signed with "{}"'.format(alg, self._algorithm) ) return header.get("kid", None) def _decode_jwt(self, token: str, secret_or_certificate: str) -> dict[str, Any]: """Verifies and decodes the given JSON web token with the given public key or shared secret. Args: token (str): The JWT to get its signature verified. secret_or_certificate (str): The public key or shared secret. Raises: TokenValidationError: if the token cannot be decoded, the algorithm is invalid or the token's signature doesn't match the calculated one. """ try: decoded = jwt.decode( jwt=token, key=secret_or_certificate, algorithms=[self._algorithm], options=self.DISABLE_JWT_CHECKS, ) except jwt.exceptions.InvalidSignatureError: raise TokenValidationError("Invalid token signature.") return decoded
[docs] def verify_signature(self, token: str) -> dict[str, Any]: """Verifies the signature of the given JSON web token. Args: token (str): The JWT to get its signature verified. Raises: TokenValidationError: if the token cannot be decoded, the algorithm is invalid or the token's signature doesn't match the calculated one. """ kid = self._get_kid(token) if kid is None: kid = "" secret_or_certificate = self._fetch_key(key_id=kid) return self._decode_jwt(token, secret_or_certificate) # type: ignore[arg-type]
[docs]class SymmetricSignatureVerifier(SignatureVerifier): """Verifier for HMAC signatures, which rely on shared secrets. Args: shared_secret (str): The shared secret used to decode the token. algorithm (str, optional): The expected signing algorithm. Defaults to "HS256". """ def __init__(self, shared_secret: str, algorithm: str = "HS256") -> None: super().__init__(algorithm) self._shared_secret = shared_secret def _fetch_key(self, key_id: str = "") -> str: return self._shared_secret
[docs]class JwksFetcher: """Class that fetches and holds a JSON web key set. This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it. Args: jwks_url (str): The url where the JWK set is located. cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ CACHE_TTL: ClassVar[int] = 600 # 10 min cache lifetime def __init__(self, jwks_url: str, cache_ttl: int = CACHE_TTL) -> None: self._jwks_url = jwks_url self._init_cache(cache_ttl) def _init_cache(self, cache_ttl: int) -> None: self._cache_value: dict[str, RSAPublicKey] = {} self._cache_date = 0.0 self._cache_ttl = cache_ttl self._cache_is_fresh = False def _cache_expired(self) -> bool: """Checks if the cache is expired Returns: True if it should use the cache. """ return self._cache_date + self._cache_ttl < time.time() def _cache_jwks(self, jwks: dict[str, Any]) -> None: """Cache the response of the JWKS request Args: jwks (dict): The JWKS """ self._cache_value = self._parse_jwks(jwks) self._cache_is_fresh = True self._cache_date = time.time() def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]: """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result and update the cache value with it. Args: force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False. """ if force or self._cache_expired(): self._cache_value = {} response = requests.get(self._jwks_url) if response.ok: jwks: dict[str, Any] = response.json() self._cache_jwks(jwks) return self._cache_value self._cache_is_fresh = False return self._cache_value @staticmethod def _parse_jwks(jwks: dict[str, Any]) -> dict[str, RSAPublicKey]: """ Converts a JWK string representation into a binary certificate in PEM format. """ keys: dict[str, RSAPublicKey] = {} for key in jwks["keys"]: # noinspection PyUnresolvedReferences # requirement already includes cryptography -> pyjwt[crypto] rsa_key: RSAPublicKey = jwt.algorithms.RSAAlgorithm.from_jwk( json.dumps(key) ) keys[key["kid"]] = rsa_key return keys
[docs] def get_key(self, key_id: str) -> RSAPublicKey: """Obtains the JWK associated with the given key id. Args: key_id (str): The id of the key to fetch. Returns: the JWK associated with the given key id. Raises: TokenValidationError: when a key with that id cannot be found """ keys = self._fetch_jwks() if keys and key_id in keys: return keys[key_id] if not self._cache_is_fresh: keys = self._fetch_jwks(force=True) if keys and key_id in keys: return keys[key_id] raise TokenValidationError(f'RSA Public Key with ID "{key_id}" was not found.')
[docs]class AsymmetricSignatureVerifier(SignatureVerifier): """Verifier for RSA signatures, which rely on public key certificates. Args: jwks_url (str): The url where the JWK set is located. algorithm (str, optional): The expected signing algorithm. Defaults to "RS256". cache_ttl (int, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ def __init__( self, jwks_url: str, algorithm: str = "RS256", cache_ttl: int = JwksFetcher.CACHE_TTL, ) -> None: super().__init__(algorithm) self._fetcher = JwksFetcher(jwks_url, cache_ttl) def _fetch_key(self, key_id: str) -> RSAPublicKey: return self._fetcher.get_key(key_id)
[docs]class TokenVerifier: """Class that verifies ID tokens following the steps defined in the OpenID Connect spec. An OpenID Connect ID token is not meant to be consumed until it's verified. Args: signature_verifier (SignatureVerifier): The instance that knows how to verify the signature. issuer (str): The expected issuer claim value. audience (str): The expected audience claim value. leeway (int, optional): The clock skew to accept when verifying date related claims in seconds. Defaults to 60 seconds. """ def __init__( self, signature_verifier: SignatureVerifier, issuer: str, audience: str, leeway: int = 0, ) -> None: if not signature_verifier or not isinstance( signature_verifier, SignatureVerifier ): raise TypeError( "signature_verifier must be an instance of SignatureVerifier." ) self.iss = issuer self.aud = audience self.leeway = leeway self._sv = signature_verifier self._clock = None # visible for testing
[docs] def verify( self, token: str, nonce: str | None = None, max_age: int | None = None, organization: str | None = None, ) -> dict[str, Any]: """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. Args: token (str): The JWT to verify. nonce (str, optional): The nonce value sent during authentication. max_age (int, optional): The max_age value sent during authentication. organization (str, optional): The expected organization ID (org_id) or orgnization name (org_name) claim value. This should be specified when logging in to an organization. Returns: the decoded payload from the token Raises: TokenValidationError: when the token cannot be decoded, the token signing algorithm is not the expected one, the token signature is invalid or the token has a claim missing or with unexpected value. """ # Verify token presence if not token or not isinstance(token, str): raise TokenValidationError("ID token is required but missing.") # Verify algorithm and signature payload = self._sv.verify_signature(token) # Verify claims self._verify_payload(payload, nonce, max_age, organization) return payload
def _verify_payload( self, payload: dict[str, Any], nonce: str | None = None, max_age: int | None = None, organization: str | None = None, ) -> None: # Issuer if "iss" not in payload or not isinstance(payload["iss"], str): raise TokenValidationError( "Issuer (iss) claim must be a string present in the ID token" ) if payload["iss"] != self.iss: raise TokenValidationError( 'Issuer (iss) claim mismatch in the ID token; expected "{}", ' 'found "{}"'.format(self.iss, payload["iss"]) ) # Subject if "sub" not in payload or not isinstance(payload["sub"], str): raise TokenValidationError( "Subject (sub) claim must be a string present in the ID token" ) # Audience if "aud" not in payload or not isinstance(payload["aud"], (str, list)): raise TokenValidationError( "Audience (aud) claim must be a string or array of strings present in" " the ID token" ) if isinstance(payload["aud"], list) and self.aud not in payload["aud"]: payload_audiences = ", ".join(payload["aud"]) raise TokenValidationError( 'Audience (aud) claim mismatch in the ID token; expected "{}" but was ' 'not one of "{}"'.format(self.aud, payload_audiences) ) elif isinstance(payload["aud"], str) and payload["aud"] != self.aud: raise TokenValidationError( 'Audience (aud) claim mismatch in the ID token; expected "{}" ' 'but found "{}"'.format(self.aud, payload["aud"]) ) # --Time validation (epoch)-- now = self._clock or time.time() leeway = self.leeway # Expires at if "exp" not in payload or not isinstance(payload["exp"], int): raise TokenValidationError( "Expiration Time (exp) claim must be a number present in the ID token" ) exp_time = payload["exp"] + leeway if now > exp_time: raise TokenValidationError( "Expiration Time (exp) claim error in the ID token; current time ({})" " is after expiration time ({})".format(now, exp_time) ) # Issued at if "iat" not in payload or not isinstance(payload["iat"], int): raise TokenValidationError( "Issued At (iat) claim must be a number present in the ID token" ) # Nonce if nonce: if "nonce" not in payload or not isinstance(payload["nonce"], str): raise TokenValidationError( "Nonce (nonce) claim must be a string present in the ID token" ) if payload["nonce"] != nonce: raise TokenValidationError( 'Nonce (nonce) claim mismatch in the ID token; expected "{}", ' 'found "{}"'.format(nonce, payload["nonce"]) ) # Organization if organization: if organization.startswith("org_"): if "org_id" not in payload or not isinstance(payload["org_id"], str): raise TokenValidationError( "Organization (org_id) claim must be a string present in the ID" " token" ) if payload["org_id"] != organization: raise TokenValidationError( "Organization (org_id) claim mismatch in the ID token; expected" ' "{}", found "{}"'.format(organization, payload["org_id"]) ) else: if "org_name" not in payload or not isinstance( payload["org_name"], str ): raise TokenValidationError( "Organization (org_name) claim must be a string present in the ID" " token" ) if payload["org_name"] != organization.lower(): raise TokenValidationError( "Organization (org_name) claim mismatch in the ID token; expected" ' "{}", found "{}"'.format(organization, payload["org_name"]) ) # Authorized party if isinstance(payload["aud"], list) and len(payload["aud"]) > 1: if "azp" not in payload or not isinstance(payload["azp"], str): raise TokenValidationError( "Authorized Party (azp) claim must be a string present in the ID" " token when Audience (aud) claim has multiple values" ) if payload["azp"] != self.aud: raise TokenValidationError( "Authorized Party (azp) claim mismatch in the ID token; expected" ' "{}", found "{}"'.format(self.aud, payload["azp"]) ) # Authentication time if max_age: if "auth_time" not in payload or not isinstance(payload["auth_time"], int): raise TokenValidationError( "Authentication Time (auth_time) claim must be a number present in" " the ID token when Max Age (max_age) is specified" ) auth_valid_until = payload["auth_time"] + max_age + leeway if now > auth_valid_until: raise TokenValidationError( "Authentication Time (auth_time) claim in the ID token indicates" " that too much time has passed since the last end-user" " authentication. Current time ({}) is after last auth at ({})".format( now, auth_valid_until ) )