bitcoin/bitcoin

View on GitHub
contrib/seeds/asmap.py

Summary

Maintainability
F
5 days
Test Coverage
# Copyright (c) 2022 Pieter Wuille
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.

"""
This module provides the ASNEntry and ASMap classes.
"""

import copy
import ipaddress
import random
import unittest
from collections.abc import Callable, Iterable
from enum import Enum
from functools import total_ordering
from typing import Optional, Union, overload

def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> list[bool]:
    """
    Convert an IPv4 or IPv6 network to a prefix represented as a list of bits.

    IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96).
    """
    num_bits = net.prefixlen
    netrange = int.from_bytes(net.network_address.packed, 'big')

    # Map an IPv4 prefix into IPv6 space.
    if isinstance(net, ipaddress.IPv4Network):
        num_bits += 96
        netrange += 0xffff00000000

    # Strip unused bottom bits.
    assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0
    return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)]

def prefix_to_net(prefix: list[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]:
    """The reverse operation of net_to_prefix."""
    # Convert to number
    netrange = sum(b << (127 - i) for i, b in enumerate(prefix))
    num_bits = len(prefix)
    assert num_bits <= 128

    # Return IPv4 range if in ::ffff:0:0/96
    if num_bits >= 96 and (netrange >> 32) == 0xffff:
        return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True)

    # Return IPv6 range otherwise.
    return ipaddress.IPv6Network((netrange, num_bits), True)

# Shortcut for (prefix, ASN) entries.
ASNEntry = tuple[list[bool], int]

# Shortcut for (prefix, old ASN, new ASN) entries.
ASNDiff = tuple[list[bool], int, int]

class _VarLenCoder:
    """
    A class representing a custom variable-length binary encoder/decoder for
    integers. Each object represents a different coder, with different parameters
    minval and clsbits.

    The encoding is easiest to describe using an example. Let's say minval=100 and
    clsbits=[4,2,2,3]. In that case:
    - x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)].
    - x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)].
    - x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)].
    - x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)].

    In general, every number is encoded as:
    - First, k "1"-bits, where k is the class the number falls in (there is one class
      per element of clsbits).
    - Then, a "0"-bit, unless k is the highest class, in which case there is nothing.
    - Lastly, clsbits[k] bits encoding in big endian the position in its class that
      number falls into.
    - Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval,
      other classes start one past the last element of the class before it.
    """

    def __init__(self, minval: int, clsbits: list[int]):
        """Construct a new _VarLenCoder."""
        self._minval = minval
        self._clsbits = clsbits
        self._maxval = minval + sum(1 << b for b in clsbits) - 1

    def can_encode(self, val: int) -> bool:
        """Check whether value val is in the range this coder supports."""
        return self._minval <= val <= self._maxval

    def encode(self, val: int, ret: list[int]) -> None:
        """Append encoding of val onto integer list ret."""

        assert self._minval <= val <= self._maxval
        val -= self._minval
        bits = 0
        for k, bits in enumerate(self._clsbits):
            if val >> bits:
                # If the value will not fit in class k, subtract its range from v,
                # emit a "1" bit and continue with the next class.
                val -= 1 << bits
                ret.append(1)
            else:
                if k + 1 < len(self._clsbits):
                    # Unless we're in the last class, emit a "0" bit.
                    ret.append(0)
                break
        # And then encode v (now the position within the class) in big endian.
        ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits))

    def encode_size(self, val: int) -> int:
        """Compute how many bits are needed to encode val."""
        assert self._minval <= val <= self._maxval
        val -= self._minval
        ret = 0
        bits = 0
        for k, bits in enumerate(self._clsbits):
            if val >> bits:
                val -= 1 << bits
                ret += 1
            else:
                ret += k + 1 < len(self._clsbits)
                break
        return ret + bits

    def decode(self, stream, bitpos) -> tuple[int,int]:
        """Decode a number starting at bitpos in stream, returning value and new bitpos."""
        val = self._minval
        bits = 0
        for k, bits in enumerate(self._clsbits):
            bit = 0
            if k + 1 < len(self._clsbits):
                bit = stream[bitpos]
                bitpos += 1
            if not bit:
                break
            val += 1 << bits
        for i in range(bits):
            bit = stream[bitpos]
            bitpos += 1
            val += bit << (bits - 1 - i)
        return val, bitpos

