digidotcom/python-suitcase

View on GitHub
suitcase/structure.py

Summary

Maintainability
D
2 days
Test Coverage
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# Copyright (c) 2019 Digi International Inc. All Rights Reserved.
import sys
import os
import six
from suitcase.exceptions import SuitcaseException, \
    SuitcasePackException, SuitcaseParseError
from suitcase.fields import FieldArray, FieldPlaceholder, CRCField, SubstructureField, \
    ConditionalField, FieldAccessor, FieldProperty
from six import BytesIO


class ParseError(Exception):
    """Exception raised when there is an error parsing"""


class Packer(object):
    """Object responsible for packing/unpacking bytes into/from fields"""

    def __init__(self, ordered_fields, crc_field):
        self.crc_field = crc_field
        self.ordered_fields = ordered_fields

    def pack(self):
        # type: () -> bytes
        sio = BytesIO()
        self.write(sio)
        return sio.getvalue()

    def write(self, stream):
        # type: (BytesIO) -> None
        # now, pack everything in
        crc_fields = []
        for name, field in self.ordered_fields:
            try:
                if isinstance(field, CRCField):
                    crc_offset = stream.tell()
                    field.pack(stream)
                    crc_fields.append((field, crc_offset))
                else:
                    field.pack(stream)
            except SuitcaseException:
                raise  # just reraise the same exception object
            except Exception:
                # keep the original traceback information, see
                # http://stackoverflow.com/questions/3847503/wrapping-exceptions-in-python
                exc_type = SuitcasePackException
                _, exc_value, exc_traceback = sys.exc_info()
                exc_value = exc_type("Unexpected exception during pack of %r: %s" % (name, str(exc_value)))
                six.reraise(exc_type, exc_value, exc_traceback)

        # if there is a crc value, seek back to the field and
        # pack it with the right value
        if len(crc_fields) > 0:
            data = stream.getvalue()
            for field, offset in crc_fields:
                stream.seek(offset)
                checksum_data = self.crc_field.packed_checksum(data)
                stream.write(checksum_data)

    def unpack(self, data, trailing=False):
        # type: (bytes, bool) -> BytesIO
        stream = BytesIO(data)
        self.unpack_stream(stream)
        stream.tell()
        if trailing:
            return stream
        elif stream.tell() != len(data):
            raise SuitcaseParseError("Structure fully parsed but additional bytes remained.  Parsing "
                                     "consumed %d of %d bytes" %
                                     (stream.tell(), len(data)))
        return stream

    def unpack_stream(self, stream):
        # type: (BytesIO) -> None
        """Unpack bytes from a stream of data field-by-field

        In the most basic case, the basic algorithm here is as follows::

            for _name, field in self.ordered_fields:
               length = field.bytes_required
               data = stream.read(length)
               field.unpack(data)

        This logic is complicated somewhat by the handling of variable length
        greedy fields (there may only be one).  The logic when we see a
        greedy field (bytes_required returns None) in the stream is to
        pivot and parse the remaining fields starting from the last and
        moving through the stream backwards.  There is also some special
        logic present for dealing with checksum fields.

        """
        crc_fields = []
        greedy_field = None
        greedy_field_name = None
        # go through the fields from first to last.  If we hit a greedy
        # field, break out of the loop
        for i, (name, field) in enumerate(self.ordered_fields):
            if isinstance(field, CRCField):
                crc_fields.append((field, stream.tell()))
            length = field.bytes_required
            if field.is_substructure():
                remaining_data = stream.getvalue()[stream.tell():]
                after_unpack = field.unpack(remaining_data, trailing=True)
                consumed = len(remaining_data) - len(after_unpack)
                # We need to fast forward by as much as was consumed by the structure
                stream.seek(stream.tell() + consumed)
                continue
            elif length is None and field.is_greedy:
                # If length is None, this is assumed to be a greedy field.
                # But we check is_greedy so that non-greedy fields are supported.

                if isinstance(field, FieldArray) and field.num_elements is not None:
                    # Read the data greedily now, and we'll backtrack after enough elements have been read.
                    data = stream.read()
                else:
                    greedy_field = field
                    greedy_field_name = name
                    break
            else:
                data = stream.read(length)
                if length is not None and len(data) != length:
                    raise SuitcaseParseError("While attempting to parse field "
                                             "%r we tried to read %s bytes but "
                                             "we were only able to read %s." %
                                             (name, length, len(data)))

            try:
                unused_data = field.unpack(data)
                stream.seek(-len(unused_data or b""), os.SEEK_CUR)
            except SuitcaseException:
                raise  # just re-raise these
            except Exception:
                exc_type = SuitcaseParseError
                _, exc_value, exc_traceback = sys.exc_info()
                exc_value = exc_type("Unexpected exception while unpacking field %r: %s" % (name, str(exc_value)))
                six.reraise(exc_type, exc_value, exc_traceback)

        if greedy_field is not None:
            remaining_data = stream.read()
            inverted_stream = BytesIO(remaining_data[::-1])

            # work through the remaining fields in reverse order in order
            # to narrow in on the right bytes for the greedy field
            reversed_remaining_fields = self.ordered_fields[(i + 1):][::-1]
            for _name, field in reversed_remaining_fields:
                if isinstance(field, CRCField):
                    crc_fields.append(
                        (field, -inverted_stream.tell() - field.bytes_required))
                length = field.bytes_required
                if length is None and field.is_greedy:
                    raise SuitcaseParseError(
                        "While attempting to parse greedy field %r we found "
                        "another greedy field, %r. There can only be one greedy"
                        "field." %
                        (greedy_field_name, _name)
                    )

                data = inverted_stream.read(length)[::-1]
                if len(data) != length:
                    raise SuitcaseParseError("While attempting to parse field "
                                             "%r we tried to read %s bytes but "
                                             "we were only able to read %s." %
                                             (_name, length, len(data)))
                try:
                    field.unpack(data)
                except SuitcaseException:
                    raise  # just re-raise these
                except Exception:
                    exc_type = SuitcaseParseError
                    _, exc_value, exc_traceback = sys.exc_info()
                    exc_value = exc_type("Unexpected exception while unpacking field %r: %s" % (name, str(exc_value)))
                    six.reraise(exc_type, exc_value, exc_traceback)

            greedy_data_chunk = inverted_stream.read()[::-1]
            greedy_field.unpack(greedy_data_chunk)

        if crc_fields:
            data = stream.getvalue()
            for (crc_field, offset) in crc_fields:
                crc_field.validate(data, offset)


