tlsfuzzer/tlslite-ng

View on GitHub
tlslite/utils/brotlidecpy/decode.py

Summary

Maintainability
F
6 days
Test Coverage
F
52%
# Copyright 2021 Sidney Markowitz All Rights Reserved.
# Distributed under MIT license.
# See file LICENSE for detail or copy at https://opensource.org/licenses/MIT

from .huffman import HuffmanCode, brotli_build_huffman_table
from .prefix import Prefix, kBlockLengthPrefixCode, kInsertLengthPrefixCode, \
    kCopyLengthPrefixCode
from .bit_reader import BrotliBitReader
from .dictionary import BrotliDictionary
from .context import Context
from .transform import Transform, kNumTransforms

kDefaultCodeLength = 8
kCodeLengthRepeatCode = 16
kNumLiteralCodes = 256
kNumInsertAndCopyCodes = 704
kNumBlockLengthCodes = 26
kLiteralContextBits = 6
kDistanceContextBits = 2

HUFFMAN_TABLE_BITS = 8
HUFFMAN_TABLE_MASK = 0xff
#  Maximum possible Huffman table size for an alphabet size of 704, max code
# length 15 and root table bits 8.
HUFFMAN_MAX_TABLE_SIZE = 1080

CODE_LENGTH_CODES = 18
kCodeLengthCodeOrder = bytearray([
    1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15])

NUM_DISTANCE_SHORT_CODES = 16
kDistanceShortCodeIndexOffset = bytearray([
    3, 2, 1, 0, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2])

kDistanceShortCodeValueOffset = [
    0, 0, 0, 0, -1, 1, -2, 2, -3, 3, -1, 1, -2, 2, -3, 3
]

kMaxHuffmanTableSize = [
    256, 402, 436, 468, 500, 534, 566, 598, 630, 662, 694, 726, 758, 790, 822,
    854, 886, 920, 952, 984, 1016, 1048, 1080
]


def decode_window_bits(br):
    if br.read_bits(1) == 0:
        return 16
    n = br.read_bits(3)
    if n > 0:
        return 17 + n
    n = br.read_bits(3)
    if n > 0:
        return 8 + n
    return 17


def decode_var_len_uint8(br):
    """Decodes a number in the range [0..255], by reading 1 - 11 bits"""
    if br.read_bits(1):
        nbits = br.read_bits(3)
        if nbits == 0:
            return 1
        return br.read_bits(nbits) + (1 << nbits)
    return 0


class MetaBlockLength:
    def __init__(self):
        self.meta_block_length = 0
        self.input_end = 0
        self.is_uncompressed = 0
        self.is_metadata = False


def decode_meta_block_length(br):
    out = MetaBlockLength()
    out.input_end = br.read_bits(1)
    if out.input_end and br.read_bits(1):
        return out

    size_nibbles = br.read_bits(2) + 4
    if size_nibbles == 7:
        out.is_metadata = True

        if br.read_bits(1) != 0:
            raise Exception('Invalid reserved bit')

        size_bytes = br.read_bits(2)
        if size_bytes == 0:
            return out

        for i in range(0, size_bytes):
            next_byte = br.read_bits(8)
            if i + 1 == size_bytes and size_bytes > 1 and next_byte == 0:
                raise Exception('Invalid size byte')

            out.meta_block_length |= next_byte << (i * 8)
    else:
        for i in range(0, size_nibbles):
            next_nibble = br.read_bits(4)
            if i + 1 == size_nibbles and size_nibbles > 4 and next_nibble == 0:
                raise Exception('Invalid size nibble')

            out.meta_block_length |= next_nibble << (i * 4)

    out.meta_block_length += 1

    if not out.input_end and not out.is_metadata:
        out.is_uncompressed = br.read_bits(1)

    return out


