tlsfuzzer/tlslite-ng

View on GitHub
tlslite/keyexchange.py

Summary

Maintainability
F
5 days
Test Coverage
A
95%
# Authors:
#   Hubert Kario (2015)
#
# See the LICENSE file for legal information regarding use of this file.
"""Handling of cryptographic operations for key exchange"""

import ecdsa
from .mathtls import goodGroupParameters, makeK, makeU, makeX, \
        paramStrength, RFC7919_GROUPS, calc_key
from .errors import TLSInsufficientSecurity, TLSUnknownPSKIdentity, \
        TLSIllegalParameterException, TLSDecryptionFailed, TLSInternalError, \
        TLSDecodeError
from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify
from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \
        ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat
from .utils.ecc import getCurveByName, getPointByteSize
from .utils.rsakey import RSAKey
from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
        numBits, numberToByteArray, divceil, numBytes, secureHash
from .utils.lists import getFirstMatching
from .utils import tlshashlib as hashlib
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
        X448_ORDER_SIZE
from .utils.compat import int_types, ML_KEM_AVAILABLE
from .utils.codec import DecodeError

if ML_KEM_AVAILABLE:
    from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024


class KeyExchange(object):
    """
    Common API for calculating Premaster secret

    NOT stable, will get moved from this file
    """

    def __init__(self, cipherSuite, clientHello, serverHello, privateKey=None):
        """Initialize KeyExchange. privateKey is the signing private key"""
        self.cipherSuite = cipherSuite
        self.clientHello = clientHello
        self.serverHello = serverHello
        self.privateKey = privateKey

    def makeServerKeyExchange(self, sigHash=None):
        """
        Create a ServerKeyExchange object

        Returns a ServerKeyExchange object for the server's initial leg in the
        handshake. If the key exchange method does not send ServerKeyExchange
        (e.g. RSA), it returns None.
        """
        raise NotImplementedError()

    def makeClientKeyExchange(self):
        """
        Create a ClientKeyExchange object

        Returns a ClientKeyExchange for the second flight from client in the
        handshake.
        """
        return ClientKeyExchange(self.cipherSuite,
                                 self.serverHello.server_version)

    def processClientKeyExchange(self, clientKeyExchange):
        """
        Process ClientKeyExchange and return premaster secret

        Processes the client's ClientKeyExchange message and returns the
        premaster secret. Raises TLSLocalAlert on error.
        """
        raise NotImplementedError()

    def processServerKeyExchange(self, srvPublicKey,
                                 serverKeyExchange):
        """Process the server KEX and return premaster secret"""
        raise NotImplementedError()

    def _tls12_sign_ecdsa_SKE(self, serverKeyExchange, sigHash=None):
        try:
            serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \
                    getattr(SignatureScheme, sigHash)
            hashName = SignatureScheme.getHash(sigHash)
        except AttributeError:
            serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash)
            serverKeyExchange.signAlg = SignatureAlgorithm.ecdsa
            hashName = sigHash

        hash_bytes = serverKeyExchange.hash(self.clientHello.random,
                                            self.serverHello.random)

        hash_bytes = hash_bytes[:self.privateKey.private_key.curve.baselen]

        serverKeyExchange.signature = \
            self.privateKey.sign(hash_bytes, hashAlg=hashName)

        if not serverKeyExchange.signature:
            raise TLSInternalError("Empty signature")

        if not self.privateKey.verify(serverKeyExchange.signature,
                                             hash_bytes,
                                             ecdsa.util.sigdecode_der):
            raise TLSInternalError("signature validation failure")

    def _tls12_sign_dsa_SKE(self, serverKeyExchange, sigHash=None):
        """Sign a TLSv1.2 SKE message."""
        try:
            serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \
                getattr(SignatureScheme, sigHash)

        except AttributeError:
            serverKeyExchange.signAlg = SignatureAlgorithm.dsa
            serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash)

        hashBytes = serverKeyExchange.hash(self.clientHello.random,
                                           self.serverHello.random)

        serverKeyExchange.signature = \
            self.privateKey.sign(hashBytes)

        if not serverKeyExchange.signature:
            raise TLSInternalError("Empty signature")

        if not self.privateKey.verify(serverKeyExchange.signature,
                                      hashBytes):
            raise TLSInternalError("Server Key Exchange signature invalid")

    def _tls12_sign_eddsa_ske(self, server_key_exchange, sig_hash):
        """Sign a TLSv1.2 SKE message."""
        server_key_exchange.hashAlg, server_key_exchange.signAlg = \
                getattr(SignatureScheme, sig_hash)
        pad_type = None
        hash_name = None
        salt_len = None

        hash_bytes = server_key_exchange.hash(self.clientHello.random,
                                              self.serverHello.random)

        server_key_exchange.signature = \
            self.privateKey.hashAndSign(hash_bytes,
                                        pad_type,
                                        hash_name,
                                        salt_len)

        if not server_key_exchange.signature:
            raise TLSInternalError("Empty signature")

        if not self.privateKey.hashAndVerify(
                server_key_exchange.signature,
                hash_bytes,
                pad_type,
                hash_name,
                salt_len):
            raise TLSInternalError("Server Key Exchange signature invalid")

    def _tls12_signSKE(self, serverKeyExchange, sigHash=None):
        """Sign a TLSv1.2 SKE message."""
        try:
            serverKeyExchange.hashAlg, serverKeyExchange.signAlg = \
                    getattr(SignatureScheme, sigHash)
            keyType = SignatureScheme.getKeyType(sigHash)
            padType = SignatureScheme.getPadding(sigHash)
            hashName = SignatureScheme.getHash(sigHash)
            saltLen = getattr(hashlib, hashName)().digest_size
        except AttributeError:
            serverKeyExchange.signAlg = SignatureAlgorithm.rsa
            serverKeyExchange.hashAlg = getattr(HashAlgorithm, sigHash)
            keyType = 'rsa'
            padType = 'pkcs1'
            hashName = sigHash
            saltLen = 0

        assert keyType == 'rsa'

        hashBytes = serverKeyExchange.hash(self.clientHello.random,
                                           self.serverHello.random)

        serverKeyExchange.signature = \
            self.privateKey.sign(hashBytes,
                                 padding=padType,
                                 hashAlg=hashName,
                                 saltLen=saltLen)

        if not serverKeyExchange.signature:
            raise TLSInternalError("Empty signature")

        if not self.privateKey.verify(serverKeyExchange.signature,
                                      hashBytes,
                                      padding=padType,
                                      hashAlg=hashName,
                                      saltLen=saltLen):
            raise TLSInternalError("Server Key Exchange signature invalid")

    def signServerKeyExchange(self, serverKeyExchange, sigHash=None):
        """
        Sign a server key exchange using default or specified algorithm

        :type sigHash: str
        :param sigHash: name of the signature hash to be used for signing
        """
        if self.serverHello.server_version < (3, 3):
            if self.privateKey.key_type == "ecdsa":
                serverKeyExchange.signAlg = SignatureAlgorithm.ecdsa
            if self.privateKey.key_type == "dsa":
                serverKeyExchange.signAlg = SignatureAlgorithm.dsa
            hashBytes = serverKeyExchange.hash(self.clientHello.random,
                                               self.serverHello.random)

            serverKeyExchange.signature = self.privateKey.sign(hashBytes)

            if not serverKeyExchange.signature:
                raise TLSInternalError("Empty signature")

            if not self.privateKey.verify(serverKeyExchange.signature,
                                          hashBytes):
                raise TLSInternalError("Server Key Exchange signature invalid")
        else:
            if self.privateKey.key_type == "ecdsa":
                self._tls12_sign_ecdsa_SKE(serverKeyExchange, sigHash)
            elif self.privateKey.key_type == "dsa":
                self._tls12_sign_dsa_SKE(serverKeyExchange, sigHash)
            elif self.privateKey.key_type in ("Ed25519", "Ed448"):
                self._tls12_sign_eddsa_ske(serverKeyExchange, sigHash)
            else:
                self._tls12_signSKE(serverKeyExchange, sigHash)

    @staticmethod
    def _tls12_verify_ecdsa_SKE(serverKeyExchange, publicKey, clientRandom,
                                serverRandom, validSigAlgs):
        hashName = HashAlgorithm.toRepr(serverKeyExchange.hashAlg)
        if not hashName:
            raise TLSIllegalParameterException("Unknown hash algorithm")

        hashBytes = serverKeyExchange.hash(clientRandom, serverRandom)

        hashBytes = hashBytes[:publicKey.public_key.curve.baselen]

        if not publicKey.verify(serverKeyExchange.signature, hashBytes,
                                padding=None,
                                hashAlg=hashName,
                                saltLen=None):
            raise TLSDecryptionFailed("Server Key Exchange signature "
                                      "invalid")

    @staticmethod
    def _tls12_verify_eddsa_ske(server_key_exchange, public_key, client_random,
                                server_random, valid_sig_algs):
        """Verify SeverKeyExchange messages with EdDSA signatures."""
        del valid_sig_algs
        sig_bytes = server_key_exchange.signature
        if not sig_bytes:
            raise TLSIllegalParameterException("Empty signature")

        hash_bytes = server_key_exchange.hash(client_random, server_random)

        if not public_key.hashAndVerify(sig_bytes,
                                        hash_bytes):
            raise TLSDecryptionFailed("Server Key Exchange signature invalid")

    @staticmethod
    def _tls12_verify_dsa_SKE(serverKeyExchange, publicKey, clientRandom,
                              serverRandom, validSigAlgs):

        hashBytes = serverKeyExchange.hash(clientRandom, serverRandom)

        if not publicKey.verify(serverKeyExchange.signature, hashBytes):
            raise TLSDecryptionFailed("Server Key Exchange signature "
                                      "invalid")

    @staticmethod
    def _tls12_verify_SKE(serverKeyExchange, publicKey, clientRandom,
                          serverRandom, validSigAlgs):
        """Verify TLSv1.2 version of SKE."""
        if (serverKeyExchange.hashAlg, serverKeyExchange.signAlg) not in \
                validSigAlgs:
            raise TLSIllegalParameterException("Server selected "
                                               "invalid signature "
                                               "algorithm")
        if (serverKeyExchange.hashAlg, serverKeyExchange.signAlg) in (
                SignatureScheme.ed25519, SignatureScheme.ed448):
            return KeyExchange._tls12_verify_eddsa_ske(serverKeyExchange,
                                                       publicKey,
                                                       clientRandom,
                                                       serverRandom,
                                                       validSigAlgs)
        if serverKeyExchange.signAlg == SignatureAlgorithm.ecdsa:
            return KeyExchange._tls12_verify_ecdsa_SKE(serverKeyExchange,
                                                       publicKey,
                                                       clientRandom,
                                                       serverRandom,
                                                       validSigAlgs)

        elif serverKeyExchange.signAlg == SignatureAlgorithm.dsa:
            return KeyExchange._tls12_verify_dsa_SKE(serverKeyExchange,
                                                     publicKey,
                                                     clientRandom,
                                                     serverRandom,
                                                     validSigAlgs)

        schemeID = (serverKeyExchange.hashAlg,
                    serverKeyExchange.signAlg)
        scheme = SignatureScheme.toRepr(schemeID)
        if scheme is not None:
            keyType = SignatureScheme.getKeyType(scheme)
            padType = SignatureScheme.getPadding(scheme)
            hashName = SignatureScheme.getHash(scheme)
            saltLen = getattr(hashlib, hashName)().digest_size
        else:
            if serverKeyExchange.signAlg != SignatureAlgorithm.rsa:
                raise TLSInternalError("non-RSA sigs are not supported")
            keyType = 'rsa'
            padType = 'pkcs1'
            saltLen = 0
            hashName = HashAlgorithm.toRepr(serverKeyExchange.hashAlg)
            if hashName is None:
                msg = "Unknown hash ID: {0}"\
                        .format(serverKeyExchange.hashAlg)
                raise TLSIllegalParameterException(msg)
        assert keyType == 'rsa'

        hashBytes = serverKeyExchange.hash(clientRandom, serverRandom)

        sigBytes = serverKeyExchange.signature
        if not sigBytes:
            raise TLSIllegalParameterException("Empty signature")

        if not publicKey.verify(sigBytes, hashBytes,
                                padding=padType,
                                hashAlg=hashName,
                                saltLen=saltLen):
            raise TLSDecryptionFailed("Server Key Exchange signature "
                                      "invalid")

    @staticmethod
    def verifyServerKeyExchange(serverKeyExchange, publicKey, clientRandom,
                                serverRandom, validSigAlgs):
        """Verify signature on the Server Key Exchange message

        the only acceptable signature algorithms are specified by validSigAlgs
        """
        if serverKeyExchange.version < (3, 3):
            hashBytes = serverKeyExchange.hash(clientRandom, serverRandom)
            sigBytes = serverKeyExchange.signature

            if not sigBytes:
                raise TLSIllegalParameterException("Empty signature")

            if not publicKey.verify(sigBytes, hashBytes):
                raise TLSDecryptionFailed("Server Key Exchange signature "
                                          "invalid")
        else:
            KeyExchange._tls12_verify_SKE(serverKeyExchange, publicKey,
                                          clientRandom, serverRandom,
                                          validSigAlgs)

    @staticmethod
    def calcVerifyBytes(version, handshakeHashes, signatureAlg,
                        premasterSecret, clientRandom, serverRandom,
                        prf_name = None, peer_tag=b'client', key_type="rsa"):
        """Calculate signed bytes for Certificate Verify"""
        if version == (3, 0):
            masterSecret = calc_key(version, premasterSecret,
                                    0, b"master secret",
                                    client_random=clientRandom,
                                    server_random=serverRandom,
                                    output_length=48)
            verifyBytes = handshakeHashes.digestSSL(masterSecret, b"")
        elif version in ((3, 1), (3, 2)):
            if key_type != "ecdsa":
                verifyBytes = handshakeHashes.digest()
            else:
                verifyBytes = handshakeHashes.digest("sha1")
        elif version == (3, 3):
            if signatureAlg in (SignatureScheme.ed25519,
                    SignatureScheme.ed448):
                hashName = "intrinsic"
                padding = None
            elif signatureAlg[1] == SignatureAlgorithm.dsa:
                hashName = HashAlgorithm.toRepr(signatureAlg[0])
                padding = None
            elif signatureAlg[1] != SignatureAlgorithm.ecdsa:
                scheme = SignatureScheme.toRepr(signatureAlg)
                if scheme is None:
                    hashName = HashAlgorithm.toRepr(signatureAlg[0])
                    padding = 'pkcs1'
                else:
                    hashName = SignatureScheme.getHash(scheme)
                    padding = SignatureScheme.getPadding(scheme)
            else:
                padding = None
                hashName = HashAlgorithm.toRepr(signatureAlg[0])
            verifyBytes = handshakeHashes.digest(hashName)
            if padding == 'pkcs1':
                verifyBytes = RSAKey.addPKCS1Prefix(verifyBytes, hashName)
        elif version == (3, 4):
            scheme = SignatureScheme.toRepr(signatureAlg)
            if scheme:
                hash_name = SignatureScheme.getHash(scheme)
            else:
                # handles negative test cases when we try to pass in
                # schemes that are not supported in TLS1.3
                hash_name = HashAlgorithm.toRepr(signatureAlg[0])
            verifyBytes = bytearray(b'\x20' * 64 +
                                    b'TLS 1.3, ' + peer_tag +
                                    b' CertificateVerify' +
                                    b'\x00') + \
                          handshakeHashes.digest(prf_name)
            if hash_name != "intrinsic":
                verifyBytes = secureHash(verifyBytes, hash_name)
        else:
            raise ValueError("Unsupported TLS version {0}".format(version))
        return verifyBytes

    @staticmethod
    def makeCertificateVerify(version, handshakeHashes, validSigAlgs,
                              privateKey, certificateRequest, premasterSecret,
                              clientRandom, serverRandom):
        """Create a Certificate Verify message

        :param version: protocol version in use
        :param handshakeHashes: the running hash of all handshake messages
        :param validSigAlgs: acceptable signature algorithms for client side,
            applicable only to TLSv1.2 (or later)
        :param certificateRequest: the server provided Certificate Request
            message
        :param premasterSecret: the premaster secret, needed only for SSLv3
        :param clientRandom: client provided random value, needed only for
            SSLv3
        :param serverRandom: server provided random value, needed only for
            SSLv3
        """
        signatureAlgorithm = None
        if privateKey.key_type == "ecdsa" and version < (3, 3):
            signatureAlgorithm = (HashAlgorithm.sha1, SignatureAlgorithm.ecdsa)
        # in TLS 1.2 we must decide which algorithm to use for signing
        if version == (3, 3):
            serverSigAlgs = certificateRequest.supported_signature_algs
            signatureAlgorithm = getFirstMatching(validSigAlgs, serverSigAlgs)
            # if none acceptable, do a last resort:
            if signatureAlgorithm is None:
                signatureAlgorithm = validSigAlgs[0]
        verifyBytes = KeyExchange.calcVerifyBytes(version, handshakeHashes,
                                                  signatureAlgorithm,
                                                  premasterSecret,
                                                  clientRandom,
                                                  serverRandom,
                                                  key_type=privateKey.key_type)
        if signatureAlgorithm and signatureAlgorithm in (
                SignatureScheme.ed25519, SignatureScheme.ed448):
            padding = None
            hashName = "intrinsic"
            saltLen = None
            sig_func = privateKey.hashAndSign
            ver_func = privateKey.hashAndVerify
        elif signatureAlgorithm and \
                signatureAlgorithm[1] == SignatureAlgorithm.ecdsa:
            padding = None
            hashName = HashAlgorithm.toRepr(signatureAlgorithm[0])
            saltLen = None
            verifyBytes = verifyBytes[:privateKey.private_key.curve.baselen]
            sig_func = privateKey.sign
            ver_func = privateKey.verify
        elif signatureAlgorithm and \
                signatureAlgorithm[1] == SignatureAlgorithm.dsa:
            padding = None
            hashName = HashAlgorithm.toRepr(signatureAlgorithm[0])
            saltLen = None
            sig_func = privateKey.sign
            ver_func = privateKey.verify
        else:
            scheme = SignatureScheme.toRepr(signatureAlgorithm)
            # for pkcs1 signatures hash is used to add PKCS#1 prefix, but
            # that was already done by calcVerifyBytes
            hashName = None
            saltLen = 0
            if scheme is None:
                padding = 'pkcs1'
            else:
                padding = SignatureScheme.getPadding(scheme)
                if padding == 'pss':
                    hashName = SignatureScheme.getHash(scheme)
                    saltLen = getattr(hashlib, hashName)().digest_size
            sig_func = privateKey.sign
            ver_func = privateKey.verify

        signedBytes = sig_func(verifyBytes,
                               padding,
                               hashName,
                               saltLen)
        if not ver_func(signedBytes, verifyBytes, padding, hashName,
                        saltLen):
            raise TLSInternalError("Certificate Verify signature invalid")
        certificateVerify = CertificateVerify(version)
        certificateVerify.create(signedBytes, signatureAlgorithm)

        return certificateVerify


