tomato42/tlsfuzzer

View on GitHub
tlsfuzzer/expect.py

Summary

Maintainability
F
6 days
Test Coverage
A
98%
# Author: Hubert Kario, (c) 2015
# Released under Gnu GPL v2.0, see LICENSE file for details

"""Parsing and processing of received TLS messages"""
from __future__ import print_function

import itertools
from functools import partial
import sys
import time

import tlslite.utils.tlshashlib as hashlib
from tlslite.constants import ContentType, HandshakeType, CertificateType,\
        HashAlgorithm, SignatureAlgorithm, ExtensionType,\
        SSL2HandshakeType, CipherSuite, GroupName, AlertDescription, \
        SignatureScheme, TLS_1_3_HRR, HeartbeatMode, \
        TLS_1_1_DOWNGRADE_SENTINEL, TLS_1_2_DOWNGRADE_SENTINEL, \
        HeartbeatMessageType, ClientCertificateType, CertificateStatusType
from tlslite.messages import ServerHello, Certificate, ServerHelloDone,\
        ChangeCipherSpec, Finished, Alert, CertificateRequest, ServerHello2,\
        ServerKeyExchange, ClientHello, ServerFinished, CertificateStatus, \
        CertificateVerify, EncryptedExtensions, NewSessionTicket, Heartbeat,\
        KeyUpdate, HelloRequest, NewSessionTicket1_0
from tlslite.extensions import TLSExtension, ALPNExtension
from tlslite.utils.codec import Parser, Writer
from tlslite.utils.compat import b2a_hex
from tlslite.utils.cryptomath import secureHMAC, derive_secret, \
        HKDF_expand_label
from tlslite.mathtls import RFC7919_GROUPS, FFDHE_PARAMETERS, calc_key
from tlslite.keyexchange import KeyExchange, DHE_RSAKeyExchange, \
        ECDHE_RSAKeyExchange
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from tlslite.errors import TLSDecryptionFailed
from tlslite.handshakehashes import HandshakeHashes
from tlslite.handshakehelpers import HandshakeHelpers
from .handshake_helpers import calc_pending_states, kex_for_group, \
        curve_name_to_hash_tls13
from .helpers import ECDSA_SIG_TLS1_3_ALL
from .tree import TreeNode

# pylint: disable=import-error,no-name-in-module
# pylint: disable=bad-option-value,deprecated-class
if sys.version_info >= (3, 3):
    from collections.abc import Iterable
else:
    from collections import Iterable
# pylint: enable=bad-option-value,deprecated-class
# pylint: enable=import-error,no-name-in-module


class Expect(TreeNode):
    """Base class for objects handling message readers"""

    def __init__(self, content_type):
        """Prepare the class for handling tree graph"""
        super(Expect, self).__init__()
        self.content_type = content_type

    def is_expect(self):
        """Flag to tell if the object is a message processor"""
        return True

    def is_command(self):
        """Flag to tell that the object is a message processor"""
        return False

    def is_generator(self):
        """Flag to tell that the object is not a message generator"""
        return False

    def is_match(self, msg):
        """
        Checks if the object can handle message

        Note that the msg is a raw, unparsed message of indicated type that
        requires calling write() to get a raw bytearray() representation of it

        :type msg: tlslite.messages.Message
        :param msg: raw message to check
        """
        if msg.contentType == self.content_type:
            return True

        return False

    def process(self, state, msg):
        """
        Process the message and update the state accordingly.

        :type state: tlsfuzzer.runner.ConnectionState
        :param state: current connection state, needs to be updated after
            parsing the message by inheriting classes
        :type msg: tlslite.messages.Message
        :param msg: raw message to parse
        """
        raise NotImplementedError("Subclasses need to implement this!")


class ExpectMessage(Expect):
    """Common methods for handling TLS messages."""

    @staticmethod
    def _cmp_eq(our, recv, field_type=None, f_str=None):
        """
        Check if expected value matched received, if defined.

        If our is not None, compare with recv. If they don't match, try
        translating them with field_type.toStr() method and rise
        AssertionError with message formatted with f_str. First parameter
        to .format() will be expected value and the second one will be the
        received one
        """
        if our is None or our == recv:
            return

        if field_type:
            expected = field_type.toStr(our)
            received = field_type.toStr(recv)
        else:
            expected = our
            received = recv

        if not f_str:
            f_str = "Expected: {0}, received: {1}"
        raise AssertionError(f_str.format(expected, received))

    @classmethod
    def _cmp_eq_or_in(cls, our, recv, field_type=None, f_str=None):
        """
        Check if received value equals expected or is in expected list.

        If our is a list or set, check if recv is in it.
        If our is not None, check if it's equal to recv.
        If they don't match or are not part of a set, try translating
        them with field_type.toStr() method and raise AssertionError
        formatted with f_str. First parameter to .format() will be
        the expected value and the second one witll be the
        received one.
        """
        if our is None:
            return
        try:
            if recv in our:
                return
        except TypeError:
            return cls._cmp_eq(our, recv, field_type, f_str)

        # doesn't match, so prepare the error message
        if field_type:
            expected = "({0})".format(", ".join(
                field_type.toStr(i) for i in our))
            received = field_type.toStr(recv)
        else:
            expected = our
            received = recv

        if not f_str:
            f_str = "Received value ({1}) not in expected list: {0}"
        raise AssertionError(f_str.format(expected, received))

    @staticmethod
    def _cmp_eq_list(our, recv, field_type=None, f_str=None):
        """
        Check if expected list of values matched received, if defined.

        If our is not None, compare with recv. If they don't match, try
        translating items in the lists with field_type.toStr() method and rise
        AssertionError with message formatted with f_str. First parameter
        to .format() will be list of expected values and the second one will be
        the received one
        """
        if our is None or our == recv:
            return

        if field_type:
            expected = ", ".join(field_type.toStr(i) for i in our)
            expected = "({0})".format(expected)
            received = ", ".join(field_type.toStr(i) for i in recv)
            received = "({0})".format(received)
        else:
            expected = repr(our)
            received = repr(recv)

        if not f_str:
            f_str = "Expected: {0}, received: {1}"
        raise AssertionError(f_str.format(expected, received))


class ExpectHandshake(ExpectMessage):
    """Common methods for handling TLS Handshake protocol messages"""

    def __init__(self, content_type, handshake_type):
        """
        Set the type of message

        :type content_type: int
        :type handshake_type: int
        """
        super(ExpectHandshake, self).__init__(content_type)
        self.handshake_type = handshake_type

    def is_match(self, msg):
        """Check if message is a given type of handshake protocol message"""
        if not super(ExpectHandshake, self).is_match(msg):
            return False

        if not msg.write():  # if message is empty
            return False

        hs_type = Parser(msg.write()).get(1)
        if hs_type != self.handshake_type:
            return False

        return True

    def process(self, state, msg):
        raise NotImplementedError("Subclass need to implement this!")


def srv_ext_handler_ems(state, extension):
    """Process Extended Master Secret extension from server."""
    if extension.extData:
        raise AssertionError("Malformed EMS extension, data in payload")

    state.extended_master_secret = True


def srv_ext_handler_etm(state, extension):
    """Process Encrypt then MAC extension from server."""
    if extension.extData:
        raise AssertionError("Malformed EtM extension, data in payload")

    state.encrypt_then_mac = True


def srv_ext_handler_sni(state, extension):
    """Process the server_name extension from server."""
    del state  # kept for comatibility
    if extension.extData:
        raise AssertionError("Malformed SNI extenion, data in payload")


def srv_ext_handler_renego(state, extension):
    """Process the renegotiation_info from server."""
    if extension.renegotiated_connection != \
            state.key['client_verify_data'] + state.key['server_verify_data']:
        raise AssertionError("Invalid data in renegotiation_info")


def srv_ext_handler_alpn(state, extension):
    """Process the ALPN extension from server."""
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.alpn)
    # the sent extension might have been provided with explicit encoding
    cln_ext = ALPNExtension().parse(Parser(cln_ext.extData))

    if not extension.protocol_names or len(extension.protocol_names) != 1:
        raise AssertionError("Malformed ALPN extension")
    if extension.protocol_names[0] not in cln_ext.protocol_names:
        raise AssertionError("Server selected ALPN protocol we did not "
                             "advertise")


def srv_ext_handler_ec_point(state, extension):
    """Process the ec_point_formats extension from server."""
    del state
    if extension.formats is None or not extension.formats:
        raise AssertionError("Malformed ec_point_formats extension")


def srv_ext_handler_npn(state, extension):
    """Process the NPN extension from server."""
    del state
    if extension.protocols is None or not extension.protocols:
        raise AssertionError("Malformed NPN extension")