def read_symbol(table, index, br):
    """
    Decodes the next Huffman code from bit-stream. table is array of nodes in a
    huffman tree, index points to root
    """
    # The C reference version assumes 15 is the max needed and uses 16 in this
    # function
    x_bits = br.read_bits(16, 0)
    index += (x_bits & HUFFMAN_TABLE_MASK)
    nbits = table[index].bits - HUFFMAN_TABLE_BITS
    skip = 0
    if nbits > 0:
        skip = HUFFMAN_TABLE_BITS
        index += table[index].value + ((
            x_bits >> HUFFMAN_TABLE_BITS) & br.kBitMask[nbits])
    br.read_bits(None, skip + table[index].bits)
    return table[index].value


def read_huffman_code_lengths(code_length_code_lengths, num_symbols,
                              code_lengths, br):
    symbol = 0
    prev_code_len = kDefaultCodeLength
    repeat = 0
    repeat_code_len = 0
    space = 32768

    table = [HuffmanCode(0, 0) for _ in range(0, 32)]

    brotli_build_huffman_table(table, 0, 5, code_length_code_lengths,
                               CODE_LENGTH_CODES)

    while (symbol < num_symbols) and (space > 0):
        p = 0
        p += br.read_bits(5, 0)
        br.read_bits(None, table[p].bits)
        code_len = table[p].value & 0xff
        if code_len < kCodeLengthRepeatCode:
            repeat = 0
            code_lengths[symbol] = code_len
            symbol += 1
            if code_len != 0:
                prev_code_len = code_len
                space -= 0x8000 >> code_len
        else:
            extra_bits = code_len - 14
            new_len = 0
            if code_len == kCodeLengthRepeatCode:
                new_len = prev_code_len
            if repeat_code_len != new_len:
                repeat = 0
                repeat_code_len = new_len
            old_repeat = repeat
            if repeat > 0:
                repeat -= 2
                repeat <<= extra_bits
            repeat += br.read_bits(extra_bits) + 3
            repeat_delta = repeat - old_repeat
            if symbol + repeat_delta > num_symbols:
                raise Exception(
                    '[read_huffman_code_lengths] symbol + repeat_delta > ' +
                    'num_symbols'
                )

            for x in range(0, repeat_delta):
                code_lengths[symbol + x] = repeat_code_len

            symbol += repeat_delta

            if repeat_code_len != 0:
                space -= repeat_delta << (15 - repeat_code_len)

    if space != 0:
        raise Exception('[read_huffman_code_lengths] space = %s' % space)

    for i in range(symbol, num_symbols):
        code_lengths[i] = 0


def read_huffman_code(alphabet_size, tables, table, br):
    code_lengths = bytearray([0] * alphabet_size)

    # simple_code_or_skip is used as follows:
    # 1 for simple code
    # 0 for no skipping, 2 skips 2 code lengths, 3 skips 3 code lengths
    simple_code_or_skip = br.read_bits(2)
    if simple_code_or_skip == 1:
        # Read symbols, codes & code lengths directly.
        max_bits_counter = alphabet_size - 1
        max_bits = 0
        symbols = [0, 0, 0, 0]
        num_symbols = br.read_bits(2) + 1
        while max_bits_counter:
            max_bits_counter >>= 1
            max_bits += 1

        for i in range(0, num_symbols):
            symbols[i] = br.read_bits(max_bits) % alphabet_size
            code_lengths[symbols[i]] = 2
        code_lengths[symbols[0]] = 1

        if num_symbols == 2:
            if symbols[0] == symbols[1]:
                raise Exception('[read_huffman_code] invalid symbols')
            code_lengths[symbols[1]] = 1
        elif num_symbols == 3:
            if (
                symbols[0] == symbols[1] or symbols[0] == symbols[2] or
                symbols[1] == symbols[2]
            ):
                raise Exception('[read_huffman_code] invalid symbols')
        elif num_symbols == 4:
            if (
                symbols[0] == symbols[1] or symbols[0] == symbols[2] or
                symbols[0] == symbols[3] or symbols[1] == symbols[2] or
                symbols[1] == symbols[3] or symbols[2] == symbols[3]
            ):
                raise Exception('[read_huffman_code] invalid symbols')
            if br.read_bits(1):
                code_lengths[symbols[2]] = 3
                code_lengths[symbols[3]] = 3
            else:
                code_lengths[symbols[0]] = 2
    else:  # Decode Huffman-coded code lengths
        code_length_code_lengths = bytearray([0] * CODE_LENGTH_CODES)
        space = 32
        num_codes = 0
        # Static Huffman code for the code length code lengths
        huff = [
            HuffmanCode(2, 0), HuffmanCode(2, 4), HuffmanCode(2, 3),
            HuffmanCode(3, 2), HuffmanCode(2, 0), HuffmanCode(2, 4),
            HuffmanCode(2, 3), HuffmanCode(4, 1), HuffmanCode(2, 0),
            HuffmanCode(2, 4), HuffmanCode(2, 3), HuffmanCode(3, 2),
            HuffmanCode(2, 0), HuffmanCode(2, 4), HuffmanCode(2, 3),
            HuffmanCode(4, 5)
        ]
        for i in range(simple_code_or_skip, CODE_LENGTH_CODES):
            if space <= 0:
                break
            code_len_idx = kCodeLengthCodeOrder[i]
            p = 0
            p += br.read_bits(4, 0)
            br.read_bits(None, huff[p].bits)
            v = huff[p].value
            code_length_code_lengths[code_len_idx] = v
            if v != 0:
                space -= (32 >> v)
                num_codes += 1

        if num_codes != 1 and space != 0:
            raise Exception('[read_huffman_code] invalid num_codes or space')

        read_huffman_code_lengths(code_length_code_lengths, alphabet_size,
                                  code_lengths, br)

    table_size = brotli_build_huffman_table(tables, table, HUFFMAN_TABLE_BITS,
                                            code_lengths, alphabet_size)

    if table_size == 0:
        raise Exception('[read_huffman_code] BuildHuffmanTable failed: ')

    return table_size