class AuthenticatedKeyExchange(KeyExchange):
    """
    Common methods for key exchanges that authenticate Server Key Exchange

    Methods for signing Server Key Exchange message
    """

    def makeServerKeyExchange(self, sigHash=None):
        """Prepare server side of key exchange with selected parameters"""
        ske = super(AuthenticatedKeyExchange, self).makeServerKeyExchange()
        self.signServerKeyExchange(ske, sigHash)
        return ske


class RSAKeyExchange(KeyExchange):
    """
    Handling of RSA key exchange

    NOT stable API, do NOT use
    """

    def __init__(self, cipherSuite, clientHello, serverHello, privateKey):
        super(RSAKeyExchange, self).__init__(cipherSuite, clientHello,
                                             serverHello, privateKey)
        self.encPremasterSecret = None

    def makeServerKeyExchange(self, sigHash=None):
        """Don't create a server key exchange for RSA key exchange"""
        return None

    def processClientKeyExchange(self, clientKeyExchange):
        """Decrypt client key exchange, return premaster secret"""
        premasterSecret = self.privateKey.decrypt(\
            clientKeyExchange.encryptedPreMasterSecret)

        # On decryption failure randomize premaster secret to avoid
        # Bleichenbacher's "million message" attack
        randomPreMasterSecret = getRandomBytes(48)
        if not premasterSecret:
            premasterSecret = randomPreMasterSecret
        elif len(premasterSecret) != 48:
            premasterSecret = randomPreMasterSecret
        else:
            versionCheck = (premasterSecret[0], premasterSecret[1])
            if versionCheck != self.clientHello.client_version:
                #Tolerate buggy IE clients
                if versionCheck != self.serverHello.server_version:
                    premasterSecret = randomPreMasterSecret
        return premasterSecret

    def processServerKeyExchange(self, srvPublicKey,
                                 serverKeyExchange):
        """Generate premaster secret for server"""
        del serverKeyExchange # not present in RSA key exchange
        premasterSecret = getRandomBytes(48)
        premasterSecret[0] = self.clientHello.client_version[0]
        premasterSecret[1] = self.clientHello.client_version[1]

        self.encPremasterSecret = srvPublicKey.encrypt(premasterSecret)
        return premasterSecret

    def makeClientKeyExchange(self):
        """Return a client key exchange with clients key share"""
        clientKeyExchange = super(RSAKeyExchange, self).makeClientKeyExchange()
        clientKeyExchange.createRSA(self.encPremasterSecret)
        return clientKeyExchange


