authlib/jose/rfc7517/asymmetric_key.py
from authlib.common.encoding import to_bytes
from cryptography.hazmat.primitives.serialization import (
Encoding, PrivateFormat, PublicFormat,
BestAvailableEncryption, NoEncryption,
)
from ._cryptography_key import load_pem_key
from .base_key import Key
class AsymmetricKey(Key):
"""This is the base class for a JSON Web Key."""
PUBLIC_KEY_FIELDS = []
PRIVATE_KEY_FIELDS = []
PRIVATE_KEY_CLS = bytes
PUBLIC_KEY_CLS = bytes
SSH_PUBLIC_PREFIX = b''
def __init__(self, private_key=None, public_key=None, options=None):
super().__init__(options)
self.private_key = private_key
self.public_key = public_key
@property
def public_only(self):
if self.private_key:
return False
if 'd' in self.tokens:
return False
return True
def get_op_key(self, operation):
"""Get the raw key for the given key_op. This method will also
check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:return: raw key
"""
self.check_key_op(operation)
if operation in self.PUBLIC_KEY_OPS:
return self.get_public_key()
return self.get_private_key()
def get_public_key(self):
if self.public_key:
return self.public_key
private_key = self.get_private_key()
if private_key:
return private_key.public_key()
return self.public_key
def get_private_key(self):
if self.private_key:
return self.private_key
if self.tokens:
self.load_raw_key()
return self.private_key
def load_raw_key(self):
if 'd' in self.tokens:
self.private_key = self.load_private_key()
else:
self.public_key = self.load_public_key()
def load_dict_key(self):
if self.private_key:
self._dict_data.update(self.dumps_private_key())
else:
self._dict_data.update(self.dumps_public_key())
def dumps_private_key(self):
raise NotImplementedError()
def dumps_public_key(self):
raise NotImplementedError()
def load_private_key(self):
raise NotImplementedError()
def load_public_key(self):
raise NotImplementedError()
def as_dict(self, is_private=False, **params):
"""Represent this key as a dict of the JSON Web Key."""
tokens = self.tokens
if is_private and 'd' not in tokens:
raise ValueError('This is a public key')
kid = tokens.get('kid')
if 'd' in tokens and not is_private:
# filter out private fields
tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS}
tokens['kty'] = self.kty
if kid:
tokens['kid'] = kid
if not kid:
tokens['kid'] = self.thumbprint()
tokens.update(params)
return tokens
def as_key(self, is_private=False):
"""Represent this key as raw key."""
if is_private:
return self.get_private_key()
return self.get_public_key()
def as_bytes(self, encoding=None, is_private=False, password=None):
"""Export key into PEM/DER format bytes.
:param encoding: "PEM" or "DER"
:param is_private: export private key or public key
:param password: encrypt private key with password
:return: bytes
"""
if encoding is None or encoding == 'PEM':
encoding = Encoding.PEM
elif encoding == 'DER':
encoding = Encoding.DER
else:
raise ValueError(f'Invalid encoding: {encoding!r}')
raw_key = self.as_key(is_private)
if is_private:
if not raw_key:
raise ValueError('This is a public key')
if password is None:
encryption_algorithm = NoEncryption()
else:
encryption_algorithm = BestAvailableEncryption(to_bytes(password))
return raw_key.private_bytes(
encoding=encoding,
format=PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)
return raw_key.public_bytes(
encoding=encoding,
format=PublicFormat.SubjectPublicKeyInfo,
)
def as_pem(self, is_private=False, password=None):
return self.as_bytes(is_private=is_private, password=password)
def as_der(self, is_private=False, password=None):
return self.as_bytes(encoding='DER', is_private=is_private, password=password)
@classmethod
def import_dict_key(cls, raw, options=None):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
return key
@classmethod
def import_key(cls, raw, options=None):
if isinstance(raw, cls):
if options is not None:
raw.options.update(options)
return raw
if isinstance(raw, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw, options=options)
elif isinstance(raw, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw, options=options)
elif isinstance(raw, dict):
key = cls.import_dict_key(raw, options)
else:
if options is not None:
password = options.pop('password', None)
else:
password = None
raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password)
if isinstance(raw_key, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw_key, options=options)
elif isinstance(raw_key, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw_key, options=options)
else:
raise ValueError('Invalid data for importing key')
return key
@classmethod
def validate_raw_key(cls, key):
return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance(key, cls.PRIVATE_KEY_CLS)
@classmethod
def generate_key(cls, crv_or_size, options=None, is_private=False):
raise NotImplementedError()