def srv_ext_handler_session_ticket(state, extension):
    """Process the session_ticket extension from server."""
    del state
    if extension.ticket != b"":
        raise AssertionError("Malformed session_ticket extension")


def srv_ext_handler_key_share(state, extension):
    """Process the key_share extension from server."""
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.key_share)

    group_id = extension.server_share.group

    cl_ext = next((i for i in cln_ext.client_shares if i.group == group_id),
                  None)
    if cl_ext is None:
        raise AssertionError("Server selected group we didn't advertise: {0}"
                             .format(GroupName.toStr(group_id)))

    kex = kex_for_group(group_id, state.version)

    state.key['ServerHello.extensions.key_share.key_exchange'] = \
        extension.server_share.key_exchange

    if not cl_ext.private:
        raise ValueError("private value for key share of group {0} missing"
                         .format(GroupName.toStr(group_id)))
    z = kex.calc_shared_key(cl_ext.private,
                            extension.server_share.key_exchange)

    state.key['DH shared secret'] = z


def hrr_ext_handler_key_share(state, extension):
    """Process the key_share extension in HRR message."""
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.supported_groups)

    group_id = extension.selected_group

    if group_id not in cln_ext.groups:
        raise AssertionError("Server selected group we didn't advertise: {0}"
                             .format(GroupName.toStr(group_id)))


def hrr_ext_handler_cookie(state, extension):
    """Process the cookie extension in HRR message."""
    del state
    if not extension.cookie:
        raise AssertionError("Server sent empty cookie extension")


def srv_ext_handler_supp_vers(state, extension):
    """Process the supported_versions from server."""
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.supported_versions)

    vers = extension.version

    if vers not in cln_ext.versions:
        raise AssertionError("Server selected version we didn't advertise: {0}"
                             .format(vers))

    state.version = vers


def srv_ext_handler_supp_groups(state, extension):
    """Process the supported_groups from server."""
    del state
    if not extension.groups:
        raise AssertionError("Server did not send any supported_groups")


def srv_ext_handler_status_request(state, extension):
    """
    Process the status_request extension from server.

    TLS 1.2 ServerHello specific, in TLS 1.3 the extension resides in
    Certificate message.
    """
    del state
    if extension.status_type is not None or \
            extension.responder_id_list != [] or \
            extension.request_extensions != bytearray():
        raise AssertionError("Server did send non empty status_request "
                             "extension")


def srv_ext_handler_heartbeat(state, extension):
    """Process the heartbeat extension from server."""
    del state
    if not extension.mode:
        raise AssertionError("Empty mode in heartbeat extension.")
    if extension.mode != HeartbeatMode.PEER_ALLOWED_TO_SEND and \
       extension.mode != HeartbeatMode.PEER_NOT_ALLOWED_TO_SEND:
        raise AssertionError("Invalid mode in heartbeat extension.")


def _srv_ext_handler_psk(state, extension, psk_configs):
    """Process the pre_shared_key extension from server.

    Since it needs the psk_configurations, it can't do it automatically
    so it shouldn't be part of _srv_ext_handler.
    """
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.pre_shared_key)

    # the selection is 0-based
    if extension.selected >= len(cln_ext.identities):
        raise AssertionError("Server selected PSK we didn't send")

    ident = cln_ext.identities[extension.selected].identity
    if state.session_tickets:
        nst = state.session_tickets[-1]
        if nst.ticket == ident:
            state.key['PSK secret'] = HandshakeHelpers.calc_res_binder_psk(
                cln_ext.identities[extension.selected],
                state.key['resumption master secret'],
                [nst])
            return
    secret = next((i[1] for i in psk_configs if i[0] == ident), None)
    if not secret:
        raise ValueError("psk_configs are missing identity")

    state.key['PSK secret'] = secret


def gen_srv_ext_handler_psk(psk_configs=tuple()):
    """Creates a handler for pre_shared_key extension from the server."""
    return partial(_srv_ext_handler_psk, psk_configs=psk_configs)


def _srv_ext_handler_record_limit(state, extension, size=None):
    """Process record_size_limit extension from server."""
    cln_hello = state.get_last_message_of_type(ClientHello)
    cln_ext = cln_hello.getExtension(ExtensionType.record_size_limit)

    assert extension.record_size_limit is not None
    assert 64 <= extension.record_size_limit <= 2**14 + \
        int(state.version > (3, 3))

    if size and extension.record_size_limit != size:
        raise AssertionError("Server sent unexpected size in extension, "
                             "expected size: {0}, received size: {1}"
                             .format(size, extension.record_size_limit))

    if state.version <= (3, 3):
        # in TLS 1.2 and earlier we need to delay that to processing of
        # server CCS
        state._peer_record_size_limit = extension.record_size_limit
        state._our_record_size_limit = min(2**14, cln_ext.record_size_limit)
    else:
        # in TLS 1.3 we need to implement it right away (as the extension
        # applies only to encrypted messages)
        # the RecordLayer expects value that excludes content type
        state.msg_sock.recv_record_limit = min(
            2**14,
            cln_ext.record_size_limit-1)
        # this is just hint for padding callback
        state.msg_sock.send_record_limit = min(
            2**14,
            extension.record_size_limit-1)
        # this guides fragmentation
        state.msg_sock.recordSize = state.msg_sock.send_record_limit


def gen_srv_ext_handler_record_limit(size=None):
    """
    Create a handler for record_size_limit_extension from the server.

    Note that if the extension is actually negotiated, it will override
    any `~SetMaxRecordSize()` before EncryptedExtensions in TLS 1.3 and
    before ChangeCipherSpec in TLS 1.2 and earlier.

    :param int size: expected value from server, None for any valid
    """
    return partial(_srv_ext_handler_record_limit, size=size)


def clnt_ext_handler_status_request(state, extension):
    """
    Check status_request extension from initiating side.

    To be used in ClientHello and CertificateRequest
    """
    del state  # kept for compatibility
    if extension.status_type != CertificateStatusType.ocsp:
        raise AssertionError(
            "Unexpected status_type in status_request extension: {0}"
            .format(CertificateStatusType.toStr(extension.status_type)))
    if extension.responder_id_list is None \
            or extension.request_extensions is None:
        raise AssertionError(
            "Malformed status_request extension")


def clnt_ext_handler_sig_algs(state, extension):
    """
    Check signature_algorithms or signature_algorithms_cert extension.

    To be used in ClientHello and CertificateRequest.
    """
    del state  # kept for API compatibility
    if not extension.sigalgs:
        raise AssertionError(
            "Empty or malformed {0} extension"
            .format(ExtensionType.toStr(extension.extType)))


_srv_ext_handler = \
        {ExtensionType.extended_master_secret: srv_ext_handler_ems,
         ExtensionType.encrypt_then_mac: srv_ext_handler_etm,
         ExtensionType.server_name: srv_ext_handler_sni,
         ExtensionType.renegotiation_info: srv_ext_handler_renego,
         ExtensionType.alpn: srv_ext_handler_alpn,
         ExtensionType.session_ticket: srv_ext_handler_session_ticket,
         ExtensionType.ec_point_formats: srv_ext_handler_ec_point,
         ExtensionType.supports_npn: srv_ext_handler_npn,
         ExtensionType.key_share: srv_ext_handler_key_share,
         ExtensionType.supported_versions: srv_ext_handler_supp_vers,
         ExtensionType.heartbeat: srv_ext_handler_heartbeat,
         ExtensionType.record_size_limit: _srv_ext_handler_record_limit,
         ExtensionType.status_request: srv_ext_handler_status_request}


_HRR_EXT_HANDLER = \
        {ExtensionType.key_share: hrr_ext_handler_key_share,
         ExtensionType.cookie: hrr_ext_handler_cookie}


_EE_EXT_HANDLER = \
        {ExtensionType.server_name: srv_ext_handler_sni,
         ExtensionType.alpn: srv_ext_handler_alpn,
         ExtensionType.supported_groups: srv_ext_handler_supp_groups,
         ExtensionType.heartbeat: srv_ext_handler_heartbeat,
         ExtensionType.record_size_limit: _srv_ext_handler_record_limit}


_CR_EXT_HANDLER = \
        {ExtensionType.status_request: clnt_ext_handler_status_request,
         ExtensionType.signature_algorithms: clnt_ext_handler_sig_algs,
         ExtensionType.signature_algorithms_cert: clnt_ext_handler_sig_algs}