# Variable-length encoders used in the binary asmap format.
_CODER_INS = _VarLenCoder(0, [0, 0, 1])
_CODER_ASN = _VarLenCoder(1, list(range(15, 25)))
_CODER_MATCH = _VarLenCoder(2, list(range(1, 9)))
_CODER_JUMP = _VarLenCoder(17, list(range(5, 31)))

class _Instruction(Enum):
    """One instruction in the binary asmap format."""
    # A return instruction, encoded as [0], returns a constant ASN. It is followed by
    # an integer using the ASN encoding.
    RETURN = 0
    # A jump instruction, encoded as [1,0] inspects the next unused bit in the input
    # and either continues execution (if 0), or skips a specified number of bits (if 1).
    # It is followed by an integer, and then two subprograms. The integer uses jump encoding
    # and corresponds to the length of the first subprogram (so it can be skipped).
    JUMP = 1
    # A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits
    # in the input with its argument. If they all match, execution continues. If they do
    # not, failure is returned. If a default instruction has been executed before, instead
    # of failure the default instruction's argument is returned. It is followed by an
    # integer in match encoding, and a subprogram. That value is at least 2 bits and at
    # most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower
    # (n-1) bits in the match value.
    MATCH = 2
    # A default instruction, encoded as [1,1,1] sets the default variable to its argument,
    # and continues execution. It is followed by an integer in ASN encoding, and a subprogram.
    DEFAULT = 3
    # Not an actual instruction, but a way to encode the empty program that fails. In the
    # encoder, it is used more generally to represent the failure case inside MATCH instructions,
    # which may (if used inside the context of a DEFAULT instruction) actually correspond to
    # a successful return. In this usage, they're always converted to an actual MATCH or RETURN
    # before the top level is reached (see make_default below).
    END = 4

class _BinNode:
    """A class representing a (node of) the parsed binary asmap format."""

    @overload
    def __init__(self, ins: _Instruction): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: int): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ...

    def __init__(self, ins: _Instruction, arg1=None, arg2=None):
        """
        Construct a new asmap node. Possibilities are:
        - _BinNode(_Instruction.RETURN, asn)
        - _BinNode(_Instruction.JUMP, node_0, node_1)
        - _BinNode(_Instruction.MATCH, val, node)
        - _BinNode(_Instruction.DEFAULT, asn, node)
        - _BinNode(_Instruction.END)
        """
        self.ins = ins
        self.arg1 = arg1
        self.arg2 = arg2
        if ins == _Instruction.RETURN:
            assert isinstance(arg1, int)
            assert arg2 is None
            self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1)
        elif ins == _Instruction.JUMP:
            assert isinstance(arg1, _BinNode)
            assert isinstance(arg2, _BinNode)
            self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) +
                         arg1.size + arg2.size)
        elif ins == _Instruction.DEFAULT:
            assert isinstance(arg1, int)
            assert isinstance(arg2, _BinNode)
            self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size
        elif ins == _Instruction.MATCH:
            assert isinstance(arg1, int)
            assert isinstance(arg2, _BinNode)
            self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1)
                         + arg2.size)
        elif ins == _Instruction.END:
            assert arg1 is None
            assert arg2 is None
            self.size = 0
        else:
            assert False

    @staticmethod
    def make_end() -> "_BinNode":
        """Constructor for a _BinNode with just an END instruction."""
        return _BinNode(_Instruction.END)

    @staticmethod
    def make_leaf(val: int) -> "_BinNode":
        """Constructor for a _BinNode of just a RETURN instruction."""
        assert val is not None and val > 0
        return _BinNode(_Instruction.RETURN, val)

    @staticmethod
    def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode":
        """
        Construct a _BinNode corresponding to running either the node0 or node1 subprogram,
        based on the next input bit. It exploits shortcuts that are possible in the encoding,
        and uses either a JUMP, MATCH, or END instruction.
        """
        if node0.ins == _Instruction.END and node1.ins == _Instruction.END:
            return node0
        if node0.ins == _Instruction.END:
            if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF:
                return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2)
            return _BinNode(_Instruction.MATCH, 3, node1)
        if node1.ins == _Instruction.END:
            if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF:
                return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)),
                                node0.arg2)
            return _BinNode(_Instruction.MATCH, 2, node0)
        return _BinNode(_Instruction.JUMP, node0, node1)

    @staticmethod
    def make_default(val: int, sub: "_BinNode") -> "_BinNode":
        """
        Construct a _BinNode that corresponds to the specified subprogram, with the specified
        default value. It exploits shortcuts that are possible in the encoding, and will use
        either a DEFAULT or a RETURN instruction."""
        assert val is not None and val > 0
        if sub.ins == _Instruction.END:
            return _BinNode(_Instruction.RETURN, val)
        if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT):
            return sub
        return _BinNode(_Instruction.DEFAULT, val, sub)