class ADHKeyExchange(KeyExchange):
    """
    Handling of anonymous Diffie-Hellman Key exchange

    FFDHE without signing serverKeyExchange useful for anonymous DH
    """

    def __init__(self, cipherSuite, clientHello, serverHello,
                 dhParams=None, dhGroups=None):
        super(ADHKeyExchange, self).__init__(cipherSuite, clientHello,
                                             serverHello)
#pylint: enable = invalid-name
        self.dh_Xs = None
        self.dh_Yc = None
        if dhParams:
            self.dh_g, self.dh_p = dhParams
        else:
            # 2048-bit MODP Group (RFC 5054, group 3)
            self.dh_g, self.dh_p = goodGroupParameters[2]
        self.dhGroups = dhGroups

    def makeServerKeyExchange(self):
        """
        Prepare server side of anonymous key exchange with selected parameters
        """
        # Check for RFC 7919 support
        ext = self.clientHello.getExtension(ExtensionType.supported_groups)
        if ext and self.dhGroups:
            commonGroup = getFirstMatching(ext.groups, self.dhGroups)
            if commonGroup:
                self.dh_g, self.dh_p = RFC7919_GROUPS[commonGroup - 256]
            elif getFirstMatching(ext.groups, range(256, 512)):
                raise TLSInternalError("DHE key exchange attempted despite no "
                                       "overlap between supported groups")

        # for TLS < 1.3 we need special algorithm to select params (see above)
        # so do not pass in the group, if we selected one
        kex = FFDHKeyExchange(None, self.serverHello.server_version,
                              self.dh_g, self.dh_p)
        self.dh_Xs = kex.get_random_private_key()
        dh_Ys = kex.calc_public_value(self.dh_Xs)

        version = self.serverHello.server_version
        serverKeyExchange = ServerKeyExchange(self.cipherSuite, version)
        serverKeyExchange.createDH(self.dh_p, self.dh_g, dh_Ys)
        # No sign for anonymous ServerKeyExchange.
        return serverKeyExchange

    def processClientKeyExchange(self, clientKeyExchange):
        """Use client provided parameters to establish premaster secret"""
        dh_Yc = clientKeyExchange.dh_Yc

        kex = FFDHKeyExchange(None, self.serverHello.server_version,
                              self.dh_g, self.dh_p)
        return kex.calc_shared_key(self.dh_Xs, dh_Yc)

    def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
        """Process the server key exchange, return premaster secret."""
        del srvPublicKey
        dh_p = serverKeyExchange.dh_p
        # TODO make the minimum changeable
        if dh_p < 2**1023:
            raise TLSInsufficientSecurity("DH prime too small")
        dh_g = serverKeyExchange.dh_g
        dh_Ys = serverKeyExchange.dh_Ys

        kex = FFDHKeyExchange(None, self.serverHello.server_version,
                              dh_g, dh_p)

        dh_Xc = kex.get_random_private_key()
        self.dh_Yc = kex.calc_public_value(dh_Xc)
        return kex.calc_shared_key(dh_Xc, dh_Ys)

    def makeClientKeyExchange(self):
        """Create client key share for the key exchange"""
        cke = super(ADHKeyExchange, self).makeClientKeyExchange()
        cke.createDH(self.dh_Yc)
        return cke