def read_block_length(table, index, br):
    code = read_symbol(table, index, br)
    nbits = kBlockLengthPrefixCode[code].nbits
    return kBlockLengthPrefixCode[code].offset + br.read_bits(nbits)


def translate_short_codes(code, ringbuffer, index):
    if code < NUM_DISTANCE_SHORT_CODES:
        index += kDistanceShortCodeIndexOffset[code]
        index &= 3
        val = ringbuffer[index] + kDistanceShortCodeValueOffset[code]
    else:
        val = code - NUM_DISTANCE_SHORT_CODES + 1
    return val


def move_to_front(v, index):
    v.insert(0, v.pop(index))


def inverse_move_to_front_transform(v, v_len):
    mtf = list(range(0, 256))
    for i in range(0, v_len):
        index = v[i]
        v[i] = mtf[index]
        if index:
            move_to_front(mtf, index)


# Contains a collection of huffman trees with the same alphabet size.
class HuffmanTreeGroup:
    def __init__(self, alphabet_size, num_huff_trees):
        self.alphabet_size = alphabet_size
        self.num_huff_trees = num_huff_trees
        self.codes = [0] * (
            num_huff_trees + num_huff_trees *
            kMaxHuffmanTableSize[(alphabet_size + 31) >> 5]
        )
        self.huff_trees = [0] * num_huff_trees

    def decode(self, br):
        next_entry = 0
        for i in range(0, self.num_huff_trees):
            self.huff_trees[i] = next_entry
            table_size = read_huffman_code(self.alphabet_size, self.codes,
                                           next_entry, br)
            next_entry += table_size


class DecodeContextMap:
    def __init__(self, context_map_size, br):
        max_run_length_prefix = 0
        self.num_huff_trees = decode_var_len_uint8(br) + 1
        self.context_map = bytearray([0] * context_map_size)

        if self.num_huff_trees <= 1:
            return

        use_rle_for_zeros = br.read_bits(1)
        if use_rle_for_zeros:
            max_run_length_prefix = br.read_bits(4) + 1

        table = [HuffmanCode(0, 0) for _ in range(0, HUFFMAN_MAX_TABLE_SIZE)]

        read_huffman_code(self.num_huff_trees + max_run_length_prefix,
                          table, 0, br)

        i = 0
        while i < context_map_size:
            code = read_symbol(table, 0, br)
            if code == 0:
                self.context_map[i] = 0
                i += 1
            elif code <= max_run_length_prefix:
                for reps in range((1 << code) + br.read_bits(code), 0, -1):
                    if i >= context_map_size:
                        raise Exception(
                            '[DecodeContextMap] i >= context_map_size')
                    self.context_map[i] = 0
                    i += 1
            else:
                self.context_map[i] = code - max_run_length_prefix
                i += 1
        if br.read_bits(1):
            inverse_move_to_front_transform(self.context_map, context_map_size)