@total_ordering
class ASMap:
    """
    A class whose objects represent a mapping from subnets to ASNs.

    Internally the mapping is stored as a binary trie, but can be converted
    from/to a list of ASNEntry objects, and from/to the binary asmap file format.

    In the trie representation, nodes are represented as bare lists for efficiency
    and ease of manipulation:
    - [0] means an unassigned subnet (no ASN mapping for it is present)
    - [int] means a subnet mapped entirely to the specified ASN.
    - [node,node] means a subnet whose lower half and upper half have different
    -             mappings, represented by new trie nodes.
    """

    def update(self, prefix: list[bool], asn: int) -> None:
        """Update this ASMap object to map prefix to the specified asn."""
        assert asn == 0 or _CODER_ASN.can_encode(asn)

        def recurse(node: list, offset: int) -> None:
            if offset == len(prefix):
                # Reached the end of prefix; overwrite this node.
                node.clear()
                node.append(asn)
                return
            if len(node) == 1:
                # Need to descend into a leaf node; split it up.
                oldasn = node[0]
                node.clear()
                node.append([oldasn])
                node.append([oldasn])
            # Descend into the node.
            recurse(node[prefix[offset]], offset + 1)
            # If the result is two identical leaf children, merge them.
            if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]:
                oldasn = node[0][0]
                node.clear()
                node.append(oldasn)
        recurse(self._trie, 0)

    def update_multi(self, entries: list[tuple[list[bool], int]]) -> None:
        """Apply multiple update operations, where longer prefixes take precedence."""
        entries.sort(key=lambda entry: len(entry[0]))
        for prefix, asn in entries:
            self.update(prefix, asn)

    def _set_trie(self, trie) -> None:
        """Set trie directly. Internal use only."""
        def recurse(node: list) -> None:
            if len(node) < 2:
                return
            recurse(node[0])
            recurse(node[1])
            if len(node[0]) == 2:
                return
            if node[0] == node[1]:
                if len(node[0]) == 0:
                    node.clear()
                else:
                    asn = node[0][0]
                    node.clear()
                    node.append(asn)
        recurse(trie)
        self._trie = trie

    def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None:
        """Construct an ASMap object from an optional list of entries."""
        self._trie = [0]
        if entries is not None:
            def entry_key(entry):
                """Sort function that places shorter prefixes first."""
                prefix, asn = entry
                return len(prefix), prefix, asn
            for prefix, asn in sorted(entries, key=entry_key):
                self.update(prefix, asn)

    def lookup(self, prefix: list[bool]) -> Optional[int]:
        """Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate."""
        node = self._trie
        for bit in prefix:
            if len(node) == 1:
                break
            node = node[bit]
        if len(node) == 1:
            return node[0]
        return None

    def _to_entries_flat(self, fill: bool = False) -> list[ASNEntry]:
        """Convert an ASMap object to a list of non-overlapping (prefix, asn) objects."""
        prefix : list[bool] = []

        def recurse(node: list) -> list[ASNEntry]:
            ret = []
            if len(node) == 1:
                if node[0] > 0:
                    ret = [(list(prefix), node[0])]
            elif len(node) == 2:
                prefix.append(False)
                ret = recurse(node[0])
                prefix[-1] = True
                ret += recurse(node[1])
                prefix.pop()
                if fill and len(ret) > 1:
                    asns = set(x[1] for x in ret)
                    if len(asns) == 1:
                        ret = [(list(prefix), list(asns)[0])]
            return ret
        return recurse(self._trie)

    def _to_entries_minimal(self, fill: bool = False) -> list[ASNEntry]:
        """Convert a trie to a minimal list of ASNEntry objects, exploiting overlap."""
        prefix : list[bool] = []

        def recurse(node: list) -> (tuple[dict[Optional[int], list[ASNEntry]], bool]):
            if len(node) == 1 and node[0] == 0:
                return {None if fill else 0: []}, True
            if len(node) == 1:
                return {node[0]: [], None: [(list(prefix), node[0])]}, False
            ret: dict[Optional[int], list[ASNEntry]] = {}
            prefix.append(False)
            left, lhole = recurse(node[0])
            prefix[-1] = True
            right, rhole = recurse(node[1])
            prefix.pop()
            hole = not fill and (lhole or rhole)
            def candidate(ctx: Optional[int], res0: Optional[list[ASNEntry]],
                    res1: Optional[list[ASNEntry]]):
                if res0 is not None and res1 is not None:
                    if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]):
                        ret[ctx] = res0 + res1
            for ctx in set(left) | set(right):
                candidate(ctx, left.get(ctx), right.get(ctx))
                candidate(ctx, left.get(None), right.get(ctx))
                candidate(ctx, left.get(ctx), right.get(None))
            if not hole:
                for ctx in list(ret):
                    if ctx is not None:
                        candidate(None, [(list(prefix), ctx)], ret[ctx])
            if None in ret:
                ret = {ctx:entries for ctx, entries in ret.items()
                       if ctx is None or len(entries) < len(ret[None])}
            if hole:
                ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0}
            return ret, hole
        res, _ = recurse(self._trie)
        return res[0] if 0 in res else res[None]

    def __str__(self) -> str:
        """Convert this ASMap object to a string containing Python code constructing it."""
        return f"ASMap({self._trie})"

    def to_entries(self, overlapping: bool = True, fill: bool = False) -> list[ASNEntry]:
        """
        Convert the mappings in this ASMap object to a list of ASNEntry objects.

        Arguments:
            overlapping: Permit the subnets in the resulting ASNEntry to overlap.
                         Setting this can result in a shorter list.
            fill:        Permit the resulting ASNEntry objects to cover subnets that
                         are unassigned in this ASMap object. Setting this can
                         result in a shorter list.
        """
        if overlapping:
            return self._to_entries_minimal(fill)
        return self._to_entries_flat(fill)

    @staticmethod
    def from_random(num_leaves: int = 10, max_asn: int = 6,
                    unassigned_prob: float = 0.5) -> "ASMap":
        """
        Construct a random ASMap object, with specified:
         - Number of leaves in its trie (at least 1)
         - Maximum ASN value (at least 1)
         - Probability for leaf nodes to be unassigned

        The number of leaves in the resulting object may be less than what is
        requested. This method is mostly intended for testing.
        """
        assert num_leaves >= 1
        assert max_asn >= 1 or unassigned_prob == 1
        assert _CODER_ASN.can_encode(max_asn)
        assert 0.0 <= unassigned_prob <= 1.0
        trie: list = []
        leaves = [trie]
        ret = ASMap()
        for i in range(1, num_leaves):
            idx = random.randrange(i)
            leaf = leaves[idx]
            lastleaf = leaves.pop()
            if idx + 1 < i:
                leaves[idx] = lastleaf
            leaf.append([])
            leaf.append([])
            leaves.append(leaf[0])
            leaves.append(leaf[1])
        for leaf in leaves:
            if random.random() >= unassigned_prob:
                leaf.append(random.randrange(1, max_asn + 1))
            else:
                leaf.append(0)
        #pylint: disable=protected-access
        ret._set_trie(trie)
        return ret

    def _to_binnode(self, fill: bool = False) -> _BinNode:
        """Convert a trie to a _BinNode object."""
        def recurse(node: list) -> tuple[dict[Optional[int], _BinNode], bool]:
            if len(node) == 1 and node[0] == 0:
                return {(None if fill else 0): _BinNode.make_end()}, True
            if len(node) == 1:
                return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False
            ret: dict[Optional[int], _BinNode] = {}
            left, lhole = recurse(node[0])
            right, rhole = recurse(node[1])
            hole = (lhole or rhole) and not fill

            def candidate(ctx: Optional[int], arg1, arg2, func: Callable):
                if arg1 is not None and arg2 is not None:
                    cand = func(arg1, arg2)
                    if ctx not in ret or cand.size < ret[ctx].size:
                        ret[ctx] = cand

            for ctx in set(left) | set(right):
                candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch)
                candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch)
                candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch)
            if not hole:
                for ctx in set(ret) - set([None]):
                    candidate(None, ctx, ret[ctx], _BinNode.make_default)
            if None in ret:
                ret = {ctx:enc for ctx, enc in ret.items()
                       if ctx is None or enc.size < ret[None].size}
            if hole:
                ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0}
            return ret, hole
        res, _ = recurse(self._trie)
        return res[0] if 0 in res else res[None]

    @staticmethod
    def _from_binnode(binnode: _BinNode) -> "ASMap":
        """Construct an ASMap object from a _BinNode. Internal use only."""
        def recurse(node: _BinNode, default: int) -> list:
            if node.ins == _Instruction.RETURN:
                return [node.arg1]
            if node.ins == _Instruction.JUMP:
                return [recurse(node.arg1, default), recurse(node.arg2, default)]
            if node.ins == _Instruction.MATCH:
                val = node.arg1
                sub = recurse(node.arg2, default)
                while val >= 2:
                    bit = val & 1
                    val >>= 1
                    if bit:
                        sub = [[default], sub]
                    else:
                        sub = [sub, [default]]
                return sub
            assert node.ins == _Instruction.DEFAULT
            return recurse(node.arg2, node.arg1)
        ret = ASMap()
        if binnode.ins != _Instruction.END:
            #pylint: disable=protected-access
            ret._set_trie(recurse(binnode, 0))
        return ret

    def to_binary(self, fill: bool = False) -> bytes:
        """
        Convert this ASMap object to binary.

        Argument:
            fill: permit the resulting binary encoder to contain mappers for
                  unassigned subnets in this ASMap object. Doing so may
                  reduce the size of the encoding.
        Returns:
            A bytes object with the encoding of this ASMap object.
        """
        bits: list[int] = []

        def recurse(node: _BinNode) -> None:
            _CODER_INS.encode(node.ins.value, bits)
            if node.ins == _Instruction.RETURN:
                _CODER_ASN.encode(node.arg1, bits)
            elif node.ins == _Instruction.JUMP:
                _CODER_JUMP.encode(node.arg1.size, bits)
                recurse(node.arg1)
                recurse(node.arg2)
            elif node.ins == _Instruction.DEFAULT:
                _CODER_ASN.encode(node.arg1, bits)
                recurse(node.arg2)
            else:
                assert node.ins == _Instruction.MATCH
                _CODER_MATCH.encode(node.arg1, bits)
                recurse(node.arg2)

        binnode = self._to_binnode(fill)
        if binnode.ins != _Instruction.END:
            recurse(binnode)

        val = 0
        nbits = 0
        ret = []
        for bit in bits:
            val += (bit << nbits)
            nbits += 1
            if nbits == 8:
                ret.append(val)
                val = 0
                nbits = 0
        if nbits:
            ret.append(val)
        return bytes(ret)

    @staticmethod
    def from_binary(bindata: bytes) -> Optional["ASMap"]:
        """Decode an ASMap object from the provided binary encoding."""

        bits: list[int] = []
        for byte in bindata:
            bits.extend((byte >> i) & 1 for i in range(8))

        def recurse(bitpos: int) -> tuple[_BinNode, int]:
            insval, bitpos = _CODER_INS.decode(bits, bitpos)
            ins = _Instruction(insval)
            if ins == _Instruction.RETURN:
                asn, bitpos = _CODER_ASN.decode(bits, bitpos)
                return _BinNode(ins, asn), bitpos
            if ins == _Instruction.JUMP:
                jump, bitpos = _CODER_JUMP.decode(bits, bitpos)
                left, bitpos1 = recurse(bitpos)
                if bitpos1 != bitpos + jump:
                    raise ValueError("Inconsistent jump")
                right, bitpos = recurse(bitpos1)
                return _BinNode(ins, left, right), bitpos
            if ins == _Instruction.MATCH:
                match, bitpos = _CODER_MATCH.decode(bits, bitpos)
                sub, bitpos = recurse(bitpos)
                return _BinNode(ins, match, sub), bitpos
            assert ins == _Instruction.DEFAULT
            asn, bitpos = _CODER_ASN.decode(bits, bitpos)
            sub, bitpos = recurse(bitpos)
            return _BinNode(ins, asn, sub), bitpos

        if len(bits) == 0:
            binnode = _BinNode(_Instruction.END)
        else:
            try:
                binnode, bitpos = recurse(0)
            except (ValueError, IndexError):
                return None
            if bitpos < len(bits) - 7:
                return None
            if not all(bit == 0 for bit in bits[bitpos:]):
                return None

        return ASMap._from_binnode(binnode)

    def __lt__(self, other: "ASMap") -> bool:
        return self._trie < other._trie

    def __eq__(self, other: object) -> bool:
        if isinstance(other, ASMap):
            return self._trie == other._trie
        return False

    def extends(self, req: "ASMap") -> bool:
        """Determine whether this matches req for all subranges where req is assigned."""
        def recurse(actual: list, require: list) -> bool:
            if len(require) == 1 and require[0] == 0:
                return True
            if len(require) == 1:
                if len(actual) == 1:
                    return bool(require[0] == actual[0])
                return recurse(actual[0], require) and recurse(actual[1], require)
            if len(actual) == 2:
                return recurse(actual[0], require[0]) and recurse(actual[1], require[1])
            return recurse(actual, require[0]) and recurse(actual, require[1])
        assert isinstance(req, ASMap)
        #pylint: disable=protected-access
        return recurse(self._trie, req._trie)

    def diff(self, other: "ASMap") -> list[ASNDiff]:
        """Compute the diff from self to other."""
        prefix: list[bool] = []
        ret: list[ASNDiff] = []

        def recurse(old_node: list, new_node: list):
            if len(old_node) == 1 and len(new_node) == 1:
                if old_node[0] != new_node[0]:
                    ret.append((list(prefix), old_node[0], new_node[0]))
            else:
                old_left: list = old_node if len(old_node) == 1 else old_node[0]
                old_right: list = old_node if len(old_node) == 1 else old_node[1]
                new_left: list = new_node if len(new_node) == 1 else new_node[0]
                new_right: list = new_node if len(new_node) == 1 else new_node[1]
                prefix.append(False)
                recurse(old_left, new_left)
                prefix[-1] = True
                recurse(old_right, new_right)
                prefix.pop()
        assert isinstance(other, ASMap)
        #pylint: disable=protected-access
        recurse(self._trie, other._trie)
        return ret

    def __copy__(self) -> "ASMap":
        """Construct a copy of this ASMap object. Its state will not be shared."""
        ret = ASMap()
        #pylint: disable=protected-access
        ret._set_trie(copy.deepcopy(self._trie))
        return ret

    def __deepcopy__(self, _) -> "ASMap":
        # ASMap objects do not allow sharing of the _trie member, so we don't need the memoization.
        return self.__copy__()