class _ExpectExtensionsMessage(ExpectHandshake):
    """
    Common methods of messages that have a list of extensions.

    Used in ServerHello, EncryptedExtensions and CertificateRequest (in
    TLS 1.3)
    """
    def __init__(self, content_type, msg_type, extensions):
        super(_ExpectExtensionsMessage, self).__init__(
            content_type, msg_type)
        self.extensions = extensions

    def _compare_extensions(self, message):
        """
        Verify that server provided extensions match exactly expected list.
        """
        # if the list of extensions is present, make sure it matches exactly
        # with what the server sent
        if self.extensions and not message.extensions:
            raise AssertionError("Server did not send any extensions")
        if self.extensions is not None and message.extensions:
            expected = set(self.extensions.keys())
            got = set(i.extType for i in message.extensions)
            if got != expected:
                diff = expected.difference(got)
                if diff:
                    raise AssertionError("Server did not send extension(s): "
                                         "{0}".format(
                                             ", ".join((ExtensionType.toStr(i)
                                                        for i in diff))))
                diff = got.difference(expected)
                # we already checked if got != expected so diff here
                # must be non-empty if the one checked above is
                assert diff
                raise AssertionError("Server sent unexpected extension(s):"
                                     " {0}".format(
                                         ", ".join(ExtensionType.toStr(i)
                                                   for i in diff)))


