matheuswhite/bluebees

View on GitHub
bluebees/client/mesh_layers/transport_layer.py

Summary

Maintainability
B
5 hrs
Test Coverage
from bluebees.client.mesh_layers.network_layer import NetworkLayer
from bluebees.client.mesh_layers.mesh_context import SoftContext
from bluebees.client.network.network_data import NetworkData
from bluebees.client.application.application_data import ApplicationData
from bluebees.client.node.node_data import NodeData
from bluebees.client.data_paths import base_dir, net_dir, app_dir, node_dir
from bluebees.common.logging import log_sys, INFO, DEBUG
from bluebees.client.mesh_layers.address import address_type, UNICAST_ADDRESS, \
                                       GROUP_ADDRESS
from bluebees.client.node.group_data import find_group_by_addr
from bluebees.common.crypto import crypto
from bluebees.common.file import file_helper
from typing import List
import asyncio

LT_MTU = 12


# ! Segment Acknowledgment message is a control message and the CTL value is 1,
# !   and its sizemic is 64-bits
# ! Control messages has sizemic equals to 64-bits, since access message has
# !   sizemic equals to 32-bits.
class AckTimeout(Exception):
    pass


class TransportLayer:

    def __init__(self, send_queue, recv_queue):
        self.net_layer = NetworkLayer(send_queue=send_queue,
                                      recv_queue=recv_queue)

        self.log = log_sys.get_logger('transport_layer')
        self.log.set_level(INFO)

    # * Send Methods
    def _encrypt_access_pdu(self, pdu: bytes, soft_ctx: SoftContext) -> bytes:
        net_data = NetworkData.load(base_dir + net_dir +
                                    soft_ctx.network_name + '.yml')
        node_data = NodeData.load(base_dir + node_dir + soft_ctx.node_name +
                                  '.yml')
        self.net_layer.hard_ctx.seq = node_data.seq

        if not soft_ctx.is_devkey:
            app_data = ApplicationData.load(base_dir + app_dir +
                                            soft_ctx.application_name +
                                            '.yml')
            app_key = app_data.key
            app_nonce = b'\x01\x00' + \
                self.net_layer.hard_ctx.seq.to_bytes(3, 'big') + \
                soft_ctx.src_addr + soft_ctx.dst_addr + net_data.iv_index
        else:
            node_data = NodeData.load(base_dir + node_dir +
                                      soft_ctx.node_name + '.yml')
            app_key = node_data.devkey
            app_nonce = b'\x02\x00' + \
                self.net_layer.hard_ctx.seq.to_bytes(3, 'big') + \
                soft_ctx.src_addr + soft_ctx.dst_addr + net_data.iv_index

        result, mic = crypto.aes_ccm_complete(key=app_key, nonce=app_nonce,
                                              text=pdu, adata=b'', mic_size=4)

        return result + mic

    def _unsegmented_transport_pdu(self, pdu: bytes,
                                   soft_ctx: SoftContext) -> bytes:
        if not soft_ctx.is_devkey:
            app_data = ApplicationData.load(base_dir + app_dir +
                                            soft_ctx.application_name +
                                            '.yml')
            aid = crypto.k4(n=app_data.key)
            unseg_tr_pdu = 0x40
        else:
            node_data = NodeData.load(base_dir + node_dir +
                                      soft_ctx.node_name + '.yml')
            aid = crypto.k4(n=node_data.devkey)
            unseg_tr_pdu = 0x00

        unseg_tr_pdu = (unseg_tr_pdu | (int.from_bytes(aid, 'big') &
                        0x3f)).to_bytes(1, 'big')
        unseg_tr_pdu += pdu

        return unseg_tr_pdu

    def __header_segmented_transport_pdu(self, soft_ctx: SoftContext,
                                         seg_o: int) -> bytes:
        net_data = NetworkData.load(base_dir + net_dir +
                                    soft_ctx.network_name + '.yml')
        node_data = NodeData.load(base_dir + node_dir + soft_ctx.node_name +
                                  '.yml')
        seq_auth = (int.from_bytes(net_data.iv_index, 'big') << 24) | \
            node_data.seq
        self.net_layer.hard_ctx.seq_zero = seq_auth & 0x1fff

        first_byte = 0x80
        if not soft_ctx.is_devkey:
            app_data = ApplicationData.load(base_dir + app_dir +
                                            soft_ctx.application_name +
                                            '.yml')
            aid = crypto.k4(n=app_data.key)
            first_byte = first_byte | 0x40
        else:
            node_data = NodeData.load(base_dir + node_dir +
                                      soft_ctx.node_name + '.yml')
            aid = crypto.k4(n=node_data.devkey)
        first_byte = (first_byte | (aid[0] & 0x3f)).to_bytes(1, 'big')

        cont = (self.net_layer.hard_ctx.seg_n & 0x1f)
        cont = cont | ((seg_o & 0x1f) << 5)
        cont = cont | ((self.net_layer.hard_ctx.seq_zero & 0x1fff) << 10)
        cont = cont.to_bytes(3, 'big')

        return first_byte + cont

    def _segmented_transport_pdu(self, pdu: bytes,
                                 soft_ctx: SoftContext) -> List[bytes]:
        self.net_layer.hard_ctx.seg_n = (len(pdu) - 1) // LT_MTU
        segments = []

        for seg_o in range(self.net_layer.hard_ctx.seg_n + 1):
            header = self.__header_segmented_transport_pdu(soft_ctx, seg_o)
            segments.append(header + pdu[0:LT_MTU])
            pdu = pdu[LT_MTU:]

        return segments

    # * Notes
    # *  - In this implementation, all received message, with dst addr setted
    # *  to a group address, will be discard
    # *  - In this implementation, the nodes contains only one element
    def __check_addresses(self, recv_ctx: SoftContext,
                          ctx: SoftContext) -> bool:
        if address_type(recv_ctx.dst_addr) == UNICAST_ADDRESS:
            send_dst_type = address_type(ctx.dst_addr)
            if send_dst_type == UNICAST_ADDRESS:
                return (recv_ctx.dst_addr == ctx.src_addr) and \
                    (recv_ctx.src_addr == ctx.dst_addr)
            elif send_dst_type == GROUP_ADDRESS:
                group = find_group_by_addr(ctx.dst_addr)
                if not group:
                    return False

                return (recv_ctx.dst_addr == ctx.src_addr) and \
                    (recv_ctx.src_addr in group.sub_addrs)
        else:
            return False

    async def _wait_ack(self, soft_ctx: SoftContext, segments: List[bytes]):
        ack_bits = 0
        expected_ack_bits = (2 ** (self.net_layer.hard_ctx.seg_n + 1)) - 1
        while True:
            self.log.debug(f'Waiting ack...')
            ack_pdu, r_ctx, seq_num = await self.net_layer.transport_pdus.get()
            self.log.debug(f'Got ack')

            # not same src and dst address (discard)
            if not self.__check_addresses(r_ctx, soft_ctx):
                self.log.debug(f'Src {r_ctx.src_addr.hex()}, '
                               f'Dst: {r_ctx.dst_addr.hex()}')
                continue

            # not control message (discard)
            if not self.net_layer.hard_ctx.is_ctrl_msg:
                self.log.debug('Not control message')
                continue

            # not ack pdu (discard)
            if ack_pdu[0] != 0x00:
                self.log.debug('Not ack pdu')
                continue

            # seq zero wrong (discard)
            pdu_seq_zero = (int.from_bytes(ack_pdu[1:3], 'big') & 0x7ffc) >> 2
            if pdu_seq_zero != self.net_layer.hard_ctx.seq_zero:
                self.log.debug(f'Seq_zero: {pdu_seq_zero}')
                continue

            ack_bits = ack_bits | int.from_bytes(ack_pdu[3:7], 'big')
            self.log.debug(f'Ack bits: {hex(ack_bits)}')
            if ack_bits == expected_ack_bits:
                return
            else:
                # resend missing segments
                bits = ack_bits
                for i, seg in enumerate(segments):
                    if bits & 0x01 == 0:
                        self.log.debug(f'Send segment: {i}|{seg.hex()}')
                        await self.net_layer.send_pdu(seg, soft_ctx)
                    bits = bits >> 1
                self.log.debug(f'Ack bits [a]: {hex(ack_bits)}')

    async def send_pdu(self, access_pdu: bytes, soft_ctx: SoftContext):
        success = False

        crypt_access_pdu = self._encrypt_access_pdu(access_pdu, soft_ctx)

        if len(crypt_access_pdu) <= LT_MTU:
            transport_pdu = self._unsegmented_transport_pdu(crypt_access_pdu,
                                                            soft_ctx)
            self.net_layer.hard_ctx.is_ctrl_msg = False

            await self.net_layer.send_pdu(transport_pdu, soft_ctx)

            success = True
        else:
            segments = self._segmented_transport_pdu(crypt_access_pdu,
                                                     soft_ctx)
            self.net_layer.hard_ctx.is_ctrl_msg = False

            for i, seg in enumerate(segments):
                await self.net_layer.send_pdu(seg, soft_ctx)
                self.log.debug(f'Send segment: {i}|{seg.hex()}')

            try:
                await asyncio.wait_for(self._wait_ack(soft_ctx, segments),
                                       soft_ctx.ack_timeout)
                success = True
            except asyncio.TimeoutError:
                self.log.debug('Wait ack timeout')

        return success

    # * Receive Methods
    async def __send_ack(self, seg_o_table: dict, soft_ctx: SoftContext):
        if not seg_o_table:
            return

        pdu = 0x00
        pdu = pdu | ((self.net_layer.hard_ctx.seq_zero & 0x1fff) << 2)
        pdu = pdu.to_bytes(3, 'big')
        block_ack = 0x0000_0001
        for k, _ in seg_o_table.items():
            block_ack = block_ack | (1 << k)
            self.log.debug(f'block ack: {block_ack}, k: {k}')
        pdu += block_ack.to_bytes(4, 'big')

        self.log.debug(f'Ack seq zero: '
                       f'{hex(self.net_layer.hard_ctx.seq_zero)}')
        self.net_layer.hard_ctx.is_ctrl_msg = True
        await self.net_layer.send_pdu(pdu, soft_ctx)

    def __search_application_by_aid(self, aid: int) -> str:
        filenames = file_helper.list_files(base_dir + app_dir)

        for f in filenames:
            app_data = ApplicationData.load(base_dir + app_dir + f)
            app_aid = crypto.k4(n=app_data.key)
            if app_aid == aid:
                return app_data.name

        return None

    def __search_node_by_addr(self, addr: bytes) -> str:
        filenames = file_helper.list_files(base_dir + node_dir)

        for f in filenames:
            node_data = NodeData.load(base_dir + node_dir + f)
            if addr == node_data.addr:
                return node_data.name

        return ''

    def _fill_hard_ctx(self, start_pdu: bytes):
        self.net_layer.hard_ctx.szmic = (start_pdu[1] & 0x80) >> 7
        self.net_layer.hard_ctx.seq_zero = \
            (int.from_bytes(start_pdu[1:3], 'big') & 0x7ffc) >> 2
        self.net_layer.hard_ctx.seg_o = \
            (int.from_bytes(start_pdu[2:4], 'big') & 0x03e0) >> 5
        self.net_layer.hard_ctx.seg_n = start_pdu[3] & 0x1f

        self.log.debug(f'Seq zero: {hex(self.net_layer.hard_ctx.seq_zero)}')

    def _fill_soft_ctx(self, start_pdu: bytes,
                       ctx: SoftContext) -> SoftContext:
        afk = start_pdu[0] & 0x40 >> 6
        aid = start_pdu[0] & 0x3f
        if afk == 1:
            app_name = self.__search_application_by_aid(aid)
            if not app_name:
                return None
            ctx.application_name = app_name
            ctx.is_devkey = False
        else:
            ctx.application_name = ''
            ctx.is_devkey = True

        return ctx

    def _join_segments(self, sorted_segments: List[bytes]) -> bytes:
        tr_pdu = b''
        for seg in sorted_segments:
            tr_pdu += seg[4:]

        return tr_pdu

    def _decrypt_transport_pdu(self, pdu: bytes, ctx: SoftContext,
                               first_seq: int) -> bytes:
        if self.net_layer.hard_ctx.szmic == 0:
            encrypted_pdu = pdu[0:-4]
            transport_mic = pdu[-4:]
        else:
            encrypted_pdu = pdu[0:-8]
            transport_mic = pdu[-8:]

        self.log.debug(f'Encrypted pdu: {encrypted_pdu.hex()}, mic = '
                       f'{transport_mic.hex()}')

        net_data = NetworkData.load(base_dir + net_dir + ctx.network_name +
                                    '.yml')

        if ctx.is_devkey:
            self.log.debug(f'Using devkey, node name [{ctx.node_name}]')
            if not ctx.node_name:
                self.log.debug('No node found')
                return None

            node_data = NodeData.load(base_dir + node_dir + ctx.node_name +
                                      '.yml')
            key = node_data.devkey
            nonce = b'\x02'
        else:
            app_data = ApplicationData.load(base_dir + app_dir +
                                            ctx.application_name + '.yml')
            key = app_data.key
            nonce = b'\x01'

        nonce += (self.net_layer.hard_ctx.szmic << 7).to_bytes(1, 'big')
        nonce += (first_seq).to_bytes(3, 'big')
        nonce += ctx.src_addr
        nonce += ctx.dst_addr
        nonce += net_data.iv_index

        access_pdu, mic_is_ok = crypto.aes_ccm_decrypt(key=key, nonce=nonce,
                                                       text=encrypted_pdu,
                                                       mic=transport_mic)

        self.log.debug(f'Access PDU: {access_pdu.hex()}, first seq: '
                       f'{hex(first_seq)}')

        if not mic_is_ok:
            self.log.debug(f'Mic is wrong, pdu: {access_pdu.hex()}, seq: '
                           f'{hex(first_seq)}')
            return None
        else:
            return access_pdu

    async def _collect_segments(self, soft_ctx: SoftContext) -> List[bytes]:
        seg_o_table = {}
        ack_counter = 0
        while len(seg_o_table) < self.net_layer.hard_ctx.seg_n:
            self.log.debug('Waiting segment')
            pdu, r_ctx, seq_num = await self.net_layer.transport_pdus.get()
            self.log.debug(f'Got segment, pdu: {pdu.hex()}')

            seq_zero = (int.from_bytes(pdu[1:3], 'big') & 0x7ffc) >> 2
            if seq_zero != self.net_layer.hard_ctx.seq_zero:
                self.log.debug('Seq zero diff')
                continue

            # not same src and dst address (discard)
            if not self.__check_addresses(r_ctx, soft_ctx):
                self.log.debug('Invalid address')
                continue

            # each 10 messages received, sent a ack
            if ack_counter >= 10:
                await self.__send_ack(seg_o_table, soft_ctx)
                ack_counter = 0
            else:
                ack_counter += 1

            # control message (discard)
            if self.net_layer.hard_ctx.is_ctrl_msg:
                self.log.debug('Control message')
                continue

            # unsegmented pdu (discard)
            if ((pdu[0] & 0x80) >> 7) == 0:
                self.log.debug('unsegmented pdu')
                continue

            seg_o = (int.from_bytes(pdu[2:4], 'big') & 0x03e0) >> 5
            # segment already received (discard)
            if seg_o in seg_o_table.keys():
                self.log.debug('unsegmented pdu')
                continue

            self.log.debug('correct segment')
            seg_o_table[seg_o] = pdu

        # send ack message
        await self.__send_ack(seg_o_table, soft_ctx)

        segments = []
        for _, v in seg_o_table.items():
            segments.append(v)

        return segments

    async def recv_pdu(self, segment_timeout: int,
                       soft_ctx: SoftContext) -> (bytes, SoftContext):
        start_pdu, r_ctx, seq_num = await self.net_layer.transport_pdus.get()

        while self.net_layer.hard_ctx.is_ctrl_msg:
            start_pdu, r_ctx, seq_num = \
                await self.net_layer.transport_pdus.get()

        self.log.debug('Testing if is segmented...')
        if ((start_pdu[0] & 0x80) >> 7) == 0:
            # unsegmented pdu
            self.log.debug(f'Is unsegmented. PDU: {start_pdu.hex()}')

            # filling soft context
            r_ctx = self._fill_soft_ctx(start_pdu=start_pdu, ctx=r_ctx)
            if not r_ctx:
                return None, None

            # checking addresses
            if not self.__check_addresses(r_ctx, soft_ctx):
                self.log.debug('Not same address')
                return None, None

            # decrypting pdu
            self.log.debug('Start decrypting...')
            access_pdu = self._decrypt_transport_pdu(
                start_pdu[1:], r_ctx, seq_num)
            self.log.debug('End decrypt')
            if not access_pdu:
                return None, None
        else:
            # segmented pdu

            # store the seq number of first segment
            first_seq = seq_num
            self.log.debug(f'First seq: {hex(first_seq)}')

            # filling hard context
            self.log.debug(f'Is segmented. PDU: {start_pdu.hex()}')
            self._fill_hard_ctx(start_pdu)

            # filling soft context
            self.log.debug(f'fill soft ctx')
            r_ctx = self._fill_soft_ctx(start_pdu=start_pdu, ctx=r_ctx)
            if not r_ctx:
                self.log.debug(f'not soft ctx')
                return None, None

            # checking addresses
            if not self.__check_addresses(r_ctx, soft_ctx):
                self.log.debug('Not same address')
                return None, None

            # collecting segments
            try:
                self.log.debug(f'collect segments')
                sorted_segments = \
                    await asyncio.wait_for(self._collect_segments(soft_ctx),
                                           segment_timeout)
            except Exception as e:
                raise e
            except asyncio.TimeoutError:
                self.log.debug('Giving up of segmented message')

            # join segments
            self.log.debug(f'join segments')
            transport_pdu = self._join_segments([start_pdu] + sorted_segments)

            # decrypting pdu
            self.log.debug(f'decryption pdu')
            access_pdu = self._decrypt_transport_pdu(transport_pdu, r_ctx,
                                                     first_seq)
            if not access_pdu:
                return None, None

        self.log.debug(f'ret pdu')
        return access_pdu, r_ctx