# the DHE_RSA part comes from IETF ciphersuite names, we want to keep it
#pylint: disable = invalid-name
class DHE_RSAKeyExchange(AuthenticatedKeyExchange, ADHKeyExchange):
    """
    Handling of authenticated ephemeral Diffe-Hellman Key exchange.
    """

    def __init__(self, cipherSuite, clientHello, serverHello, privateKey,
                 dhParams=None, dhGroups=None):
        """
        Create helper object for Diffie-Hellamn key exchange.

        :param dhParams: Diffie-Hellman parameters that will be used by
            server. First element of the tuple is the generator, the second
            is the prime. If not specified it will use a secure set (currently
            a 2048-bit safe prime).
        :type dhParams: 2-element tuple of int
        """
        super(DHE_RSAKeyExchange, self).__init__(cipherSuite, clientHello,
                                                 serverHello, dhParams,
                                                 dhGroups)
#pylint: enable = invalid-name
        self.privateKey = privateKey


class AECDHKeyExchange(KeyExchange):
    """
    Handling of anonymous Eliptic curve Diffie-Hellman Key exchange

    ECDHE without signing serverKeyExchange useful for anonymous ECDH
    """

    def __init__(self, cipherSuite, clientHello, serverHello, acceptedCurves,
                 defaultCurve=GroupName.secp256r1):
        super(AECDHKeyExchange, self).__init__(cipherSuite, clientHello,
                                               serverHello)
        self.ecdhXs = None
        self.acceptedCurves = acceptedCurves
        self.group_id = None
        self.ecdhYc = None
        self.defaultCurve = defaultCurve

    def makeServerKeyExchange(self, sigHash=None):
        """Create AECDHE version of Server Key Exchange"""
        #Get client supported groups
        client_curves = self.clientHello.getExtension(
                ExtensionType.supported_groups)
        if client_curves is None:
            # in case there is no extension, we can pick any curve,
            # use the configured one
            client_curves = [self.defaultCurve]
        elif not client_curves.groups:
            # extension should have been validated before
            raise TLSInternalError("Can't do ECDHE with no client curves")
        else:
            client_curves = client_curves.groups

        #Pick first client preferred group we support
        self.group_id = getFirstMatching(client_curves, self.acceptedCurves)
        if self.group_id is None:
            raise TLSInsufficientSecurity("No mutual groups")

        kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version)
        self.ecdhXs = kex.get_random_private_key()

        ext_negotiated = ECPointFormat.uncompressed
        ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
        ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
        if ext_c and ext_s:
            try:
                ext_negotiated = next((i for i in ext_c.formats \
                                       if i in ext_s.formats))
            except StopIteration:
                raise TLSIllegalParameterException("No common EC point format")

        ext_negotiated = 'uncompressed' if \
            ext_negotiated == ECPointFormat.uncompressed else 'compressed'

        ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated)

        version = self.serverHello.server_version
        serverKeyExchange = ServerKeyExchange(self.cipherSuite, version)
        serverKeyExchange.createECDH(ECCurveType.named_curve,
                                     named_curve=self.group_id,
                                     point=ecdhYs)
        # No sign for anonymous ServerKeyExchange
        return serverKeyExchange

    def processClientKeyExchange(self, clientKeyExchange):
        """Calculate premaster secret from previously generated SKE and CKE"""
        ecdhYc = clientKeyExchange.ecdh_Yc

        if not ecdhYc:
            raise TLSDecodeError("No key share")

        kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version)
        ext_supported = [ECPointFormat.uncompressed]
        ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
        ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
        if ext_c and ext_s:
            ext_supported = [
                ext for ext in ext_c.formats if ext in ext_s.formats
                ]
            if not ext_supported:
                raise TLSIllegalParameterException("No common EC point format")
        ext_supported = map(
            lambda x: 'uncompressed' if
            x == ECPointFormat.uncompressed else
            'compressed', ext_supported
            )
        return kex.calc_shared_key(self.ecdhXs, ecdhYc, set(ext_supported))

    def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
        """Process the server key exchange, return premaster secret"""
        del srvPublicKey

        if serverKeyExchange.curve_type != ECCurveType.named_curve \
            or serverKeyExchange.named_curve not in self.acceptedCurves:
            raise TLSIllegalParameterException("Server picked curve we "
                                               "didn't advertise")

        ecdh_Ys = serverKeyExchange.ecdh_Ys
        if not ecdh_Ys:
            raise TLSDecodeError("Empty server key share")

        kex = ECDHKeyExchange(serverKeyExchange.named_curve,
                              self.serverHello.server_version)
        ecdhXc = kex.get_random_private_key()
        ext_negotiated = ECPointFormat.uncompressed
        ext_supported = [ECPointFormat.uncompressed]

        if self.clientHello:
            ext_c = self.clientHello.getExtension(
                ExtensionType.ec_point_formats)
            ext_s = self.serverHello.getExtension(
                ExtensionType.ec_point_formats)
            if ext_c and ext_s:
                try:
                    ext_supported = [
                        i for i in ext_c.formats if i in ext_s.formats
                        ]
                    ext_negotiated = ext_supported[0]
                except IndexError:
                    raise TLSIllegalParameterException(
                        "No common EC point format")

        ext_negotiated = 'uncompressed' if \
            ext_negotiated == ECPointFormat.uncompressed else 'compressed'
        ext_supported = map(
            lambda x: 'uncompressed' if
            x == ECPointFormat.uncompressed else
            'compressed', ext_supported
            )
        self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated)
        return kex.calc_shared_key(ecdhXc, ecdh_Ys, set(ext_supported))

    def makeClientKeyExchange(self):
        """Make client key exchange for ECDHE"""
        cke = super(AECDHKeyExchange, self).makeClientKeyExchange()
        cke.createECDH(self.ecdhYc)
        return cke