class ExpectServerHello(_ExpectExtensionsMessage):
    """
    Parsing TLS Handshake protocol Server Hello messages.

    Processing of the ServerHello message updates the record layer
    to the version advertisied by the server.
    Use :py:class:`~tlsfuzzer.messages.SetRecordVersion` to change it earlier
    to send records with different versions.

    .. note::
      Receiving of the ServerHello in TLS 1.3 influences record layer
      encryption. After the message is received, the
      ``client_handshake_traffic_secret`` and
      ``server_handshake_traffic_secret``
      is derived and record layer is configured to expect encrypted records
      on the *receiving* side.

    :ivar str ~.description: identifier to print when processing of the
        node fails
    """

    def __init__(self, extensions=None, version=None, resume=False,
                 cipher=None, server_max_protocol=None, force_resume=False,
                 description=None):
        """
        Initialize the object

        :param dict extensions: extension objects to match the server sent
        extensions or callbacks to process and verify them. None means use
        automatic handlers that will verify the response against the extensions
        sent in ClientHello. Empty dict means that the server is expected to
        send no extensions. Order does not matter, but all extensions present
        and only extensions present in the list must be sent by server. None
        as the value of the relevant extension type can be used to select
        autohandler for a given extension type.

        :param tuple version: the literal version in the Server Hello message
        (needs to be (3, 3) for TLS 1.3, use extensions to expect TLS 1.3
        negotiation)

        :param tuple server_max_protocol: the higher protocol version supported
        by server. Used for testing downgrade signaling of servers.

        :type cipher: int or set-like
        :param int cipher: the id of the cipher that is expected to be
        negotiated by server. Can also be a list or set (needs to support
        ``in``) for a set of allowed ciphers.
        None (the default) means any valid cipher
        (i.e. not SCSV or GREASE) sent in ClientHello can be selected by
        server.

        :type resume: boolean
        :param resume: whether the session id should match the one from
        current state - IOW, if the server hello should belong to a resumed
        session. TLS 1.2 and earlier only. In TLS 1.3 resumption is handled
        by providing handler for ``pre_shared_key`` extension.

        :param boolean force_resume: assume that the session is getting resumed,
            even if the sessionID is empty. Applicable to TLS 1.2 and earlier
            only when using session tickets and not sending a sessionID.
        """
        super(ExpectServerHello, self).__init__(ContentType.handshake,
                                                HandshakeType.server_hello,
                                                extensions)
        self.cipher = cipher
        self.version = version
        self.resume = resume
        self.srv_max_prot = server_max_protocol
        self.force_resume = force_resume
        self.description = description

    def __str__(self):
        """Return human redable representation of the object."""
        if self.description:
            return "ExpectServerHello(description={0!r})"\
                   .format(self.description)
        return "ExpectServerHello()"

    @staticmethod
    def _get_autohandler(ext_id):
        try:
            return _srv_ext_handler[ext_id]
        except KeyError:
            raise AssertionError("No autohandler for "
                                 "{0}"
                                 .format(ExtensionType
                                         .toStr(ext_id)))

    def _process_extensions(self, state, cln_hello, srv_hello):
        """Check if extensions are correct."""
        # extensions allowed in TLS 1.3 ServerHello and HelloRetryRequest
        # messages (as some need to be echoed by server in EncryptedExtensions
        # and some in Certificate)
        sh_supported = [ExtensionType.pre_shared_key,
                        ExtensionType.supported_versions,
                        ExtensionType.key_share]
        hrr_supported = [ExtensionType.cookie,
                         ExtensionType.supported_versions,
                         ExtensionType.key_share]
        for ext in srv_hello.extensions:
            ext_id = ext.extType
            if state.version > (3, 3) and \
                    ((srv_hello.random != TLS_1_3_HRR and
                      ext_id not in sh_supported) or
                     (srv_hello.random == TLS_1_3_HRR and
                      ext_id not in hrr_supported)):
                raise AssertionError("Server sent unallowed "
                                     "extension of type {0}"
                                     .format(ExtensionType
                                             .toStr(ext_id)))
            # in TLS 1.2 generally the server can reply to any client sent
            # extension, and all of them end in ClientHello
            cl_ext = cln_hello.getExtension(ext_id)
            if ext_id == ExtensionType.renegotiation_info and \
                    CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV \
                    in cln_hello.cipher_suites:
                cl_ext = True
            if isinstance(self, ExpectHelloRetryRequest) and \
                    ext_id == ExtensionType.cookie:
                cl_ext = True
            if cl_ext is None:
                raise AssertionError("Server sent unadvertised "
                                     "extension of type {0}"
                                     .format(ExtensionType
                                             .toStr(ext_id)))
            handler = None
            if self.extensions:
                handler = self.extensions[ext_id]

            # use automatic handlers for some extensions
            if handler is None:
                handler = self._get_autohandler(ext_id)

            if callable(handler):
                handler(state, ext)
            elif isinstance(handler, TLSExtension):
                if not handler == ext:
                    raise AssertionError("Expected extension not "
                                         "matched for type {0}, "
                                         "received: {1}"
                                         .format(ExtensionType
                                                 .toStr(ext_id),
                                                 ext))
            else:
                raise ValueError("Bad extension handler for id {0}"
                                 .format(ExtensionType.toStr(ext_id)))

    @staticmethod
    def _extract_version(msg):
        """Extract the real version from the message if TLS 1.3 is in use."""
        ext = msg.getExtension(ExtensionType.supported_versions)

        # RFC 8446 "legacy_version field MUST be set to 0x0303"
        if msg.server_version > (3, 3):
            raise ValueError("Server sent invalid version in legacy_version "
                             "field")

        if ext and msg.server_version == (3, 3):
            return ext.version

        return msg.server_version

    def process(self, state, msg):
        """
        Process the message and update state accordingly

        :type state: ConnectionState
        :param state: overall state of TLS connection

        :type msg: Message
        :param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.server_random = srv_hello.random

        cln_hello = state.get_last_message_of_type(ClientHello)

        # check for session_id based session resumption
        if self.resume:
            assert state.session_id == srv_hello.session_id
        if self.force_resume or ((state.session_id == srv_hello.session_id
                or cln_hello.session_id == srv_hello.session_id) and
                srv_hello.session_id != bytearray(0) and
                self._extract_version(srv_hello) < (3, 4)):
            # TLS 1.2 resumption, TLS 1.3 is based on PSKs
            state.resuming = True
            assert state.cipher == srv_hello.cipher_suite
            assert state.version == self._extract_version(srv_hello)
        state.session_id = srv_hello.session_id

        self._cmp_eq(self.version, srv_hello.server_version,
                     f_str="Server selected unexpected protocol version. "
                           "Expected: {0}, received: {1}.")

        self._cmp_eq_or_in(
            self.cipher, srv_hello.cipher_suite,
            f_str="Server selected unexpected ciphersuite. "
                  "Expected: {0}, received: {1}.")

        # check if server sent cipher matches what we advertised in CH
        if srv_hello.cipher_suite not in cln_hello.cipher_suites:
            cipher = srv_hello.cipher_suite
            if cipher in CipherSuite.ietfNames:
                name = "{0} ({1:#06x})".format(CipherSuite.ietfNames[cipher],
                                               cipher)
            else:
                name = "{0:#06x}".format(cipher)
            raise AssertionError("Server responded with cipher we did"
                                 " not advertise: {0}".format(name))

        state.cipher = srv_hello.cipher_suite
        state.version = self._extract_version(srv_hello)

        # update the state of connection
        state.msg_sock.version = state.version
        state.msg_sock.tls13record = state.version > (3, 3)

        self._check_against_hrr(state, srv_hello)

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # Reset value of the session-wide settings
        state.extended_master_secret = False
        state.encrypt_then_mac = False

        self._check_downgrade_protection(srv_hello)

        self._compare_extensions(srv_hello)

        if srv_hello.extensions:
            self._process_extensions(state, cln_hello, srv_hello)

        if state.version > (3, 3):
            self._setup_tls13_handshake_keys(state)
        return srv_hello

    @staticmethod
    def _check_against_hrr(state, srv_hello):
        if state.version < (3, 4):
            return

        hrr = state.get_last_message_of_type(ServerHello)
        if not hrr or hrr.random != TLS_1_3_HRR:
            # not an HRR, so HRR tests don't apply to it
            return

        if hrr.cipher_suite != srv_hello.cipher_suite:
            raise AssertionError("Server picked different cipher suite than "
                                 "it advertised in HelloRetryRequest")

        hrr_version = hrr.getExtension(ExtensionType.supported_versions)
        sh_version = srv_hello.getExtension(ExtensionType.supported_versions)

        if hrr_version.version != sh_version.version:
            raise AssertionError("Server picked different protocol version "
                                 "than it advertised in HelloRetryRequest")

    def _setup_tls13_handshake_keys(self, state):
        """Set up the encryption keys for the TLS 1.3 handshake."""
        del self
        prf_name = state.prf_name
        prf_size = state.prf_size

        # Derive PSK secret
        psk = state.key.setdefault('PSK secret', bytearray(prf_size))

        # Derive TLS 1.3 early secret
        secret = bytearray(prf_size)
        secret = secureHMAC(secret, psk, prf_name)
        state.key['early secret'] = secret

        # Derive TLS 1.3 handshake secret
        secret = derive_secret(secret, b'derived', None, prf_name)
        dh_secret = state.key.setdefault('DH shared secret',
                                         bytearray(prf_size))
        secret = secureHMAC(secret, dh_secret, prf_name)
        state.key['handshake secret'] = secret

        # Derive TLS 1.3 traffic secrets
        s_traffic_secret = derive_secret(secret, b's hs traffic',
                                         state.handshake_hashes,
                                         prf_name)
        state.key['server handshake traffic secret'] = s_traffic_secret
        c_traffic_secret = derive_secret(secret, b'c hs traffic',
                                         state.handshake_hashes,
                                         prf_name)
        state.key['client handshake traffic secret'] = c_traffic_secret

        state.msg_sock.calcTLS1_3PendingState(
            state.cipher, c_traffic_secret, s_traffic_secret, None)

        state.msg_sock.changeReadState()

    def _check_downgrade_protection(self, srv_hello):
        """
        Verify that server provided downgrade protection as specified in
        RFC 8446, Section 4.1.3
        """
        # even if we don't know which version server supports, some values
        # are obviously incorrect:
        if (self._extract_version(srv_hello) > (3, 3) and
                srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL) or \
                (self._extract_version(srv_hello) > (3, 2) and
                 srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL):
            raise AssertionError(
                "Server set downgrade protection sentinel but shouldn't "
                "have done that")
        # as we're doing both TLS 1.2 tests and TLS 1.3 tests with `scripts/`
        # we don't know when setting the sentinel is expected and when
        # it is not as the negotiation might have ended up with TLS 1.2
        # because that was the highest version we advertised
        if self.srv_max_prot is None:
            return

        downgrade_value = None
        if self.srv_max_prot > (3, 3) \
                and self._extract_version(srv_hello) == (3, 3):
            downgrade_value = TLS_1_2_DOWNGRADE_SENTINEL
        elif self.srv_max_prot > (3, 2) \
                and self._extract_version(srv_hello) < (3, 3):
            downgrade_value = TLS_1_1_DOWNGRADE_SENTINEL
        else:
            if srv_hello.random[24:] == TLS_1_1_DOWNGRADE_SENTINEL or \
                srv_hello.random[24:] == TLS_1_2_DOWNGRADE_SENTINEL:
                raise AssertionError(
                    "Server set downgrade protection sentinel but shouldn't "
                    "have done that")

        if downgrade_value is not None:
            if srv_hello.random[24:] != downgrade_value:
                raise AssertionError(
                    "Server failed to set downgrade protection sentinel in "
                    "ServerHello.random value")


class ExpectHelloRetryRequest(ExpectServerHello):
    """Processing of the TLS 1.3 HelloRetryRequest message."""

    def __init__(self, extensions=None, version=None, cipher=None):
        super(ExpectHelloRetryRequest, self).__init__(
            extensions, version, cipher)
        self._ch_hh = None
        self._msg = None

    def process(self, state, msg):
        self._ch_hh = state.handshake_hashes.copy()
        self._msg = msg
        hrr = super(ExpectHelloRetryRequest, self).process(state, msg)
        assert hrr.random == TLS_1_3_HRR

    @staticmethod
    def _get_autohandler(ext_id):
        try:
            return _HRR_EXT_HANDLER[ext_id]
        except KeyError:
            try:
                return _srv_ext_handler[ext_id]
            except KeyError:
                raise AssertionError("No autohandler for {0}".format(
                    ExtensionType.toStr(ext_id)))

    def _setup_tls13_handshake_keys(self, state):
        """Prepare handshake ciphers for the HRR handling"""
        prf_name = state.prf_name

        ch_hash = self._ch_hh.digest(prf_name)
        new_hh = HandshakeHashes()
        writer = Writer()
        writer.add(HandshakeType.message_hash, 1)
        writer.addVarSeq(ch_hash, 1, 3)
        new_hh.update(writer.bytes)

        new_hh.update(self._msg.write())

        state.handshake_hashes = new_hh


class ExpectServerHello2(ExpectHandshake):
    """Processing of SSLv2 Handshake Protocol SERVER-HELLO message"""

    def __init__(self, version=None):
        c_type = ContentType.handshake
        h_type = SSL2HandshakeType.server_hello
        super(ExpectServerHello2, self).__init__(c_type,
                                                 h_type)
        self.version = version

    def process(self, state, msg):
        """
        Process the message and update state accordingly

        :type state: `~ConnectionState`
        :param state: overall state of TLS connection

        :type msg: Message
        :param msg: TLS Message read from socket
        """
        # the value is faked for SSLv2 protocol, but let's just check sanity
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == SSL2HandshakeType.server_hello

        server_hello = ServerHello2().parse(parser)

        state.handshake_messages.append(server_hello)
        state.handshake_hashes.update(msg.write())

        self._cmp_eq(self.version, server_hello.server_version,
                     f_str="Server picked unexpected protocol version."
                           "Expected: {0}, received: {1}.")

        if server_hello.session_id_hit:
            state.resuming = True
        state.session_id = server_hello.session_id
        state.server_random = server_hello.session_id
        state.version = server_hello.server_version
        state.msg_sock.version = server_hello.server_version

        # fake a certificate message so finding the server public key works
        x509 = X509()
        x509.parseBinary(server_hello.certificate)
        cert_chain = X509CertChain([x509])
        certificate = Certificate(CertificateType.x509)
        certificate.create(cert_chain)
        state.handshake_messages.append(certificate)
        # fake message so don't update handshake hashes


class ExpectCertificate(ExpectHandshake):
    """Processing TLS Handshake protocol Certificate messages"""

    def __init__(self, cert_type=CertificateType.x509):
        super(ExpectCertificate, self).__init__(ContentType.handshake,
                                                HandshakeType.certificate)
        self.cert_type = cert_type
        self._old_cert = None
        self._old_cert_bytes = None

    def process(self, state, msg):
        """
        :type state: `~ConnectionState`
        """
        assert msg.contentType == ContentType.handshake

        msg_bytes = msg.write()
        if self._old_cert_bytes is not None and \
                msg_bytes == self._old_cert_bytes:
            cert = self._old_cert
        else:
            parser = Parser(msg_bytes)
            hs_type = parser.get(1)
            assert hs_type == HandshakeType.certificate

            cert = Certificate(self.cert_type, state.version)
            cert.parse(parser)
            self._old_cert_bytes = msg_bytes
            self._old_cert = cert

        state.handshake_messages.append(cert)
        state.handshake_hashes.update(msg_bytes)


class ExpectCertificateVerify(ExpectHandshake):
    """
    Processing TLS Handshake protocol Certificate Verify messages.
    :param tuple(int,int) version: Expected TLS version of the message. If not
    provided will be taken from the state.
    :param tuple(int,int) sig_alg: Expected value of the signature scheme
    created by the server. If not provided it will be compared with signature
    algorithm extension from client hello.
    :param str hash_file: The file where hashes of the signature context will
    be logged
    :param str sig_file: The file where the signatures themselves will be
    logged

    """
    def __init__(
        self, version=None, sig_alg=None, hash_file=None, sig_file=None
    ):
        super(ExpectCertificateVerify, self).__init__(
            ContentType.handshake,
            HandshakeType.certificate_verify)
        self.version = version
        self.sig_alg = sig_alg
        self.hash_file = hash_file
        self.sig_file = sig_file

    def process(self, state, msg):
        """
        :type state: `~ConnectionState`
        """
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.certificate_verify

        if self.version is None:
            self.version = state.version

        cert_v = CertificateVerify(self.version)
        cert_v.parse(parser)

        if self.sig_alg:
            assert self.sig_alg == cert_v.signatureAlgorithm
        else:
            c_hello = state.get_last_message_of_type(ClientHello)
            ext = c_hello.getExtension(ExtensionType.signature_algorithms)
            assert cert_v.signatureAlgorithm in ext.sigalgs
            key_type = state.get_server_public_key().key_type
            if key_type == "rsa-pss":
                # in TLS 1.3 only RSA-PSS signatures are allowed
                assert cert_v.signatureAlgorithm in (
                    SignatureScheme.rsa_pss_pss_sha256,
                    SignatureScheme.rsa_pss_pss_sha384,
                    SignatureScheme.rsa_pss_pss_sha512)
            elif key_type == "rsa":
                # in TLS 1.3 only RSA-PSS signatures are allowed
                assert cert_v.signatureAlgorithm in (
                    SignatureScheme.rsa_pss_rsae_sha256,
                    SignatureScheme.rsa_pss_rsae_sha384,
                    SignatureScheme.rsa_pss_rsae_sha512)
            elif key_type in ("Ed25519", "Ed448"):
                assert cert_v.signatureAlgorithm in (
                    SignatureScheme.ed25519,
                    SignatureScheme.ed448)
                if getattr(SignatureScheme, key_type.lower()) != \
                        cert_v.signatureAlgorithm:
                    raise AssertionError(
                        "Mismatched signature ({0}) for used key ({1})"
                        .format(
                            SignatureScheme.toStr(cert_v.signatureAlgorithm),
                            key_type))
            else:
                assert key_type == "ecdsa"
                curve_name = state.get_server_public_key().curve_name
                assert curve_name in ("NIST256p", "NIST384p", "NIST521p")
                sigalg = cert_v.signatureAlgorithm
                assert sigalg in ECDSA_SIG_TLS1_3_ALL
                hash_name = curve_name_to_hash_tls13(curve_name)
                # in TLS 1.3 the hash is bound to key curve
                if sigalg != (getattr(HashAlgorithm, hash_name),
                              SignatureAlgorithm.ecdsa):
                    raise AssertionError(
                        "Invalid signature type for {1} key, "
                        "received: {0}"
                        .format(SignatureScheme.toStr(sigalg), curve_name))

        salg = cert_v.signatureAlgorithm

        if salg in (SignatureScheme.ed25519, SignatureScheme.ed448):
            hash_name = "intrinsic"
            padding = None
            salt_len = None
        elif salg[1] == SignatureAlgorithm.ecdsa:
            hash_name = HashAlgorithm.toStr(salg[0])
            padding = None
            salt_len = None
        else:
            scheme = SignatureScheme.toRepr(salg)
            hash_name = SignatureScheme.getHash(scheme)
            padding = SignatureScheme.getPadding(scheme)
            salt_len = getattr(hashlib, hash_name)().digest_size

        transcript_hash = state.handshake_hashes.digest(state.prf_name)
        sig_context = bytearray(b'\x20' * 64 +
                                b'TLS 1.3, server CertificateVerify' +
                                b'\x00') + transcript_hash

        if not state.get_server_public_key().hashAndVerify(
                cert_v.signature,
                sig_context,
                padding,
                hash_name,
                salt_len):
            raise AssertionError("Signature verification failed")

        if self.hash_file:
            data = getattr(hashlib, hash_name)(sig_context).digest()
            self.hash_file.write(data)

        if self.sig_file:
            self.sig_file.write(cert_v.signature)

        state.handshake_messages.append(cert_v)
        state.handshake_hashes.update(msg.write())


class ExpectServerKeyExchange(ExpectHandshake):
    """Processing TLS Handshake protocol Server Key Exchange message"""

    def __init__(self, version=None, cipher_suite=None, valid_sig_algs=None,
                 valid_groups=None, valid_params=None):
        """
        Expect ServerKeyExchange message from server.

        :param list(int) valid_groups: TLS group identifiers for groups that
            server can use. In case the groups include identifiers between 256
            and 512 (see RFC 7919), the node will also check that the server
            selected FFDH parameters match the parameters specified in the RFC.

        :param set(tuple(int,int)) valid_params: set of explicit expected
            parameters used by the server, the first element of the tuple
            is the expected generator and the second is the prime used for the
            DH calculation. Applicable only to ciphersuites that use FFDHE
            key exchange.
        """
        msg_type = HandshakeType.server_key_exchange
        super(ExpectServerKeyExchange, self).__init__(ContentType.handshake,
                                                      msg_type)
        self.version = version
        self.cipher_suite = cipher_suite
        self.valid_sig_algs = valid_sig_algs
        self.valid_groups = valid_groups
        self.valid_params = valid_params
        if self.valid_groups and self.valid_params:
            raise ValueError("valid_groups and valid_params are exclusive")

    def _checkParams(self, server_key_exchange):
        groups = []
        if self.valid_groups and any(i in range(256, 512)
                                     for i in self.valid_groups):
            groups = [RFC7919_GROUPS[i - 256] for i in self.valid_groups
                      if i in range(256, 512)]
        if self.valid_params:
            groups = self.valid_params
        server_params = (server_key_exchange.dh_g, server_key_exchange.dh_p)
        if groups and server_params not in groups:
            for name, params in FFDHE_PARAMETERS.items():
                if server_params == params:
                    raise AssertionError(
                        "DH parameters not from valid set, "
                        "received: {0}".format(name))
            raise AssertionError(
                "DH parameters not from valid set, "
                "received: g:{0}, p:{1}".format(
                    hex(server_params[0]),
                    hex(server_params[1])))

    def process(self, state, msg):
        """Process the Server Key Exchange message"""
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_key_exchange

        if self.version is None:
            self.version = state.version
        if self.cipher_suite is None:
            self.cipher_suite = state.cipher
        valid_sig_algs = self.valid_sig_algs
        valid_groups = self.valid_groups

        server_key_exchange = ServerKeyExchange(self.cipher_suite,
                                                self.version)
        server_key_exchange.parse(parser)

        client_random = state.client_random
        server_random = state.server_random
        public_key = state.get_server_public_key()
        server_hello = state.get_last_message_of_type(ServerHello)
        if server_hello is None:
            server_hello = ServerHello
            server_hello.server_version = state.version
        if valid_sig_algs is None:
            # if the value was unset in script, get the advertised value from
            # Client Hello
            client_hello = state.get_last_message_of_type(ClientHello)
            if client_hello is not None:
                sig_algs_ext = client_hello.getExtension(ExtensionType.
                                                         signature_algorithms)
                if sig_algs_ext is not None:
                    valid_sig_algs = sig_algs_ext.sigalgs
            if valid_sig_algs is None:
                # no advertised means support for sha1 only
                valid_sig_algs = [(HashAlgorithm.sha1, SignatureAlgorithm.rsa)]
                if self.cipher_suite in CipherSuite.ecdheEcdsaSuites:
                    valid_sig_algs = [(HashAlgorithm.sha1,
                                       SignatureAlgorithm.ecdsa)]

        try:
            KeyExchange.verifyServerKeyExchange(server_key_exchange,
                                                public_key,
                                                client_random,
                                                server_random,
                                                valid_sig_algs)
        except TLSDecryptionFailed:
            # very rarely validation of signature fails, print it so that
            # we have a chance in debugging it
            print("Bad signature: {0}"
                  .format(b2a_hex(server_key_exchange.signature)),
                  file=sys.stderr)
            raise

        if self.cipher_suite in CipherSuite.dhAllSuites:
            self._checkParams(server_key_exchange)
            state.key_exchange = DHE_RSAKeyExchange(self.cipher_suite,
                                                    clientHello=None,
                                                    serverHello=server_hello,
                                                    privateKey=None)
            state.key['ServerKeyExchange.key_share'] = \
                server_key_exchange.dh_Ys
            state.key['ServerKeyExchange.dh_p'] = server_key_exchange.dh_p
        elif self.cipher_suite in CipherSuite.ecdhAllSuites:
            # extract valid groups from Client Hello
            if valid_groups is None:
                client_hello = state.get_last_message_of_type(ClientHello)
                if client_hello is not None:
                    groups_ext = client_hello.getExtension(ExtensionType.
                                                           supported_groups)
                    if groups_ext is not None:
                        valid_groups = groups_ext.groups
                if valid_groups is None:
                    # no advertised means support for all
                    valid_groups = GroupName.allEC
            state.key_exchange = \
                ECDHE_RSAKeyExchange(self.cipher_suite,
                                     clientHello=None,
                                     serverHello=server_hello,
                                     privateKey=None,
                                     acceptedCurves=valid_groups)
            state.key['ServerKeyExchange.key_share'] = \
                server_key_exchange.ecdh_Ys
        else:
            raise AssertionError("Unsupported cipher selected")
        state.key['premaster_secret'] = state.key_exchange.\
            processServerKeyExchange(public_key,
                                     server_key_exchange)

        state.handshake_messages.append(server_key_exchange)
        state.handshake_hashes.update(msg.write())


# RFC8446 Section 4.2 says that implementation MUST reject extensions
# it recognises but which are not allowed in CertificateRequest
# check it against all defined in RFC8446
TLS_1_3_CR_FORBIDDEN = set((
    ExtensionType.server_name,
    1,  # ExtensionType.max_fragment_length
    ExtensionType.supported_groups,
    14,  # ExtensionType.use_srtp
    ExtensionType.heartbeat,
    ExtensionType.alpn,
    19,  # ExtensionType.client_certificate_type
    20,  # ExtensionType.server_certificate_type
    21,  # ExtensionType.padding,
    ExtensionType.key_share,
    ExtensionType.pre_shared_key,
    ExtensionType.psk_key_exchange_modes,
    ExtensionType.early_data,
    ExtensionType.cookie,
    ExtensionType.supported_versions,
    49  # ExtensionType.post_handshake_auth
    ))


class ExpectCertificateRequest(_ExpectExtensionsMessage):
    """Processing TLS Handshake protocol Certificate Request message."""

    def __init__(self, sig_algs=None, cert_types=None,
                 sanity_check_cert_types=True, extensions=None, context=None):
        """
        Set expected parameters for the CertificateRequest message.

        :param sig_algs: a list of signature algorithms that we are expecting
            from server. Needs to be in-order and complete. ``None`` to accept
            any list from server. Applicable to TLS 1.2 and later only.
            Do not use together with non-default ``extensions``.
        :param cert_types: a list of client certificate types that we are
            expecting from server. Needs to be in-order and complete.
            ``None`` to accept any list from server. Applicable to TLS 1.2 and
            earlier only.
        :param sanity_check_cert_types: set to ``False`` to disable
            verification checking if every signature algorithm has a
            corresponding client certificate type.
        :param extensions: dictionary with extensions that need to be included
            in the message. Set to ``None`` to accept any, set to empty dict to
            expect no extensions. Usable in TLS 1.3 only.
        """
        msg_type = HandshakeType.certificate_request
        super(ExpectCertificateRequest, self).__init__(ContentType.handshake,
                                                       msg_type,
                                                       extensions)
        self.sig_algs = sig_algs
        self.cert_types = cert_types
        self.context = context
        self.sanity_check_cert_types = sanity_check_cert_types
        if sig_algs is not None and extensions is not None:
            raise ValueError("Can't set sig_algs and extensions at the same "
                             "time")

    @staticmethod
    def _sanity_check_cert_types(cert_request):
        """Verify that the CertificateRequest is self-consistent."""
        for sig_alg in cert_request.supported_signature_algs:
            if sig_alg[1] in (SignatureAlgorithm.ecdsa,
                              SignatureAlgorithm.ed25519,
                              SignatureAlgorithm.ed448):
                key_type = "ECDSA"
                cert_type = "ecdsa_sign"
            elif sig_alg[1] == SignatureAlgorithm.rsa:
                key_type = "RSA"
                cert_type = "rsa_sign"
            elif sig_alg[1] == SignatureAlgorithm.dsa:
                key_type = "DSA"
                cert_type = "dss_sign"
            else:
                sig_scheme = SignatureScheme.toRepr(sig_alg)
                key_type = SignatureScheme.getKeyType(sig_scheme)
                assert key_type == "rsa", \
                    "Unsupported signature algorithm: {0}".format(sig_alg)
                cert_type = "rsa_sign"

            if getattr(ClientCertificateType, cert_type) \
                    not in cert_request.certificate_types:
                raise AssertionError(
                    "CertificateRequest includes {1} signature algorithms "
                    "({0}) but does not include {2} client "
                    "certificate type".format(sig_alg, key_type, cert_type))

    @staticmethod
    def _get_autohandler(ext_id):
        try:
            return _CR_EXT_HANDLER[ext_id]
        except KeyError:
            # handle future/GREASE extensions
            return None

    def _process_extensions(self, state, msg):
        for ext in msg.extensions:
            ext_id = ext.extType
            handler = None
            if ext_id in TLS_1_3_CR_FORBIDDEN:
                raise AssertionError(
                    "Server sent extension that is explicitly forbidden in "
                    "CertificateRequest messages: {0}".format(
                        ExtensionType.toStr(ext_id)))
            if self.extensions:
                handler = self.extensions[ext_id]
            if handler is None:
                handler = self._get_autohandler(ext_id)

            if callable(handler):
                handler(state, ext)
            elif isinstance(handler, TLSExtension):
                if not handler == ext:
                    raise AssertionError(
                        "Expected extension not matched for type {0}, "
                        "received: {1}".format(ExtensionType.toStr(ext_id),
                                               ext))
            elif handler is None:
                # since server can send arbitrary extensions, we need to
                # be able to process them, so if the self.extensions is unset
                # we can just do nothing
                pass
            else:
                raise ValueError("Bad extension handler for id {0}".format(
                    ExtensionType.toStr(ext_id)))

    def process(self, state, msg):
        """
        Check received Certificate Request

        :type state: ConnectionState
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.certificate_request

        cert_request = CertificateRequest(state.version)
        cert_request.parse(parser)

        self._cmp_eq_list(self.sig_algs, cert_request.supported_signature_algs,
                          SignatureScheme,
                          f_str="Unexpected signature algorithms. Got: {1}, "
                                "expected: {0}")

        self._cmp_eq_list(self.cert_types, cert_request.certificate_types,
                          ClientCertificateType,
                          f_str="Unexpected client certificate types. Got: "
                                "{1}, expected: {0}")

        if state.version == (3, 3) and self.sanity_check_cert_types:
            # only in TLS 1.2 do the sig algs coexist with cert types
            self._sanity_check_cert_types(cert_request)

        if state.version >= (3, 4):
            self._compare_extensions(cert_request)
            self._process_extensions(state, cert_request)
            if self.context is not None:
                self.context.append(cert_request)

        state.handshake_messages.append(cert_request)
        state.handshake_hashes.update(msg.write())