class StructureMeta(type):
    """Metaclass for all structure objects

    When a class with this metaclass is created, we look for any
    FieldProperty instances associated with the class and record
    those for use later on.

    """

    def __new__(cls, name, bases, dct):
        #  find all the placeholders in this class declaration and store
        # them away.  Add name mangling to the original fields so they
        # do not get in the way.
        dct['_field_placeholders'] = {}
        dct['_crc_field'] = None
        for key, value in list(dct.items()):  # use a copy, we mutate dct
            if isinstance(value, FieldPlaceholder):
                if issubclass(value.cls, FieldAccessor):
                    # Wrap the accessor in a simple FieldProperty,
                    # so that the following usage model is supported:
                    #
                    #     class S(Structure):
                    #         bits = BitField(...)
                    #         segment = bits.segment
                    value = FieldProperty(value)

                dct['_field_placeholders'][key] = value
                dct['__%s' % key] = value
                del dct[key]
                if value.cls == CRCField:
                    dct['_crc_field'] = value

        sorted_fields = list(sorted(dct['_field_placeholders'].items(),
                                    key=lambda kv: kv[1]._field_seqno))
        dct['_sorted_fields'] = sorted_fields
        return type.__new__(cls, name, bases, dct)


@six.add_metaclass(StructureMeta)
class Structure(object):
    r"""Base class for message schema declaration

    ``Structure`` forms the core of the Suitcase library and allows for
    a declarative syntax for specifying packet schemas and associated
    methods for transforming these schemas into packed bytes (and vice-versa).

    Here's an example showing how one might specify the format for a UDP
    Datagram::


        >>> from suitcase.fields import UBInt16, LengthField, VariableRawPayload
        >>> class UDPDatagram(Structure):
        ...     source_port = UBInt16()
        ...     destination_port = UBInt16()
        ...     length = LengthField(UBInt16())
        ...     checksum = UBInt16()
        ...     data = VariableRawPayload(length)

    From this we have a near-ideal form for packing and parsing packet
    data following the schema::

        >>> def printb(s):
        ...     print(repr(s).replace("b'", "'").replace("u'", "'"))
        ...
        >>> dgram = UDPDatagram()
        >>> dgram.source_port = 9110
        >>> dgram.destination_port = 1001
        >>> dgram.checksum = 27193
        >>> dgram.data = b"Hello, world!"
        >>> printb(dgram.pack())
        '#\x96\x03\xe9\x00\rj9Hello, world!'
        >>> dgram2 = UDPDatagram()
        >>> dgram2.unpack(dgram.pack())
        >>> dgram2
        UDPDatagram (
          source_port=9110,
          destination_port=1001,
          length=13,
          checksum=27193,
          data=...'Hello, world!',
        )

    Initialization via keyword argument is also supported::

        >>> dgram = UDPDatagram(source_port=9110,
        ...                     destination_port=1001,
        ...                     checksum=27193,
        ...                     data=b"Hello, world!")
        ...
        >>> printb(dgram.pack())
        '#\x96\x03\xe9\x00\rj9Hello, world!'
        >>> dgram2 = UDPDatagram()
        >>> dgram2.unpack(dgram.pack())
        >>> dgram2
        UDPDatagram (
          source_port=9110,
          destination_port=1001,
          length=13,
          checksum=27193,
          data=...'Hello, world!',
        )

    """

    @classmethod
    def from_data(cls, data):
        """Create a new, populated message from some data

        This factory method is identical to doing the following, it just takes
        one line instead of two and looks nicer in general::

            m = MyMessage()
            m.unpack(data)

        Can be rewritten as just::

            m = MyMessage.from_data(data)

        """
        m = cls()
        m.unpack(data)
        return m

    def __init__(self, **kwargs):
        self._key_to_field = {}
        self._parent = None
        self._sorted_fields = []
        self._placeholder_to_field = {}
        if self.__class__._crc_field is None:
            self._crc_field = None
        else:
            self._crc_field = self.__class__._crc_field.create_instance(self)
        for key, field_placeholder in self.__class__._sorted_fields:
            field = field_placeholder.create_instance(self)
            self._key_to_field[key] = field
            self._placeholder_to_field[field_placeholder] = field
            self._sorted_fields.append((key, field))
        self._packer = Packer(self._sorted_fields, self._crc_field)
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __getattr__(self, key):
        k2f = self.__dict__.get('_key_to_field', {})
        if key in k2f:
            field = self._key_to_field[key]
            return field.getval()
        raise AttributeError

    def __setattr__(self, key, value):
        k2f = self.__dict__.get('_key_to_field', {})
        if key in k2f:
            field = self._key_to_field[key]
            return field.setval(value)
        return object.__setattr__(self, key, value)

    def __dir__(self):
        return dir(type(self)) + [str(k) for k, v in self._sorted_fields]

    def __iter__(self):
        return iter(self._sorted_fields)

    def __repr__(self):
        output = "%s (\n" % self.__class__.__name__
        for field_name, field in self:
            output += "  %s=%s,\n" % (field_name, field)
        output += ")"
        return output

    def lookup_field_by_name(self, name):
        for fname, field in self:
            if name == fname:
                return field
        raise KeyError

    def lookup_field_by_placeholder(self, placeholder):
        return self._placeholder_to_field[placeholder]

    def unpack(self, data, trailing=False):
        # type: (bytes, bool) -> BytesIO
        # If we asked to unpack while leaving any trailing bytes,
        # make sure to specify it's okay for there to be trailing bytes.
        return self._packer.unpack(data, trailing=trailing)

    def pack(self):
        # type: () -> bytes
        return self._packer.pack()