class TestASMap(unittest.TestCase):
    """Unit tests for this module."""

    def test_ipv6_prefix_roundtrips(self) -> None:
        """Test that random IPv6 network ranges roundtrip through prefix encoding."""
        for _ in range(20):
            net_bits = random.getrandbits(128)
            for prefix_len in range(0, 129):
                masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len)
                net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len))
                prefix = net_to_prefix(net)
                self.assertTrue(len(prefix) <= 128)
                net2 = prefix_to_net(prefix)
                self.assertEqual(net, net2)

    def test_ipv4_prefix_roundtrips(self) -> None:
        """Test that random IPv4 network ranges roundtrip through prefix encoding."""
        for _ in range(100):
            net_bits = random.getrandbits(32)
            for prefix_len in range(0, 33):
                masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len)
                net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len))
                prefix = net_to_prefix(net)
                self.assertTrue(32 <= len(prefix) <= 128)
                net2 = prefix_to_net(prefix)
                self.assertEqual(net, net2)

    def test_asmap_roundtrips(self) -> None:
        """Test case that verifies random ASMap objects roundtrip to/from entries/binary."""
        # Iterate over the number of leaves the random test ASMap objects have.
        for leaves in range(1, 20):
            # Iterate over the number of bits in the AS numbers used.
            for asnbits in range(0, 24):
                # Iterate over the probability that leaves are unassigned.
                for pct in range(101):
                    # Construct a random ASMap object according to the above parameters.
                    asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
                                              unassigned_prob=0.01 * pct)
                    # Run tests for to_entries and construction from those entries, both
                    # for overlapping and non-overlapping ones.
                    for overlapping in [False, True]:
                        entries = asmap.to_entries(overlapping=overlapping, fill=False)
                        random.shuffle(entries)
                        asmap2 = ASMap(entries)
                        assert asmap2 is not None
                        self.assertEqual(asmap2, asmap)
                        entries = asmap.to_entries(overlapping=overlapping, fill=True)
                        random.shuffle(entries)
                        asmap2 = ASMap(entries)
                        assert asmap2 is not None
                        self.assertTrue(asmap2.extends(asmap))

                    # Run tests for to_binary and construction from binary.
                    enc = asmap.to_binary(fill=False)
                    asmap3 = ASMap.from_binary(enc)
                    assert asmap3 is not None
                    self.assertEqual(asmap3, asmap)
                    enc = asmap.to_binary(fill=True)
                    asmap3 = ASMap.from_binary(enc)
                    assert asmap3 is not None
                    self.assertTrue(asmap3.extends(asmap))

    def test_patching(self) -> None:
        """Test behavior of update, lookup, extends, and diff."""
        #pylint: disable=too-many-locals,too-many-nested-blocks
        # Iterate over the number of leaves the random test ASMap objects have.
        for leaves in range(1, 20):
            # Iterate over the number of bits in the AS numbers used.
            for asnbits in range(0, 10):
                # Iterate over the probability that leaves are unassigned.
                for pct in range(0, 101):
                    # Construct a random ASMap object according to the above parameters.
                    asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
                                              unassigned_prob=0.01 * pct)
                    # Make a copy of that asmap object to which patches will be applied.
                    # It starts off being equal to asmap.
                    patched = copy.copy(asmap)
                    # Keep a list of patches performed.
                    patches: list[ASNEntry] = []
                    # Initially there cannot be any difference.
                    self.assertEqual(asmap.diff(patched), [])
                    # Make 5 patches, each building on top of the previous ones.
                    for _ in range(0, 5):
                        # Construct a random path and new ASN to assign it to, apply it to patched,
                        # and remember it in patches.
                        pathlen = random.randrange(5)
                        path = [random.getrandbits(1) != 0 for _ in range(pathlen)]
                        newasn = random.randrange(1 + (1 << asnbits))
                        patched.update(path, newasn)
                        patches = [(path, newasn)] + patches

                        # Compute the diff, and whether asmap extends patched, and the other way
                        # around.
                        diff = asmap.diff(patched)
                        self.assertEqual(asmap == patched, len(diff) == 0)
                        extends = asmap.extends(patched)
                        back_extends = patched.extends(asmap)
                        # Determine whether those extends results are consistent with the diff
                        # result.
                        self.assertEqual(extends, all(d[2] == 0 for d in diff))
                        self.assertEqual(back_extends, all(d[1] == 0 for d in diff))
                        # For every diff found:
                        for path, old_asn, new_asn in diff:
                            # Verify asmap and patched actually differ there.
                            self.assertTrue(old_asn != new_asn)
                            self.assertEqual(asmap.lookup(path), old_asn)
                            self.assertEqual(patched.lookup(path), new_asn)
                            for _ in range(2):
                                # Extend the path far enough that it's smaller than any mapped
                                # range, and check the lookup holds there too.
                                spec_path = list(path)
                                while len(spec_path) < 32:
                                    spec_path.append(random.getrandbits(1) != 0)
                                self.assertEqual(asmap.lookup(spec_path), old_asn)
                                self.assertEqual(patched.lookup(spec_path), new_asn)
                                # Search through the list of performed patches to find the last one
                                # applying to the extended path (note that patches is in reverse
                                # order, so the first match should work).
                                found = False
                                for patch_path, patch_asn in patches:
                                    if spec_path[:len(patch_path)] == patch_path:
                                        # When found, it must match whatever the result was patched
                                        # to.
                                        self.assertEqual(new_asn, patch_asn)
                                        found = True
                                        break
                                # And such a patch must exist.
                                self.assertTrue(found)

if __name__ == '__main__':
    unittest.main()