class ExpectServerHelloDone(ExpectHandshake):
    """Processing TLS Handshake protocol ServerHelloDone messages"""

    def __init__(self):
        super(ExpectServerHelloDone,
              self).__init__(ContentType.handshake,
                             HandshakeType.server_hello_done)

    def process(self, state, msg):
        """
        :type state: ConnectionState
        :type msg: Message
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello_done

        srv_hello_done = ServerHelloDone()
        srv_hello_done.parse(parser)

        state.handshake_messages.append(srv_hello_done)
        state.handshake_hashes.update(msg.write())


class ExpectChangeCipherSpec(Expect):
    """
    Processing TLS Change Cipher Spec messages.

    .. note::
      In SSLv3 up to TLS 1.2, the message modifies the state of record layer
      to expect encrypted records *after* receiving this message.
      In case of renegotiation, record layer will expect records encrypted
      with the newly negotiated keys. In TLS 1.3 it has no effect on record
      layer encryption.
    """

    def __init__(self):
        super(ExpectChangeCipherSpec,
              self).__init__(ContentType.change_cipher_spec)

    def process(self, state, msg):
        """
        :type state: ConnectionState
        :type msg: Message
        """
        assert msg.contentType == ContentType.change_cipher_spec
        parser = Parser(msg.write())
        ccs = ChangeCipherSpec().parse(parser)

        assert ccs.type == 1

        if state.version < (3, 4):
            # in TLS 1.3 the CCS does not have any affect on encryption
            if state.resuming:
                state.msg_sock.encryptThenMAC = state.encrypt_then_mac
                calc_pending_states(state)

            state.msg_sock.changeReadState()

            if state._our_record_size_limit:
                state.msg_sock.recv_record_limit = state._our_record_size_limit


class ExpectVerify(ExpectHandshake):
    """Processing of SSLv2 SERVER-VERIFY message"""

    def __init__(self):
        super(ExpectVerify, self).__init__(ContentType.handshake,
                                           SSL2HandshakeType.server_verify)

    def process(self, state, msg):
        """Check if the VERIFY message has expected value"""
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())

        msg_type = parser.get(1)
        assert msg_type == SSL2HandshakeType.server_verify


class ExpectFinished(ExpectHandshake):
    """
    Processing TLS handshake protocol Finished message.

    .. note::
      In TLS 1.3 the message will modify record layer to start *sending*
      records with encryption using the ``client_handshake_traffic_secret``
      keys.
      It will also modify the record layer to start expecting the records
      to be encrypted with ``server_application_traffic_secret`` keys.
    """

    def __init__(self, version=None, description=None):
        """
        Initialize object.

        .. note::
            The ``description`` parameter MUST be specified
            as a keyword argument, i.e. read the definition as
            ``(self, *, description=None)`` (see PEP 3102).
            Otherwise the behaviour of this node is not guaranteed if new
            arguments are added to it (as they will be added *before*
            the ``description`` argument).

        :param str description: name or comment attached to the node,
            it will be printed when :py:func:`str` or :py:func:`repr` is
            called on the node.
        """
        if version in ((0, 2), (2, 0)):
            super(ExpectFinished, self).__init__(ContentType.handshake,
                                                 SSL2HandshakeType.
                                                 server_finished)
        else:
            super(ExpectFinished, self).__init__(ContentType.handshake,
                                                 HandshakeType.finished)
        self.version = version
        self.description = description

    def process(self, state, msg):
        """
        :type state: ConnectionState
        :type msg: Message
        """
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == self.handshake_type
        if self.version is None:
            self.version = state.version

        if self.version in ((0, 2), (2, 0)):
            finished = ServerFinished()
        else:
            finished = Finished(self.version, state.prf_size)

        finished.parse(parser)

        if self.version in ((0, 2), (2, 0)):
            state.session_id = finished.verify_data
        elif self.version <= (3, 3):
            verify_expected = calc_key(state.version,
                                       state.key['master_secret'],
                                       state.cipher,
                                       b'client finished' if not state.client
                                       else b'server finished',
                                       state.handshake_hashes,
                                       output_length=12)

            assert finished.verify_data == verify_expected
        else:  # TLS 1.3
            finished_key = HKDF_expand_label(
                state.key['server handshake traffic secret'],
                b'finished',
                b'',
                state.prf_size,
                state.prf_name)
            transcript_hash = state.handshake_hashes.digest(state.prf_name)
            verify_expected = secureHMAC(finished_key,
                                         transcript_hash,
                                         state.prf_name)
            assert finished.verify_data == verify_expected

        state.handshake_messages.append(finished)
        state.key['server_verify_data'] = finished.verify_data
        state.handshake_hashes.update(msg.write())

        if self.version in ((0, 2), (2, 0)):
            state.msg_sock.handshake_finished = True

        if self.version > (3, 3):
            # in TLS 1.3 ChangeCipherSpec is a no-op, so we need to attach
            # the change for reading to some message that is always sent
            state.msg_sock.changeWriteState()

            # we now need to calculate application traffic keys to allow
            # correct interpretation of the alerts regarding Certificate,
            # CertificateVerify and Finished

            # derive the master secret
            secret = derive_secret(
                state.key['handshake secret'], b'derived', None,
                state.prf_name)
            secret = secureHMAC(
                secret, bytearray(state.prf_size), state.prf_name)
            state.key['master secret'] = secret

            # derive encryption keys
            c_traff_sec = derive_secret(
                secret, b'c ap traffic', state.handshake_hashes,
                state.prf_name)
            state.key['client application traffic secret'] = c_traff_sec
            s_traff_sec = derive_secret(
                secret, b's ap traffic', state.handshake_hashes,
                state.prf_name)
            state.key['server application traffic secret'] = s_traff_sec

            # derive TLS exporter key
            exp_ms = derive_secret(secret, b'exp master',
                                   state.handshake_hashes,
                                   state.prf_name)
            state.key['exporter master secret'] = exp_ms

            # set up the encryption keys for application data
            state.msg_sock.calcTLS1_3PendingState(
                state.cipher, c_traff_sec, s_traff_sec, None)
            state.msg_sock.changeReadState()

    def __repr__(self):
        """Return human readable representation of the object."""
        return self._repr(['description'])


class ExpectEncryptedExtensions(_ExpectExtensionsMessage):
    """Processing of the TLS handshake protocol Encrypted Extensions message"""

    def __init__(self, extensions=None):
        super(ExpectEncryptedExtensions, self).__init__(
            ContentType.handshake,
            HandshakeType.encrypted_extensions,
            extensions)

    def _compare_extensions_in_ee(self, srv_exts, cln_hello):
        """
        Verify that server provided extensions match exactly expected list.
        """
        # check if received extensions match the set extensions
        self._compare_extensions(srv_exts)
        if self.extensions is None and srv_exts.extensions:
            cln_exts = set(i.extType for i in cln_hello.extensions)
            got = set(i.extType for i in srv_exts.extensions)
            diff = got.difference(cln_exts)
            if not got.issubset(cln_exts):
                raise AssertionError("Server sent unexpected extension(s):"
                                     " {0}".format(
                                         ", ".join(ExtensionType.toStr(i)
                                                   for i in diff)))

    @staticmethod
    def _get_autohandler(ext_id):
        try:
            return _EE_EXT_HANDLER[ext_id]
        except KeyError:
            raise ValueError("No autohandler for "
                             "{0}"
                             .format(ExtensionType
                                     .toStr(ext_id)))

    def _process_extensions(self, state, srv_exts):
        """Check if extensions are correct."""
        # fix these constants, when the extensions are implemented
        ee_supported = [ExtensionType.server_name,
                        1,  # max_fragment_length - RFC 6066
                        ExtensionType.supported_groups,
                        14,  # use_srtp - RFC 5764
                        ExtensionType.heartbeat,  # RFC 6520
                        ExtensionType.alpn,
                        19,  # client_certificate_type
                             # draft-ietf-tls-tls13-28 / RFC 7250
                        20,  # server_certificate_type
                             # draft-ietf-tls-tls13-28 / RFC 7250
                        ExtensionType.record_size_limit,  # RFC 8449
                        ExtensionType.early_data]

        for ext in srv_exts.extensions:
            ext_id = ext.extType
            if ext_id not in ee_supported:
                raise AssertionError("Server sent unsupported "
                                     "extension of type {0}"
                                     .format(ExtensionType
                                             .toStr(ext_id)))
            handler = None
            if self.extensions:
                handler = self.extensions[ext_id]

            # use automatic handlers for some extensions
            if handler is None:
                handler = self._get_autohandler(ext_id)

            if callable(handler):
                handler(state, ext)
            elif isinstance(handler, TLSExtension):
                if not handler == ext:
                    raise AssertionError("Expected extension not "
                                         "matched for type {0}, "
                                         "received: {1}"
                                         .format(ExtensionType
                                                 .toStr(ext_id),
                                                 ext))
            else:
                raise ValueError("Bad extension handler for id {0}"
                                 .format(ExtensionType.toStr(ext_id)))

    def process(self, state, msg):
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == self.handshake_type

        srv_exts = EncryptedExtensions().parse(parser)

        # get client_hello message with CH extensions
        cln_hello = state.get_last_message_of_type(ClientHello)

        self._compare_extensions_in_ee(srv_exts, cln_hello)

        if srv_exts.extensions:
            self._process_extensions(state, srv_exts)

        state.handshake_messages.append(srv_exts)
        state.handshake_hashes.update(msg.write())


class ExpectNewSessionTicket(ExpectHandshake):
    """Processing TLS handshake protocol new session ticket message."""

    def __init__(self, version=None, description=None):
        """
        Initialise object.

        .. note::
            The ``description`` parameter MUST be specified
            as a keyword argument, i.e. read the definition as
            ``(self, *, description=None)`` (see PEP 3102).
            Otherwise the behaviour of this node is not guaranteed if new
            arguments are added to it (as they will be added *before*
            the ``description`` argument).

        :param tuple version: parse the message as in the specified TLS
            version, use negotiated version by default
        :param str description: name or comment attached to the node,
            it will be printed when :py:func:`str` or :py:func:`repr` is
            called on the node.
        """
        super(ExpectNewSessionTicket, self).__init__(
            ContentType.handshake,
            HandshakeType.new_session_ticket)
        self.description = description
        self.version = version

    def process(self, state, msg):
        """Parse, verify and process the message."""
        assert msg.contentType == ContentType.handshake
        msg_bytes = msg.write()
        parser = Parser(msg_bytes)
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.new_session_ticket
        if self.version is None:
            self.version = state.version

        if self.version < (3, 4):
            ticket = NewSessionTicket1_0().parse(parser)
        else:
            ticket = NewSessionTicket().parse(parser)
        ticket.time = time.time()

        state.session_tickets.append(ticket)

        if self.version < (3, 4):
            # in TLS 1.2 and earlier tickets are part of the Handshake, so
            # they need to be hashed
            state.handshake_messages.append(ticket)
            state.handshake_hashes.update(msg_bytes)

    def __repr__(self):
        """Return human readable representation of object."""
        return self._repr(['description'])


class ExpectHelloRequest(ExpectHandshake):
    """Processing of TLS handshake protocol hello request message."""

    def __init__(self, description=None):
        """
        Initialise object.

        .. note::
            The ``description`` parameter MUST be specified
            as a keyword argument, i.e. read the definition as
            ``(self, *, description=None)`` (see PEP 3102).
            Otherwise the behaviour of this node is not guaranteed if new
            arguments are added to it (as they will be added *before*
            the ``description`` argument).

        :param str description: name or comment attached to the node,
            it will be printed when :py:func:`str` or :py:func:`repr` is
            called on the node.
        """
        super(ExpectHelloRequest, self).__init__(
            ContentType.handshake,
            HandshakeType.hello_request)
        self.description = description

    def process(self, state, msg):
        """Parse, verify and process the message."""
        assert msg.contentType == ContentType.handshake
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.hello_request

        # check if it is well-formed
        HelloRequest().parse(parser)

    def __repr__(self):
        """Return human readable representation of object."""
        return self._repr(['description'])


class ExpectAlert(Expect):
    """Processing TLS Alert message"""

    def __init__(self, level=None, description=None):
        super(ExpectAlert, self).__init__(ContentType.alert)
        self.level = level
        self.description = description

    def process(self, state, msg):
        assert msg.contentType == ContentType.alert
        parser = Parser(msg.write())

        alert = Alert()
        alert.parse(parser)

        problem_desc = ""
        if self.level is not None and alert.level != self.level:
            problem_desc += "Alert level {0} != {1}".format(alert.level,
                                                            self.level)
        if self.description is not None:
            # allow for multiple choice for description
            if not isinstance(self.description, Iterable):
                self.description = tuple([self.description])

            if alert.description not in self.description:
                if problem_desc:
                    problem_desc += ", "
                descriptions = ["\"{0}\"".format(AlertDescription.toStr(i))
                                for i in self.description]
                expected = ", ".join(
                    itertools.chain((i for i in descriptions[:-2]),
                                    [" or ".join(i for i in descriptions[-2:])]
                                   ))
                received = AlertDescription.toStr(alert.description)
                problem_desc += ("Expected alert description {0} does not "
                                 "match received \"{1}\""
                                 .format(expected, received))
        if problem_desc:
            raise AssertionError(problem_desc)

    def __repr__(self):
        """Return human readable representation of object."""
        return self._repr(["level", "description"])


class ExpectSSL2Alert(ExpectHandshake):
    """Processing of SSLv2 Handshake protocol alert messages"""

    def __init__(self, error=None):
        super(ExpectSSL2Alert, self).__init__(ContentType.handshake,
                                              SSL2HandshakeType.error)
        self.error = error

    def process(self, state, msg):
        """Analyse the error message"""
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == SSL2HandshakeType.error

        if self.error is not None:
            assert self.error == parser.get(2)


class ExpectApplicationData(Expect):
    """Processing Application Data message"""

    def __init__(self, data=None, size=None, output=None, description=None):
        super(ExpectApplicationData, self).\
                __init__(ContentType.application_data)
        self.data = data
        self.size = size
        self.output = output
        self.description = description

    def __str__(self):
        """Return human readable representation of the object."""
        return self._repr(['data', 'size', 'description'])

    def process(self, state, msg):
        assert msg.contentType == ContentType.application_data
        data = msg.write()

        if self.data:
            assert self.data == data
        if self.size and len(data) != self.size:
            raise AssertionError("ApplicationData of unexpected size: {0}, "
                                 "expected: {1}".format(len(data), self.size))
        if self.output:
            self.output.write("ExpectApplicationData received payload:\n")
            self.output.write(repr(data))
            self.output.write("ExpectApplicationData end of payload.\n")


class ExpectHeartbeat(ExpectMessage):
    """Processing of heartbeat messages."""

    def __init__(self, message_type=HeartbeatMessageType.heartbeat_response,
                 payload=None, padding_size=None):
        """
        Set up waiting for a heartbeat message.

        :type message_type: int
        :param message_type: Type of heartbeat messages to wait for, see
            `~tlslite.constants.HeartbeatMessageType` for defined types
        :type payload: bytes-like
        :param payload: literal value of padding to expect, if set to ``None``,
            any payload will be accepted
        :type padding_size: int
        :param padding_size: exact length of padding that will be expected,
            if set to ``None``, any padding length will be accepted
        """
        super(ExpectHeartbeat, self).\
            __init__(ContentType.heartbeat)
        self.message_type = message_type
        self.payload = payload
        self.padding_size = padding_size

    def process(self, state, msg):
        """Check if the ``msg`` meets the requirements for the message."""
        assert msg.contentType == ContentType.heartbeat

        parser = Parser(msg.write())
        heartbeat = Heartbeat().parse(parser)

        self._cmp_eq(self.message_type, heartbeat.message_type,
                     HeartbeatMessageType,
                     "Unexpected heartbeat message type. Expected: {0}, "
                     "received: {1}.")

        self._cmp_eq(self.payload, heartbeat.payload,
                     f_str="Unexpected payload in Heartbeat message "
                           "received. Expected: {0!r}, received: {1!r}")

        if self.padding_size is None:
            assert len(heartbeat.padding) >= 16
        else:
            if len(heartbeat.padding) != self.padding_size:
                raise AssertionError(
                        "Server sent unexpected size of padding "
                        "in heartbeat message. Expected: {0}, "
                        "received: {1}".format(self.padding_size,
                                               len(heartbeat.padding)))


class ExpectNoMessage(Expect):
    """
    Virtual message signifying timeout on message listen.

    :ivar timeout: how long to wait for message before giving up, in seconds,
        can be float
    :vartype timeout: int or float
    """

    def __init__(self, timeout=0.1):
        super(ExpectNoMessage, self).__init__(None)
        self.timeout = timeout

    def process(self, state, msg):
        """Do nothing."""
        pass


class ExpectClose(Expect):
    """Virtual message signifying closing of TCP connection"""

    def __init__(self):
        super(ExpectClose, self).__init__(None)

    def process(self, state, msg):
        """Close our side"""
        state.msg_sock.sock.close()


class ExpectCertificateStatus(ExpectHandshake):
    """Processing of CertificateStatus message from RFC 6066."""

    def __init__(self):
        super(ExpectCertificateStatus,
              self).__init__(ContentType.handshake,
                             HandshakeType.certificate_status)

    def process(self, state, msg):
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.certificate_status

        cert_status = CertificateStatus().parse(parser)

        state.handshake_messages.append(cert_status)
        state.handshake_hashes.update(msg.write())


class ExpectKeyUpdate(ExpectHandshake):
    """Processing of post-handshake KeyUpdate message from RFC 8446"""

    def __init__(self, message_type=None):
        """
        Initialize object.

        :type message_type: int
        :param message_type: type of KeyUpdate msg, either
            update_not_requested or update_requested
        """
        super(ExpectKeyUpdate, self).__init__(
            ContentType.handshake,
            HandshakeType.key_update)
        self.message_type = message_type

    def process(self, state, msg):
        """
        Parse, verify and process the message.

        :type state: ConnectionState
        :type msg: Message
        """
        assert msg.contentType == self.content_type
        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == self.handshake_type

        keyupdate = KeyUpdate().parse(parser)
        assert keyupdate.message_type == self.message_type

        _, sr_app_secret = state.msg_sock.\
            calcTLS1_3KeyUpdate_sender(
                state.cipher,
                state.key['client application traffic secret'],
                state.key['server application traffic secret'])
        state.key['server application traffic secret'] = sr_app_secret