# The ECDHE_RSA part comes from the IETF names of ciphersuites, so we want to
# keep it
#pylint: disable = invalid-name
class ECDHE_RSAKeyExchange(AuthenticatedKeyExchange, AECDHKeyExchange):
    """Helper class for conducting ECDHE key exchange"""

    def __init__(self, cipherSuite, clientHello, serverHello, privateKey,
                 acceptedCurves, defaultCurve=GroupName.secp256r1):
        super(ECDHE_RSAKeyExchange, self).__init__(cipherSuite, clientHello,
                                                   serverHello,
                                                   acceptedCurves,
                                                   defaultCurve)
#pylint: enable = invalid-name
        self.privateKey = privateKey


class SRPKeyExchange(KeyExchange):
    """Helper class for conducting SRP key exchange"""

    def __init__(self, cipherSuite, clientHello, serverHello, privateKey,
                 verifierDB, srpUsername=None, password=None, settings=None):
        """Link Key Exchange options with verifierDB for SRP"""
        super(SRPKeyExchange, self).__init__(cipherSuite, clientHello,
                                             serverHello, privateKey)
        self.N = None
        self.v = None
        self.b = None
        self.B = None
        self.verifierDB = verifierDB
        self.A = None
        self.srpUsername = srpUsername
        self.password = password
        self.settings = settings
        if srpUsername is not None and not isinstance(srpUsername, bytearray):
            raise TypeError("srpUsername must be a bytearray object")
        if password is not None and not isinstance(password, bytearray):
            raise TypeError("password must be a bytearray object")

    def makeServerKeyExchange(self, sigHash=None):
        """Create SRP version of Server Key Exchange"""
        srpUsername = bytes(self.clientHello.srp_username)
        #Get parameters from username
        try:
            entry = self.verifierDB[srpUsername]
        except KeyError:
            raise TLSUnknownPSKIdentity("Unknown identity")
        (self.N, g, s, self.v) = entry

        #Calculate server's ephemeral DH values (b, B)
        self.b = bytesToNumber(getRandomBytes(32))
        k = makeK(self.N, g)
        self.B = (powMod(g, self.b, self.N) + (k * self.v)) % self.N

        #Create ServerKeyExchange, signing it if necessary
        serverKeyExchange = ServerKeyExchange(self.cipherSuite,
                                              self.serverHello.server_version)
        serverKeyExchange.createSRP(self.N, g, s, self.B)
        if self.cipherSuite in CipherSuite.srpCertSuites:
            self.signServerKeyExchange(serverKeyExchange, sigHash)
        return serverKeyExchange

    def processClientKeyExchange(self, clientKeyExchange):
        """Calculate premaster secret from Client Key Exchange and sent SKE"""
        A = clientKeyExchange.srp_A
        if A % self.N == 0:
            raise TLSIllegalParameterException("Invalid SRP A value")

        #Calculate u
        u = makeU(self.N, A, self.B)

        #Calculate premaster secret
        S = powMod((A * powMod(self.v, u, self.N)) % self.N, self.b, self.N)
        return numberToByteArray(S)

    def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
        """Calculate premaster secret from ServerKeyExchange"""
        del srvPublicKey # irrelevant for SRP
        N = serverKeyExchange.srp_N
        g = serverKeyExchange.srp_g
        s = serverKeyExchange.srp_s
        B = serverKeyExchange.srp_B

        if (g, N) not in goodGroupParameters:
            raise TLSInsufficientSecurity("Unknown group parameters")
        if numBits(N) < self.settings.minKeySize:
            raise TLSInsufficientSecurity("N value is too small: {0}".\
                                          format(numBits(N)))
        if numBits(N) > self.settings.maxKeySize:
            raise TLSInsufficientSecurity("N value is too large: {0}".\
                                          format(numBits(N)))
        if B % N == 0:
            raise TLSIllegalParameterException("Suspicious B value")

        #Client ephemeral value
        a = bytesToNumber(getRandomBytes(32))
        self.A = powMod(g, a, N)

        #Calculate client's static DH values (x, v)
        x = makeX(s, self.srpUsername, self.password)
        v = powMod(g, x, N)

        #Calculate u
        u = makeU(N, self.A, B)

        #Calculate premaster secret
        k = makeK(N, g)
        S = powMod((B - (k*v)) % N, a+(u*x), N)
        return numberToByteArray(S)

    def makeClientKeyExchange(self):
        """Create ClientKeyExchange"""
        cke = super(SRPKeyExchange, self).makeClientKeyExchange()
        cke.createSRP(self.A)
        return cke


