spid_cie_oidc/entity/jwtse.py
import base64
import binascii
import json
import logging
import cryptojwt
from cryptojwt.exception import UnsupportedAlgorithm, VerificationError
from cryptojwt.jwe.jwe import factory
from cryptojwt.jwe.jwe_ec import JWE_EC
from cryptojwt.jwe.jwe_rsa import JWE_RSA
from cryptojwt.jwk.jwk import key_from_jwk_dict
from cryptojwt.jws.jws import JWS
from cryptojwt.jws.utils import left_hash
from typing import Union
from .settings import (
DEFAULT_JWE_ALG,
DEFAULT_JWE_ENC,
ENCRYPTION_ALG_VALUES_SUPPORTED,
SIGNING_ALG_VALUES_SUPPORTED,
)
logger = logging.getLogger(__name__)
def unpad_jwt_element(jwt: str, position: int) -> dict:
b = jwt.split(".")[position]
padded = f"{b}{'=' * divmod(len(b), 4)[1]}"
data = json.loads(base64.urlsafe_b64decode(padded))
return data
def unpad_jwt_head(jwt: str) -> dict:
return unpad_jwt_element(jwt, position=0)
def unpad_jwt_payload(jwt: str) -> dict:
return unpad_jwt_element(jwt, position=1)
def create_jwe(plain_dict: Union[dict, str, int, None], jwk_dict: dict, **kwargs) -> str:
logger.debug(f"Encrypting dict as JWE: " f"{plain_dict}")
_key = key_from_jwk_dict(jwk_dict)
if isinstance(_key, cryptojwt.jwk.rsa.RSAKey):
JWE_CLASS = JWE_RSA
elif isinstance(_key, cryptojwt.jwk.ec.ECKey):
JWE_CLASS = JWE_EC
if isinstance(plain_dict, dict):
_payload = json.dumps(plain_dict).encode()
elif not plain_dict:
logger.warning("create_jwe with null payload!")
_payload = ""
elif isinstance(plain_dict, (str, int)):
_payload = plain_dict
else:
logger.error("create_jwe with unsupported payload type!")
_payload = ""
_keyobj = JWE_CLASS(
_payload,
alg=DEFAULT_JWE_ALG,
enc=DEFAULT_JWE_ENC,
kid=_key.kid,
**kwargs
)
jwe = _keyobj.encrypt(_key.public_key())
logger.debug(f"Encrypted dict as JWE: {jwe}")
return jwe
def decrypt_jwe(jwe: str, jwk_dict: dict) -> dict:
# get header
try:
jwe_header = unpad_jwt_head(jwe)
except (binascii.Error, Exception) as e: # pragma: no cover
logger.error(f"Failed to extract JWT header: {e}")
raise VerificationError("The JWT is not valid")
_alg = jwe_header.get("alg", DEFAULT_JWE_ALG)
_enc = jwe_header.get("enc", DEFAULT_JWE_ENC)
jwe_header.get("kid")
if _alg not in ENCRYPTION_ALG_VALUES_SUPPORTED: # pragma: no cover
raise UnsupportedAlgorithm(f"{_alg} has beed disabled for security reason")
_decryptor = factory(jwe, alg=_alg, enc=_enc)
# _dkey = RSAKey(priv_key=PRIV_KEY)
_dkey = key_from_jwk_dict(jwk_dict)
msg = _decryptor.decrypt(jwe, [_dkey])
try:
msg_dict = json.loads(msg)
logger.debug(f"Decrypted JWT as: {json.dumps(msg_dict, indent=2)}")
except json.decoder.JSONDecodeError:
msg_dict = msg
logger.debug(f"Decrypted JWT as: {msg_dict}")
return msg_dict
def create_jws(payload: dict, jwk_dict: dict, alg: str = "RS256", protected:dict = {}, **kwargs) -> str:
_key = key_from_jwk_dict(jwk_dict)
_signer = JWS(payload, alg=alg, **kwargs)
signature = _signer.sign_compact([_key], protected=protected, **kwargs)
return signature
def verify_jws(jws: str, pub_jwk: dict, **kwargs) -> str:
_key = key_from_jwk_dict(pub_jwk)
_head = unpad_jwt_head(jws)
if _head.get("kid") != pub_jwk["kid"]: # pragma: no cover
raise Exception(
f"kid error: {_head.get('kid')} != {pub_jwk['kid']}"
)
_alg = _head["alg"]
if _alg not in SIGNING_ALG_VALUES_SUPPORTED or not _alg: # pragma: no cover
raise UnsupportedAlgorithm(f"{_alg} has beed disabled for security reason")
verifier = JWS(alg=_head["alg"], **kwargs)
msg = verifier.verify_compact(jws, [_key])
return msg
def verify_at_hash(id_token, access_token) -> bool:
id_token_at_hash = id_token['at_hash']
at_hash = left_hash(access_token, "HS256")
if at_hash != id_token_at_hash:
raise Exception(
f"at_hash error: {at_hash} != {id_token_at_hash}"
)
return True