def decode_block_type(max_block_type, trees, tree_type, block_types,
                      ring_buffers, indexes, br):
    ringbuffer = tree_type * 2
    index = tree_type
    type_code = read_symbol(trees, tree_type * HUFFMAN_MAX_TABLE_SIZE, br)
    if type_code == 0:
        block_type = ring_buffers[ringbuffer + (indexes[index] & 1)]
    elif type_code == 1:
        block_type = ring_buffers[ringbuffer + ((indexes[index] - 1) & 1)] + 1
    else:
        block_type = type_code - 2
    if block_type >= max_block_type:
        block_type -= max_block_type
    block_types[tree_type] = block_type
    ring_buffers[ringbuffer + (indexes[index] & 1)] = block_type
    indexes[index] += 1


def copy_uncompressed_block_to_output(length, pos, output_buffer, br):
    """
    This only is called when input is on a byte boundary. Copy length raw bytes
    from input to output[pos]
    """
    br.copy_bytes(output_buffer, pos, length)


def jump_to_byte_boundary(br):
    """
    Advances the bit reader position if needed to put it on a byte boundary
    """
    br.copy_bytes(b'', 0, 0)


def brotli_decompress_buffer(input_buffer, buffer_limit=None):
    br = BrotliBitReader(input_buffer)
    output_buffer = bytearray([])
    pos = 0
    input_end = 0
    max_distance = 0
    # This ring buffer holds a few past copy distances that will be used by
    # some special distance codes.
    dist_rb = [16, 15, 11, 4]
    dist_rb_idx = 0
    hgroup = [
        HuffmanTreeGroup(0, 0), HuffmanTreeGroup(0, 0), HuffmanTreeGroup(0, 0)]

    # Decode window size.
    window_bits = decode_window_bits(br)
    max_backward_distance = (1 << window_bits) - 16

    block_type_trees = [
        HuffmanCode(0, 0) for _ in range(0, 3 * HUFFMAN_MAX_TABLE_SIZE)]
    block_len_trees = [
        HuffmanCode(0, 0) for _ in range(0, 3 * HUFFMAN_MAX_TABLE_SIZE)]

    while not input_end:
        block_length = [1 << 28, 1 << 28, 1 << 28]
        block_type = [0] * 3
        num_block_types = [1] * 3
        block_type_rb = [0, 1, 0, 1, 0, 1]
        block_type_rb_index = [0] * 3

        if buffer_limit is not None and len(output_buffer) > buffer_limit:
            raise BufferError(
                "Trying to obtain buffer larger than {0}".format(buffer_limit)
            )

        for i in range(0, 3):
            hgroup[i].codes = None
            hgroup[i].huff_trees = None

        _out = decode_meta_block_length(br)
        meta_block_remaining_len = _out.meta_block_length
        input_end = _out.input_end
        is_uncompressed = _out.is_uncompressed

        if _out.is_metadata:
            jump_to_byte_boundary(br)

            while meta_block_remaining_len > 0:
                # Read one byte and ignore it
                br.read_bits(8)
                meta_block_remaining_len -= 1
            continue

        if meta_block_remaining_len == 0:
            continue

        if len(output_buffer) < (pos + meta_block_remaining_len):
            output_buffer.extend(bytearray([0] * meta_block_remaining_len))

        if is_uncompressed:
            copy_uncompressed_block_to_output(meta_block_remaining_len, pos,
                                              output_buffer, br)
            pos += meta_block_remaining_len
            continue

        for i in range(0, 3):
            num_block_types[i] = decode_var_len_uint8(br) + 1
            if num_block_types[i] >= 2:
                read_huffman_code(num_block_types[i] + 2, block_type_trees,
                                  i * HUFFMAN_MAX_TABLE_SIZE, br)
                read_huffman_code(kNumBlockLengthCodes, block_len_trees,
                                  i * HUFFMAN_MAX_TABLE_SIZE, br)
                block_length[i] = read_block_length(block_len_trees,
                                                    i * HUFFMAN_MAX_TABLE_SIZE,
                                                    br)
                block_type_rb_index[i] = 1

        distance_postfix_bits = br.read_bits(2)
        num_direct_distance_codes = NUM_DISTANCE_SHORT_CODES + (
            br.read_bits(4) << distance_postfix_bits)
        distance_postfix_mask = (1 << distance_postfix_bits) - 1
        num_distance_codes = (num_direct_distance_codes + (
            48 << distance_postfix_bits))
        context_modes = bytearray([0] * num_block_types[0])

        for i in range(0, num_block_types[0]):
            context_modes[i] = (br.read_bits(2) << 1)

        _o1 = DecodeContextMap(num_block_types[0] << kLiteralContextBits, br)
        num_literal_huff_trees = _o1.num_huff_trees
        context_map = _o1.context_map

        _o2 = DecodeContextMap(num_block_types[2] << kDistanceContextBits, br)
        num_dist_huff_trees = _o2.num_huff_trees
        dist_context_map = _o2.context_map

        hgroup[0] = HuffmanTreeGroup(kNumLiteralCodes, num_literal_huff_trees)
        hgroup[1] = HuffmanTreeGroup(
            kNumInsertAndCopyCodes, num_block_types[1])
        hgroup[2] = HuffmanTreeGroup(num_distance_codes, num_dist_huff_trees)

        for i in range(0, 3):
            hgroup[i].decode(br)

        context_map_slice = 0
        dist_context_map_slice = 0
        context_mode = context_modes[block_type[0]]
        context_lookup_offset1 = Context.lookupOffsets[context_mode]
        context_lookup_offset2 = Context.lookupOffsets[context_mode + 1]
        huff_tree_command = hgroup[1].huff_trees[0]

        while meta_block_remaining_len > 0:

            if block_length[1] == 0:
                decode_block_type(num_block_types[1], block_type_trees, 1,
                                  block_type, block_type_rb,
                                  block_type_rb_index, br)
                block_length[1] = read_block_length(
                    block_len_trees, HUFFMAN_MAX_TABLE_SIZE, br)
                huff_tree_command = hgroup[1].huff_trees[block_type[1]]
            block_length[1] -= 1
            cmd_code = read_symbol(hgroup[1].codes, huff_tree_command, br)
            range_idx = cmd_code >> 6
            distance_code = 0
            if range_idx >= 2:
                range_idx -= 2
                distance_code = -1
            insert_code = Prefix.kInsertRangeLut[range_idx] + (
                (cmd_code >> 3) & 7)
            copy_code = Prefix.kCopyRangeLut[range_idx] + (cmd_code & 7)
            insert_length = kInsertLengthPrefixCode[insert_code].offset + \
                br.read_bits(kInsertLengthPrefixCode[insert_code].nbits)
            copy_length = kCopyLengthPrefixCode[copy_code].offset + \
                br.read_bits(kCopyLengthPrefixCode[copy_code].nbits)
            prev_byte1 = output_buffer[pos - 1] if pos > 0 else 0
            prev_byte2 = output_buffer[pos - 2] if pos > 1 else 0
            for j in range(0, insert_length):
                if block_length[0] == 0:
                    decode_block_type(num_block_types[0], block_type_trees, 0,
                                      block_type, block_type_rb,
                                      block_type_rb_index, br)
                    block_length[0] = read_block_length(block_len_trees, 0, br)
                    context_offset = block_type[0] << kLiteralContextBits
                    context_map_slice = context_offset
                    context_mode = context_modes[block_type[0]]
                    context_lookup_offset1 = Context.lookupOffsets[
                        context_mode]
                    context_lookup_offset2 = Context.lookupOffsets[
                        context_mode + 1]
                context = Context.lookup[context_lookup_offset1 + prev_byte1] \
                    | Context.lookup[context_lookup_offset2 + prev_byte2]
                literal_huff_tree_index = context_map[
                    context_map_slice + context]
                block_length[0] -= 1
                prev_byte2 = prev_byte1
                prev_byte1 = read_symbol(
                    hgroup[0].codes,
                    hgroup[0].huff_trees[literal_huff_tree_index], br)
                output_buffer[pos] = prev_byte1
                pos += 1
            meta_block_remaining_len -= insert_length
            if meta_block_remaining_len <= 0:
                break

            if distance_code < 0:
                if block_length[2] == 0:
                    decode_block_type(num_block_types[2], block_type_trees, 2,
                                      block_type, block_type_rb,
                                      block_type_rb_index, br)
                    block_length[2] = read_block_length(
                        block_len_trees, 2 * HUFFMAN_MAX_TABLE_SIZE, br)
                    dist_context_offset = block_type[2] << kDistanceContextBits
                    dist_context_map_slice = dist_context_offset
                block_length[2] -= 1
                context = (3 if copy_length > 4 else copy_length - 2) & 0xff
                dist_huff_tree_index = dist_context_map[
                    dist_context_map_slice + context]
                distance_code = read_symbol(
                    hgroup[2].codes,
                    hgroup[2].huff_trees[dist_huff_tree_index], br)
                if distance_code >= num_direct_distance_codes:
                    distance_code -= num_direct_distance_codes
                    postfix = distance_code & distance_postfix_mask
                    distance_code >>= distance_postfix_bits
                    nbits = (distance_code >> 1) + 1
                    offset = ((2 + (distance_code & 1)) << nbits) - 4
                    distance_code = num_direct_distance_codes + (
                            (offset + br.read_bits(nbits))
                            << distance_postfix_bits) + postfix

            # Convert distance code to actual distance by possibly looking up
            # past distances from the ringbuffer
            distance = translate_short_codes(distance_code, dist_rb,
                                             dist_rb_idx)
            if distance < 0:
                raise Exception('[brotli_decompress] invalid distance')

            if (pos < max_backward_distance
                    and max_distance != max_backward_distance):
                max_distance = pos
            else:
                max_distance = max_backward_distance

            copy_dst = pos

            if distance > max_distance:
                if BrotliDictionary.minDictionaryWordLength <= copy_length <= \
                        BrotliDictionary.maxDictionaryWordLength:
                    offset = BrotliDictionary.offsetsByLength[copy_length]
                    word_id = distance - max_distance - 1
                    shift = BrotliDictionary.sizeBitsByLength[copy_length]
                    mask = (1 << shift) - 1
                    word_idx = word_id & mask
                    transform_idx = word_id >> shift
                    offset += word_idx * copy_length
                    if transform_idx < kNumTransforms:
                        length = Transform.transformDictionaryWord(
                            output_buffer, copy_dst, offset, copy_length,
                            transform_idx)
                        copy_dst += length
                        pos += length
                        meta_block_remaining_len -= length
                    else:
                        raise Exception(
                            ("Invalid backward reference. pos: %s " +
                             "distance: %s len: %s bytes left: %s") % (
                             pos, distance, copy_length,
                             meta_block_remaining_len))
                else:
                    raise Exception(("Invalid backward reference. pos: %s " +
                                     "distance: %s len: %s bytes left: %s") % (
                                     pos, distance, copy_length,
                                     meta_block_remaining_len))
            else:
                if distance_code > 0:
                    dist_rb[dist_rb_idx & 3] = distance
                    dist_rb_idx += 1

                if copy_length > meta_block_remaining_len:
                    raise Exception(("Invalid backward reference. pos: %s " +
                                     "distance: %s len: %s bytes left: %s") % (
                        pos, distance, copy_length, meta_block_remaining_len))

                # don't try to optimize with a slice. source and dest may
                # overlap
                for j in range(0, copy_length):
                    output_buffer[pos] = output_buffer[pos - distance]
                    pos += 1
                    meta_block_remaining_len -= 1
    return output_buffer