class RawDHKeyExchange(object):
    """
    Abstract class for performing Diffe-Hellman key exchange.

    Provides a shared API for X25519, ECDHE and FFDHE key exchange.
    """

    def __init__(self, group, version):
        """
        Set the parameters of the key exchange

        Sets group on which the KEX will take part and protocol version used.
        """
        self.group = group
        self.version = version

    def get_random_private_key(self):
        """
        Generate a random value suitable for use as the private value of KEX.
        """
        raise NotImplementedError("Abstract class")

    def calc_public_value(self, private, point_format=None):
        """Calculate the public value from the provided private value."""
        raise NotImplementedError("Abstract class")

    def calc_shared_key(self, private, peer_share, valid_point_formats=None):
        """Calcualte the shared key given our private and remote share value"""
        raise NotImplementedError("Abstract class")


class FFDHKeyExchange(RawDHKeyExchange):
    """Implemenation of the Finite Field Diffie-Hellman key exchange."""

    def __init__(self, group, version, generator=None, prime=None):
        super(FFDHKeyExchange, self).__init__(group, version)
        if prime and group:
            raise ValueError("Can't set the RFC7919 group and custom params"
                             " at the same time")
        if group:
            self.generator, self.prime = RFC7919_GROUPS[group-256]
        else:
            self.prime = prime
            self.generator = generator

        if not 1 < self.generator < self.prime:
            raise TLSIllegalParameterException("Invalid DH generator")

    def get_random_private_key(self):
        """
        Return a random private value for the prime used.

        :rtype: int
        """
        # Per RFC 3526, Section 1, the exponent should have double the entropy
        # of the strength of the group.
        needed_bytes = divceil(paramStrength(self.prime) * 2, 8)
        return bytesToNumber(getRandomBytes(needed_bytes))

    def calc_public_value(self, private, point_format=None):
        """
        Calculate the public value for given private value.

        :param point_format: ignored, used for compatibility with ECDH groups

        :rtype: int
        """
        dh_Y = powMod(self.generator, private, self.prime)
        if dh_Y in (1, self.prime - 1):
            raise TLSIllegalParameterException("Small subgroup capture")
        if self.version < (3, 4):
            return dh_Y
        else:
            return numberToByteArray(dh_Y, numBytes(self.prime))

    def _normalise_peer_share(self, peer_share):
        """Convert the peer_share to number if necessary."""
        if isinstance(peer_share, (int_types)):
            return peer_share

        if numBytes(self.prime) != len(peer_share):
            raise TLSIllegalParameterException(
                "Key share does not match FFDH prime")
        return bytesToNumber(peer_share)

    def calc_shared_key(self, private, peer_share, valid_point_formats=None):
        """Calculate the shared key.

        :param valid_point_formats: ignored, used for compatibility with ECDH groups

        :rtype: bytearray"""
        peer_share = self._normalise_peer_share(peer_share)
        # First half of RFC 2631, Section 2.1.5. Validate the client's public
        # key.
        # use of safe primes also means that the p-1 is invalid
        if not 2 <= peer_share < self.prime - 1:
            raise TLSIllegalParameterException("Invalid peer key share")

        S = powMod(peer_share, private, self.prime)
        if S in (1, self.prime - 1):
            raise TLSIllegalParameterException("Small subgroup capture")
        if self.version < (3, 4):
            return numberToByteArray(S)
        else:
            return numberToByteArray(S, numBytes(self.prime))


class ECDHKeyExchange(RawDHKeyExchange):
    """Implementation of the Elliptic Curve Diffie-Hellman key exchange."""
    _x_groups = set((GroupName.x25519, GroupName.x448))

    @staticmethod
    def _non_zero_check(value):
        """
        Verify using constant time operation that the bytearray is not zero

        :raises TLSIllegalParameterException: if the value is all zero
        """
        summa = 0
        for i in value:
            summa |= i
        if summa == 0:
            raise TLSIllegalParameterException("Invalid key share")

    def __init__(self, group, version):
        super(ECDHKeyExchange, self).__init__(group, version)

    def get_random_private_key(self):
        """Return random private key value for the selected curve."""
        if self.group in self._x_groups:
            if self.group == GroupName.x25519:
                return getRandomBytes(X25519_ORDER_SIZE)
            else:
                return getRandomBytes(X448_ORDER_SIZE)
        else:
            curve = getCurveByName(GroupName.toStr(self.group))
            return ecdsa.keys.SigningKey.generate(curve)

    def _get_fun_gen_size(self):
        """Return the function and generator for X25519/X448 KEX."""
        if self.group == GroupName.x25519:
            return x25519, bytearray(X25519_G), X25519_ORDER_SIZE
        else:
            return x448, bytearray(X448_G), X448_ORDER_SIZE

    def calc_public_value(self, private, point_format='uncompressed'):
        """
        Calculate public value for given private key.

        :param private: Private key for the selected key exchange group.
        :param str point_format: The point format to use for the
            ECDH public key. Applies only to NIST curves.
        """
        if isinstance(private, ecdsa.keys.SigningKey):
            return private.verifying_key.to_string(point_format)
        if self.group in self._x_groups:
            fun, generator, _ = self._get_fun_gen_size()
            return fun(private, generator)
        else:
            curve = getCurveByName(GroupName.toStr(self.group))
            point = curve.generator * private
            return bytearray(point.to_bytes(point_format))

    def calc_shared_key(self, private, peer_share,
            valid_point_formats=('uncompressed',)):
        """
        Calculate the shared key.

        :param bytearray | SigningKey private: private value

        :param bytearray peer_share: public value

        :param set(str) valid_point_formats: list of point formats that
            the peer share can be in; ["uncompressed"] by default.

        :rtype: bytearray
        :returns: shared key

        :raises TLSIllegalParameterException
            when the paramentrs for point are invalid.
        :raises TLSDecodeError
            when the the valid_point_formats is empty.

        """
        if self.group in self._x_groups:
            fun, _, size = self._get_fun_gen_size()
            if len(peer_share) != size:
                raise TLSIllegalParameterException("Invalid key share")
            if isinstance(private, ecdsa.keys.SigningKey):
                private = bytesToNumber(private.to_string())
            S = fun(private, peer_share)
            self._non_zero_check(S)
            return S

        curve = getCurveByName(GroupName.toRepr(self.group))
        try:
            abstractPoint = ecdsa.ellipticcurve.AbstractPoint()
            point = abstractPoint.from_bytes(
                curve.curve, peer_share, valid_encodings=valid_point_formats)
            ecdhYc = ecdsa.ellipticcurve.Point(
                curve.curve, point[0], point[1])

        except AssertionError:
            raise TLSIllegalParameterException("Invalid ECC point")
        except DecodeError:
            raise TLSDecodeError("Empty point formats extension")
        if isinstance(private, ecdsa.keys.SigningKey):
            ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private)
            ecdh.load_received_public_key_bytes(peer_share,
                                                valid_encodings=
                                                valid_point_formats)
            return bytearray(ecdh.generate_sharedsecret_bytes())
        S = ecdhYc * private

        return numberToByteArray(S.x(), getPointByteSize(ecdhYc))


class KEMKeyExchange(object):
    """
    Implementation of the Hybrid KEM key exchange groups.

    Caution, KEMs are not symmetric! While they client calls the
    same get_random_private_key(), calc_public_value(), and calc_shared_key()
    as in FFDH or ECDH, the server calls just the encapsulate_key() method.
    """

    def __init__(self, group, version):
        if not ML_KEM_AVAILABLE:
            raise TLSInternalError("kyber-py library not installed!")
        self.group = group
        assert version == (3, 4)
        del version

        if self.group not in GroupName.allKEM:
            raise TLSInternalError("called with wrong group")

        if self.group == GroupName.secp256r1mlkem768:
            self._classic_group = GroupName.secp256r1
        elif self.group == GroupName.x25519mlkem768:
            self._classic_group = GroupName.x25519
        else:
            assert self.group == GroupName.secp384r1mlkem1024
            self._classic_group = GroupName.secp384r1

    def get_random_private_key(self):
        """
        Generates a random value to be used as the private key in KEM.

        To be used only to generate the KeyShare in ClientHello.
        """

        if self.group not in GroupName.allKEM:
            raise TLSInternalError("called with wrong group")
        if self.group in (GroupName.secp256r1mlkem768,
                          GroupName.x25519mlkem768):
            pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen()
        else:
            pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen()

        classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
        classic_key = classic_kex.get_random_private_key()

        return ((pqc_pub_key, pqc_priv_key), classic_key)

    def calc_public_value(self, private, point_format='uncompressed'):
        """
        Extract public values for the private key.

        To be used only to generate the KeyShare in ClientHello.

        :param str point_format: Point format of the ECDH portion of the
            key exchange (effective only for NIST curves, valid is
            'uncompressed' only)
        """
        classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))

        classic_pub_key_share = classic_kex.calc_public_value(
            private[1], point_format=point_format)

        if self.group == GroupName.x25519mlkem768:
            return private[0][0] + classic_pub_key_share
        return classic_pub_key_share + private[0][0]

    @staticmethod
    def _split_key_shares(public, pqc_first, pqc_key_len, classic_key_len):
        if len(public) != classic_key_len + pqc_key_len:
            raise TLSIllegalParameterException(
                "Invalid key size for the selected group. "
                "Expected: {0}, received: {1}".format(
                    classic_key_len + pqc_key_len,
                    len(public)))

        if pqc_first:
            pqc_key = public[:pqc_key_len]
            classic_key_share = bytearray(public[pqc_key_len:])
        else:
            classic_key_share = bytearray(public[:classic_key_len])
            pqc_key = public[classic_key_len:]

        return pqc_key, classic_key_share

    def _group_to_params(self):
        """Returns a tuple:
        classic_key_len, pqc_ek_key_len, pqc_ciphertext_len, pqc_first, ML_KEM
        """
        if self.group == GroupName.secp256r1mlkem768:
            classic_key_len = 65
            pqc_key_len = 1184
            pqc_ciphertext_len = 1088
            pqc_first = False
            ml_kem = ML_KEM_768
        elif self.group == GroupName.x25519mlkem768:
            classic_key_len = 32
            pqc_key_len = 1184
            pqc_ciphertext_len = 1088
            pqc_first = True
            ml_kem = ML_KEM_768
        else:
            assert self.group == GroupName.secp384r1mlkem1024
            classic_key_len = 97
            pqc_key_len = 1568
            pqc_ciphertext_len = 1568
            pqc_first = False
            ml_kem = ML_KEM_1024

        return classic_key_len, pqc_key_len, pqc_ciphertext_len, pqc_first, \
            ml_kem

    def encapsulate_key(self, public):
        """
        Generate a random secret, encapsulate it given the public key,
        and return both the random secret and encapsulation of it.

        To be used for generation of KeyShare in ServerHello.
        """
        classic_key_len, pqc_key_len, _, pqc_first, ml_kem = \
            self._group_to_params()

        pqc_key, classic_key_share = self._split_key_shares(
            public, pqc_first, pqc_key_len, classic_key_len)

        classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
        classic_key = classic_kex.get_random_private_key()
        classic_my_key_share = classic_kex.calc_public_value(classic_key)
        classic_shared_secret = classic_kex.calc_shared_key(
            classic_key, classic_key_share)

        try:
            pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key)
        except ValueError:
            raise TLSIllegalParameterException(
                "Invalid PQC key from peer")

        if pqc_first:
            shared_secret = pqc_shared_secret + classic_shared_secret
            key_encapsulation = pqc_encaps + classic_my_key_share
        else:
            shared_secret = classic_shared_secret + pqc_shared_secret
            key_encapsulation = classic_my_key_share + pqc_encaps

        return shared_secret, key_encapsulation

    def calc_shared_key(self, private, key_encaps):
        """
        Decapsulate the key share received from server.
        """
        classic_key_len, _, pqc_key_len, pqc_first, ml_kem = \
            self._group_to_params()

        pqc_key, classic_key_share = self._split_key_shares(
            key_encaps, pqc_first, pqc_key_len, classic_key_len)

        classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
        classic_shared_secret = classic_kex.calc_shared_key(
                private[1], classic_key_share)

        try:
            pqc_shared_secret = ml_kem.decaps(private[0][1], pqc_key)
        except ValueError:
            raise TLSIllegalParameterException(
                "Error in KEM decapsulation")

        if pqc_first:
            shared_secret = pqc_shared_secret + classic_shared_secret
        else:
            shared_secret = classic_shared_secret + pqc_shared_secret

        return shared_secret