cea-sec/miasm

View on GitHub
miasm/core/cpu.py

Summary

Maintainability
F
2 wks
Test Coverage
#-*- coding:utf-8 -*-

from builtins import range
import re
import struct
import logging
from collections import defaultdict


from future.utils import viewitems, viewvalues

import pyparsing

from miasm.core.utils import decode_hex
import miasm.expression.expression as m2_expr
from miasm.core.bin_stream import bin_stream, bin_stream_str
from miasm.core.utils import Disasm_Exception
from miasm.expression.simplifications import expr_simp


from miasm.core.asm_ast import AstNode, AstInt, AstId, AstOp
from miasm.core import utils
from future.utils import with_metaclass

log = logging.getLogger("cpuhelper")
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter("[%(levelname)-8s]: %(message)s"))
log.addHandler(console_handler)
log.setLevel(logging.WARN)


class bitobj(object):

    def __init__(self, s=b""):
        if not s:
            bits = []
        else:
            bits = [int(x) for x in bin(int(encode_hex(s), 16))[2:]]
            if len(bits) % 8:
                bits = [0 for x in range(8 - (len(bits) % 8))] + bits
        self.bits = bits
        self.offset = 0

    def __len__(self):
        return len(self.bits) - self.offset

    def getbits(self, n):
        if not n:
            return 0
        if n > len(self.bits) - self.offset:
            raise ValueError('not enough bits %r %r' % (n, len(self.bits)))
        b = self.bits[self.offset:self.offset + n]
        b = int("".join(str(x) for x in b), 2)
        self.offset += n
        return b

    def putbits(self, b, n):
        if not n:
            return
        bits = list(bin(b)[2:])
        bits = [int(x) for x in bits]
        bits = [0 for x in range(n - len(bits))] + bits
        self.bits += bits

    def tostring(self):
        if len(self.bits) % 8:
            raise ValueError(
                'num bits must be 8 bit aligned: %d' % len(self.bits)
            )
        b = int("".join(str(x) for x in self.bits), 2)
        b = "%X" % b
        b = '0' * (len(self.bits) // 4 - len(b)) + b
        b = decode_hex(b.encode())
        return b

    def reset(self):
        self.offset = 0

    def copy_state(self):
        b = self.__class__()
        b.bits = self.bits
        b.offset = self.offset
        return b


def literal_list(l):
    l = l[:]
    l.sort()
    l = l[::-1]
    o = pyparsing.Literal(l[0])
    for x in l[1:]:
        o |= pyparsing.Literal(x)
    return o


class reg_info(object):

    def __init__(self, reg_str, reg_expr):
        self.str = reg_str
        self.expr = reg_expr
        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)

    def cb_parse(self, tokens):
        assert len(tokens) == 1
        i = self.str.index(tokens[0])
        reg = self.expr[i]
        result = AstId(reg)
        return result

    def reg2expr(self, s):
        i = self.str.index(s[0])
        return self.expr[i]

    def expr2regi(self, e):
        return self.expr.index(e)


class reg_info_dct(object):

    def __init__(self, reg_expr):
        self.dct_str_inv = dict((v.name, k) for k, v in viewitems(reg_expr))
        self.dct_expr = reg_expr
        self.dct_expr_inv = dict((v, k) for k, v in viewitems(reg_expr))
        reg_str = [v.name for v in viewvalues(reg_expr)]
        self.parser = literal_list(reg_str).setParseAction(self.cb_parse)

    def cb_parse(self, tokens):
        assert len(tokens) == 1
        i = self.dct_str_inv[tokens[0]]
        reg = self.dct_expr[i]
        result = AstId(reg)
        return result

    def reg2expr(self, s):
        i = self.dct_str_inv[s[0]]
        return self.dct_expr[i]

    def expr2regi(self, e):
        return self.dct_expr_inv[e]


def gen_reg(reg_name, sz=32):
    """Gen reg expr and parser"""
    reg = m2_expr.ExprId(reg_name, sz)
    reginfo = reg_info([reg_name], [reg])
    return reg, reginfo


def gen_reg_bs(reg_name, reg_info, base_cls):
    """
    Generate:
        class bs_reg_name(base_cls):
            reg = reg_info

        bs_reg_name = bs(l=0, cls=(bs_reg_name,))
    """

    bs_name = "bs_%s" % reg_name
    cls = type(bs_name, base_cls, {'reg': reg_info})

    bs_obj = bs(l=0, cls=(cls,))

    return cls, bs_obj


def gen_regs(rnames, env, sz=32):
    regs_str = []
    regs_expr = []
    regs_init = []
    for rname in rnames:
        r = m2_expr.ExprId(rname, sz)
        r_init = m2_expr.ExprId(rname+'_init', sz)
        regs_str.append(rname)
        regs_expr.append(r)
        regs_init.append(r_init)
        env[rname] = r

    reginfo = reg_info(regs_str, regs_expr)
    return regs_expr, regs_init, reginfo


LPARENTHESIS = pyparsing.Literal("(")
RPARENTHESIS = pyparsing.Literal(")")


def int2expr(tokens):
    v = tokens[0]
    return (m2_expr.ExprInt, v)


def parse_op(tokens):
    v = tokens[0]
    return (m2_expr.ExprOp, v)


def parse_id(tokens):
    v = tokens[0]
    return (m2_expr.ExprId, v)


def ast_parse_op(tokens):
    if len(tokens) == 1:
        return tokens[0]
    if len(tokens) == 2:
        if tokens[0] in ['-', '+', '!']:
            return m2_expr.ExprOp(tokens[0], tokens[1])
    if len(tokens) == 3:
        if tokens[1] == '-':
            # a - b => a + (-b)
            tokens[1] = '+'
            tokens[2] = - tokens[2]
        return m2_expr.ExprOp(tokens[1], tokens[0], tokens[2])
    tokens = tokens[::-1]
    while len(tokens) >= 3:
        o1, op, o2 = tokens.pop(), tokens.pop(), tokens.pop()
        if op == '-':
            # a - b => a + (-b)
            op = '+'
            o2 = - o2
        e = m2_expr.ExprOp(op, o1, o2)
        tokens.append(e)
    if len(tokens) != 1:
        raise NotImplementedError('strange op')
    return tokens[0]


def ast_id2expr(a):
    return m2_expr.ExprId(a, 32)


def ast_int2expr(a):
    return m2_expr.ExprInt(a, 32)


def neg_int(tokens):
    x = -tokens[0]
    return x


integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0]))
hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums)
hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16))

# str_int = (Optional('-') + (hex_int | integer))
str_int_pos = (hex_int | integer)
str_int_neg = (pyparsing.Suppress('-') + \
                   (hex_int | integer)).setParseAction(neg_int)

str_int = str_int_pos | str_int_neg
str_int.setParseAction(int2expr)

logicop = pyparsing.oneOf('& | ^ >> << <<< >>>')
signop = pyparsing.oneOf('+ -')
multop = pyparsing.oneOf('* / %')
plusop = pyparsing.oneOf('+ -')


##########################

def literal_list(l):
    l = l[:]
    l.sort()
    l = l[::-1]
    o = pyparsing.Literal(l[0])
    for x in l[1:]:
        o |= pyparsing.Literal(x)
    return o


def cb_int(tokens):
    assert len(tokens) == 1
    integer = AstInt(tokens[0])
    return integer


def cb_parse_id(tokens):
    assert len(tokens) == 1
    reg = tokens[0]
    return AstId(reg)


def cb_op_not(tokens):
    tokens = tokens[0]
    assert len(tokens) == 2
    assert tokens[0] == "!"
    result = AstOp("!", tokens[1])
    return result


def merge_ops(tokens, op):
    args = []
    if len(tokens) >= 3:
        args = [tokens.pop(0)]
        i = 0
        while i < len(tokens):
            op_tmp = tokens[i]
            arg = tokens[i+1]
            i += 2
            if op_tmp != op:
                raise ValueError("Bad operator")
            args.append(arg)
    result = AstOp(op, *args)
    return result


def cb_op_and(tokens):
    result = merge_ops(tokens[0], "&")
    return result


def cb_op_xor(tokens):
    result = merge_ops(tokens[0], "^")
    return result


def cb_op_sign(tokens):
    assert len(tokens) == 1
    op, value = tokens[0]
    return -value


def cb_op_div(tokens):
    tokens = tokens[0]
    assert len(tokens) == 3
    assert tokens[1] == "/"
    result = AstOp("/", tokens[0], tokens[2])
    return result


def cb_op_plusminus(tokens):
    tokens = tokens[0]
    if len(tokens) == 3:
        # binary op
        assert isinstance(tokens[0], AstNode)
        assert isinstance(tokens[2], AstNode)
        op, args = tokens[1], [tokens[0], tokens[2]]
    elif len(tokens) > 3:
        args = [tokens.pop(0)]
        i = 0
        while i < len(tokens):
            op = tokens[i]
            arg = tokens[i+1]
            i += 2
            if op == '-':
                arg = -arg
            elif op == '+':
                pass
            else:
                raise ValueError("Bad operator")
            args.append(arg)
        op = '+'
    else:
        raise ValueError("Parsing error")
    assert all(isinstance(arg, AstNode) for arg in args)
    result = AstOp(op, *args)
    return result


def cb_op_mul(tokens):
    tokens = tokens[0]
    assert len(tokens) == 3
    assert isinstance(tokens[0], AstNode)
    assert isinstance(tokens[2], AstNode)

    # binary op
    op, args = tokens[1], [tokens[0], tokens[2]]
    result = AstOp(op, *args)
    return result


integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda tokens: int(tokens[0]))
hex_word = pyparsing.Literal('0x') + pyparsing.Word(pyparsing.hexnums)
hex_int = pyparsing.Combine(hex_word).setParseAction(lambda tokens: int(tokens[0], 16))

str_int_pos = (hex_int | integer)

str_int = str_int_pos
str_int.setParseAction(cb_int)

notop = pyparsing.oneOf('!')
andop = pyparsing.oneOf('&')
orop = pyparsing.oneOf('|')
xorop = pyparsing.oneOf('^')
shiftop = pyparsing.oneOf('>> <<')
rotop = pyparsing.oneOf('<<< >>>')
signop = pyparsing.oneOf('+ -')
mulop = pyparsing.oneOf('*')
plusop = pyparsing.oneOf('+ -')
divop = pyparsing.oneOf('/')


variable = pyparsing.Word(pyparsing.alphas + "_$.", pyparsing.alphanums + "_")
variable.setParseAction(cb_parse_id)
operand = str_int | variable

base_expr = pyparsing.infixNotation(operand,
                               [(notop,   1, pyparsing.opAssoc.RIGHT, cb_op_not),
                                (andop, 2, pyparsing.opAssoc.RIGHT, cb_op_and),
                                (xorop, 2, pyparsing.opAssoc.RIGHT, cb_op_xor),
                                (signop,  1, pyparsing.opAssoc.RIGHT, cb_op_sign),
                                (mulop,  2, pyparsing.opAssoc.RIGHT, cb_op_mul),
                                (divop,  2, pyparsing.opAssoc.RIGHT, cb_op_div),
                                (plusop,  2, pyparsing.opAssoc.LEFT, cb_op_plusminus),
                                ])


default_prio = 0x1337


def isbin(s):
    return re.match(r'[0-1]+$', s)


def int2bin(i, l):
    s = '0' * l + bin(i)[2:]
    return s[-l:]


def myror32(v, r):
    return ((v & 0xFFFFFFFF) >> r) | ((v << (32 - r)) & 0xFFFFFFFF)


def myrol32(v, r):
    return ((v & 0xFFFFFFFF) >> (32 - r)) | ((v << r) & 0xFFFFFFFF)


class bs(object):
    all_new_c = {}
    prio = default_prio

    def __init__(self, strbits=None, l=None, cls=None,
                 fname=None, order=0, flen=None, **kargs):
        if fname is None:
            fname = hex(id(str((strbits, l, cls, fname, order, flen, kargs))))
        if strbits is None:
            strbits = ""  # "X"*l
        elif l is None:
            l = len(strbits)
        if strbits and isbin(strbits):
            value = int(strbits, 2)
        elif 'default_val' in kargs:
            value = int(kargs['default_val'], 2)
        else:
            value = None
        allbits = list(strbits)
        allbits.reverse()
        fbits = 0
        fmask = 0
        while allbits:
            a = allbits.pop()
            if a == " ":
                continue
            fbits <<= 1
            fmask <<= 1
            if a in '01':
                a = int(a)
                fbits |= a
                fmask |= 1
        lmask = (1 << l) - 1
        # gen conditional field
        if cls:
            for b in cls:
                if 'flen' in b.__dict__:
                    flen = getattr(b, 'flen')

        self.strbits = strbits
        self.l = l
        self.cls = cls
        self.fname = fname
        self.order = order
        self.fbits = fbits
        self.fmask = fmask
        self.flen = flen
        self.value = value
        self.kargs = kargs

    lmask = property(lambda self:(1 << self.l) - 1)

    def __getitem__(self, item):
        return getattr(self, item)

    def __repr__(self):
        o = self.__class__.__name__
        if self.fname:
            o += "_%s" % self.fname
        o += "_%(strbits)s" % self
        if self.cls:
            o += '_' + '_'.join([x.__name__ for x in self.cls])
        return o

    def gen(self, parent):
        c_name = 'nbsi'
        if self.cls:
            c_name += '_' + '_'.join([x.__name__ for x in self.cls])
            bases = list(self.cls)
        else:
            bases = []
        # bsi added at end of list
        # used to use first function of added class
        bases += [bsi]
        k = c_name, tuple(bases)
        if k in self.all_new_c:
            new_c = self.all_new_c[k]
        else:
            new_c = type(c_name, tuple(bases), {})
            self.all_new_c[k] = new_c
        c = new_c(parent,
                  self.strbits, self.l, self.cls,
                  self.fname, self.order, self.lmask, self.fbits,
                  self.fmask, self.value, self.flen, **self.kargs)
        return c

    def check_fbits(self, v):
        return v & self.fmask == self.fbits

    @classmethod
    def flen(cls, v):
        raise NotImplementedError('not fully functional')


class dum_arg(object):

    def __init__(self, e=None):
        self.expr = e


class bsopt(bs):

    def ispresent(self):
        return True


class bsi(object):

    def __init__(self, parent, strbits, l, cls, fname, order,
                 lmask, fbits, fmask, value, flen, **kargs):
        self.parent = parent
        self.strbits = strbits
        self.l = l
        self.cls = cls
        self.fname = fname
        self.order = order
        self.fbits = fbits
        self.fmask = fmask
        self.flen = flen
        self.value = value
        self.kargs = kargs
        self.__dict__.update(self.kargs)

    lmask = property(lambda self:(1 << self.l) - 1)

    def decode(self, v):
        self.value = v & self.lmask
        return True

    def encode(self):
        return True

    def clone(self):
        s = self.__class__(self.parent,
                           self.strbits, self.l, self.cls,
                           self.fname, self.order, self.lmask, self.fbits,
                           self.fmask, self.value, self.flen, **self.kargs)
        s.__dict__.update(self.kargs)
        if hasattr(self, 'expr'):
            s.expr = self.expr
        return s

    def __hash__(self):
        kargs = []
        for k, v in list(viewitems(self.kargs)):
            if isinstance(v, list):
                v = tuple(v)
            kargs.append((k, v))
        l = [self.strbits, self.l, self.cls,
             self.fname, self.order, self.lmask, self.fbits,
             self.fmask, self.value]  # + kargs

        return hash(tuple(l))


class bs_divert(object):
    prio = default_prio

    def __init__(self, **kargs):
        self.args = kargs

    def __getattr__(self, item):
        if item in self.__dict__:
            return self.__dict__[item]
        elif item in self.args:
            return self.args.get(item)
        else:
            raise AttributeError


class bs_name(bs_divert):
    prio = 1

    def divert(self, i, candidates):
        out = []
        for cls, _, bases, dct, fields in candidates:
            for new_name, value in viewitems(self.args['name']):
                nfields = fields[:]
                s = int2bin(value, self.args['l'])
                args = dict(self.args)
                args.update({'strbits': s})
                f = bs(**args)
                nfields[i] = f
                ndct = dict(dct)
                ndct['name'] = new_name
                out.append((cls, new_name, bases, ndct, nfields))
        return out


class bs_mod_name(bs_divert):
    prio = 2

    def divert(self, i, candidates):
        out = []
        for cls, _, bases, dct, fields in candidates:
            tab = self.args['mn_mod']
            if isinstance(tab, list):
                tmp = {}
                for j, v in enumerate(tab):
                    tmp[j] = v
                tab = tmp
            for value, new_name in viewitems(tab):
                nfields = fields[:]
                s = int2bin(value, self.args['l'])
                args = dict(self.args)
                args.update({'strbits': s})
                f = bs(**args)
                nfields[i] = f
                ndct = dict(dct)
                ndct['name'] = self.modname(ndct['name'], value)
                out.append((cls, new_name, bases, ndct, nfields))
        return out

    def modname(self, name, i):
        return name + self.args['mn_mod'][i]


class bs_cond(bsi):
    pass


class bs_swapargs(bs_divert):

    def divert(self, i, candidates):
        out = []
        for cls, name, bases, dct, fields in candidates:
            # args not permuted
            ndct = dict(dct)
            nfields = fields[:]
            # gen fix field
            f = gen_bsint(0, self.args['l'], self.args)
            nfields[i] = f
            out.append((cls, name, bases, ndct, nfields))

            # args permuted
            ndct = dict(dct)
            nfields = fields[:]
            ap = ndct['args_permut'][:]
            a = ap.pop(0)
            b = ap.pop(0)
            ndct['args_permut'] = [b, a] + ap
            # gen fix field
            f = gen_bsint(1, self.args['l'], self.args)
            nfields[i] = f

            out.append((cls, name, bases, ndct, nfields))
        return out


class m_arg(object):

    def fromstring(self, text, loc_db, parser_result=None):
        if parser_result:
            e, start, stop = parser_result[self.parser]
            self.expr = e
            return start, stop
        try:
            v, start, stop = next(self.parser.scanString(text))
        except StopIteration:
            return None, None
        arg = v[0]
        expr = self.asm_ast_to_expr(arg, loc_db)
        self.expr = expr
        return start, stop

    def asm_ast_to_expr(self, arg, loc_db, **kwargs):
        raise NotImplementedError("Virtual")


class m_reg(m_arg):
    prio = default_prio

    @property
    def parser(self):
        return self.reg.parser

    def decode(self, v):
        self.expr = self.reg.expr[0]
        return True

    def encode(self):
        return self.expr == self.reg.expr[0]


class reg_noarg(object):
    reg_info = None
    parser = None

    def fromstring(self, text, loc_db, parser_result=None):
        if parser_result:
            e, start, stop = parser_result[self.parser]
            self.expr = e
            return start, stop
        try:
            v, start, stop = next(self.parser.scanString(text))
        except StopIteration:
            return None, None
        arg = v[0]
        expr = self.parses_to_expr(arg, loc_db)
        self.expr = expr
        return start, stop

    def decode(self, v):
        v = v & self.lmask
        if v >= len(self.reg_info.expr):
            return False
        self.expr = self.reg_info.expr[v]
        return True

    def encode(self):
        if not self.expr in self.reg_info.expr:
            log.debug("cannot encode reg %r", self.expr)
            return False
        self.value = self.reg_info.expr.index(self.expr)
        return True

    def check_fbits(self, v):
        return v & self.fmask == self.fbits


class mn_prefix(object):
    pass


def swap16(v):
    return struct.unpack('<H', struct.pack('>H', v))[0]


def swap32(v):
    return struct.unpack('<I', struct.pack('>I', v))[0]


def perm_inv(p):
    o = [None for x in range(len(p))]
    for i, x in enumerate(p):
        o[x] = i
    return o


def gen_bsint(value, l, args):
    s = int2bin(value, l)
    args = dict(args)
    args.update({'strbits': s})
    f = bs(**args)
    return f

total_scans = 0


def branch2nodes(branch, nodes=None):
    if nodes is None:
        nodes = []
    for k, v in viewitems(branch):
        if not isinstance(v, dict):
            continue
        for k2 in v:
            nodes.append((k, k2))
        branch2nodes(v, nodes)


def factor_one_bit(tree):
    if isinstance(tree, set):
        return tree
    new_keys = defaultdict(lambda: defaultdict(dict))
    if len(tree) == 1:
        return tree
    for k, v in viewitems(tree):
        if k == "mn":
            new_keys[k] = v
            continue
        l, fmask, fbits, fname, flen = k
        if flen is not None or l <= 1:
            new_keys[k] = v
            continue
        cfmask = fmask >> (l - 1)
        nfmask = fmask & ((1 << (l - 1)) - 1)
        cfbits = fbits >> (l - 1)
        nfbits = fbits & ((1 << (l - 1)) - 1)
        ck = 1, cfmask, cfbits, None, flen
        nk = l - 1, nfmask, nfbits, fname, flen
        if nk in new_keys[ck]:
            raise NotImplementedError('not fully functional')
        new_keys[ck][nk] = v
    for k, v in list(viewitems(new_keys)):
        new_keys[k] = factor_one_bit(v)
    # try factor sons
    if len(new_keys) != 1:
        return new_keys
    subtree = next(iter(viewvalues(new_keys)))
    if len(subtree) != 1:
        return new_keys
    if next(iter(subtree)) == 'mn':
        return new_keys

    return new_keys


def factor_fields(tree):
    if not isinstance(tree, dict):
        return tree
    if len(tree) != 1:
        return tree
    # merge
    k1, v1 = next(iter(viewitems(tree)))
    if k1 == "mn":
        return tree
    l1, fmask1, fbits1, fname1, flen1 = k1
    if fname1 is not None:
        return tree
    if flen1 is not None:
        return tree

    if not isinstance(v1, dict):
        return tree
    if len(v1) != 1:
        return tree
    k2, v2 = next(iter(viewitems(v1)))
    if k2 == "mn":
        return tree
    l2, fmask2, fbits2, fname2, flen2 = k2
    if fname2 is not None:
        return tree
    if flen2 is not None:
        return tree
    l = l1 + l2
    fmask = (fmask1 << l2) | fmask2
    fbits = (fbits1 << l2) | fbits2
    fname = fname2
    flen = flen2
    k = l, fmask, fbits, fname, flen
    new_keys = {k: v2}
    return new_keys


def factor_fields_all(tree):
    if not isinstance(tree, dict):
        return tree
    new_keys = {}
    for k, v in viewitems(tree):
        v = factor_fields(v)
        new_keys[k] = factor_fields_all(v)
    return new_keys


def graph_tree(tree):
    nodes = []
    branch2nodes(tree, nodes)

    out = """
          digraph G {
          """
    for a, b in nodes:
        if b == 'mn':
            continue
        out += "%s -> %s;\n" % (id(a), id(b))
    out += "}"
    open('graph.txt', 'w').write(out)


def add_candidate_to_tree(tree, c):
    branch = tree
    for f in c.fields:
        if f.l == 0:
            continue
        node = f.l, f.fmask, f.fbits, f.fname, f.flen

        if not node in branch:
            branch[node] = {}
        branch = branch[node]
    if not 'mn' in branch:
        branch['mn'] = set()
    branch['mn'].add(c)


def add_candidate(bases, c):
    add_candidate_to_tree(bases[0].bintree, c)


def getfieldby_name(fields, fname):
    f = [x for x in fields if hasattr(x, 'fname') and x.fname == fname]
    if len(f) != 1:
        raise ValueError('more than one field with name: %s' % fname)
    return f[0]


def getfieldindexby_name(fields, fname):
    for i, f in enumerate(fields):
        if hasattr(f, 'fname') and f.fname == fname:
            return f, i
    return None


class metamn(type):

    def __new__(mcs, name, bases, dct):
        if name == "cls_mn" or name.startswith('mn_'):
            return type.__new__(mcs, name, bases, dct)
        alias = dct.get('alias', False)

        fields = bases[0].mod_fields(dct['fields'])
        if not 'name' in dct:
            dct["name"] = bases[0].getmn(name)
        if 'args' in dct:
            # special case for permuted arguments
            o = []
            p = []
            for i, a in enumerate(dct['args']):
                o.append((i, a))
                if a in fields:
                    p.append((fields.index(a), a))
            p.sort()
            p = [x[1] for x in p]
            p = [dct['args'].index(x) for x in p]
            dct['args_permut'] = perm_inv(p)
        # order fields
        f_ordered = [x for x in enumerate(fields)]
        f_ordered.sort(key=lambda x: (x[1].prio, x[0]))
        candidates = bases[0].gen_modes(mcs, name, bases, dct, fields)
        for i, fc in f_ordered:
            if isinstance(fc, bs_divert):
                candidates = fc.divert(i, candidates)
        for cls, name, bases, dct, fields in candidates:
            ndct = dict(dct)
            fields = [f for f in fields if f]
            ndct['fields'] = fields
            ndct['mn_len'] = sum([x.l for x in fields])
            c = type.__new__(cls, name, bases, ndct)
            c.alias = alias
            c.check_mnemo(fields)
            c.num = bases[0].num
            bases[0].num += 1
            bases[0].all_mn.append(c)
            mode = dct['mode']
            bases[0].all_mn_mode[mode].append(c)
            bases[0].all_mn_name[c.name].append(c)
            i = c()
            i.init_class()
            bases[0].all_mn_inst[c].append(i)
            add_candidate(bases, c)
            # gen byte lookup
            o = ""
            for f in i.fields_order:
                if not isinstance(f, bsi):
                    raise ValueError('f is not bsi')
                if f.l == 0:
                    continue
                o += f.strbits
        return c


class instruction(object):
    __slots__ = ["name", "mode", "args",
                 "l", "b", "offset", "data",
                 "additional_info", "delayslot"]

    def __init__(self, name, mode, args, additional_info=None):
        self.name = name
        self.mode = mode
        self.args = args
        self.additional_info = additional_info
        self.offset = None
        self.l = None
        self.b = None
        self.delayslot = 0

    def gen_args(self, args):
        out = ', '.join([str(x) for x in args])
        return out

    def __str__(self):
        return self.to_string()

    def to_string(self, loc_db=None):
        o = "%-10s " % self.name
        args = []
        for i, arg in enumerate(self.args):
            if not isinstance(arg, m2_expr.Expr):
                raise ValueError('zarb arg type')
            x = self.arg2str(arg, i, loc_db)
            args.append(x)
        o += self.gen_args(args)
        return o

    def to_html(self, loc_db=None):
        out = "%-10s " % self.name
        out = '<font color="%s">%s</font>' % (utils.COLOR_MNEMO, out)

        args = []
        for i, arg in enumerate(self.args):
            if not isinstance(arg, m2_expr.Expr):
                raise ValueError('zarb arg type')
            x = self.arg2html(arg, i, loc_db)
            args.append(x)
        out += self.gen_args(args)
        return out

    def get_asm_offset(self, expr):
        return m2_expr.ExprInt(self.offset, expr.size)

    def get_asm_next_offset(self, expr):
        return m2_expr.ExprInt(self.offset+self.l, expr.size)

    def resolve_args_with_symbols(self, loc_db):
        args_out = []
        for expr in self.args:
            # try to resolve symbols using loc_db (0 for default value)
            loc_keys = m2_expr.get_expr_locs(expr)
            fixed_expr = {}
            for exprloc in loc_keys:
                loc_key = exprloc.loc_key
                names = loc_db.get_location_names(loc_key)
                # special symbols
                if '$' in names:
                    fixed_expr[exprloc] = self.get_asm_offset(exprloc)
                    continue
                if '_' in names:
                    fixed_expr[exprloc] = self.get_asm_next_offset(exprloc)
                    continue
                arg_int = loc_db.get_location_offset(loc_key)
                if arg_int is not None:
                    fixed_expr[exprloc] = m2_expr.ExprInt(arg_int, exprloc.size)
                    continue
                if not names:
                    raise ValueError('Unresolved symbol: %r' % exprloc)

                offset = loc_db.get_location_offset(loc_key)
                if offset is None:
                    raise ValueError(
                        'The offset of loc_key "%s" cannot be determined' % names
                    )
                else:
                    # Fix symbol with its offset
                    size = exprloc.size
                    if size is None:
                        default_size = self.get_symbol_size(exprloc, loc_db)
                        size = default_size
                    value = m2_expr.ExprInt(offset, size)
                fixed_expr[exprloc] = value

            expr = expr.replace_expr(fixed_expr)
            expr = expr_simp(expr)
            args_out.append(expr)
        return args_out

    def get_info(self, c):
        return


class cls_mn(with_metaclass(metamn, object)):
    args_symb = []
    instruction = instruction
    # Block's offset alignment
    alignment = 1

    @classmethod
    def guess_mnemo(cls, bs, attrib, pre_dis_info, offset):
        candidates = []

        candidates = set()

        fname_values = pre_dis_info
        todo = [
            (dict(fname_values), branch, offset * 8)
            for branch in list(viewitems(cls.bintree))
        ]
        for fname_values, branch, offset_b in todo:
            (l, fmask, fbits, fname, flen), vals = branch

            if flen is not None:
                l = flen(attrib, fname_values)
            if l is not None:
                try:
                    v = cls.getbits(bs, attrib, offset_b, l)
                except IOError:
                    # Raised if offset is out of bound
                    continue
                offset_b += l
                if v & fmask != fbits:
                    continue
                if fname is not None and not fname in fname_values:
                    fname_values[fname] = v
            for nb, v in viewitems(vals):
                if 'mn' in nb:
                    candidates.update(v)
                else:
                    todo.append((dict(fname_values), (nb, v), offset_b))

        return [c for c in candidates]

    def reset_class(self):
        for f in self.fields_order:
            if f.strbits and isbin(f.strbits):
                f.value = int(f.strbits, 2)
            elif 'default_val' in f.kargs:
                f.value = int(f.kargs['default_val'], 2)
            else:
                f.value = None
            if f.fname:
                setattr(self, f.fname, f)

    def init_class(self):
        args = []
        fields_order = []
        to_decode = []
        off = 0
        for i, fc in enumerate(self.fields):
            f = fc.gen(self)
            f.offset = off
            off += f.l
            fields_order.append(f)
            to_decode.append((i, f))

            if isinstance(f, m_arg):
                args.append(f)
            if f.fname:
                setattr(self, f.fname, f)
        if hasattr(self, 'args_permut'):
            args = [args[self.args_permut[i]]
                    for i in range(len(self.args_permut))]
        to_decode.sort(key=lambda x: (x[1].order, x[0]))
        to_decode = [fields_order.index(f[1]) for f in to_decode]
        self.args = args
        self.fields_order = fields_order
        self.to_decode = to_decode

    def add_pre_dis_info(self, prefix=None):
        return True

    @classmethod
    def getbits(cls, bs, attrib, offset_b, l):
        return bs.getbits(offset_b, l)

    @classmethod
    def getbytes(cls, bs, offset, l):
        return bs.getbytes(offset, l)

    @classmethod
    def pre_dis(cls, v_o, attrib, offset):
        return {}, v_o, attrib, offset, 0

    def post_dis(self):
        return self

    @classmethod
    def check_mnemo(cls, fields):
        pass

    @classmethod
    def mod_fields(cls, fields):
        return fields

    @classmethod
    def dis(cls, bs_o, mode_o = None, offset=0):
        if not isinstance(bs_o, bin_stream):
            bs_o = bin_stream_str(bs_o)

        bs_o.enter_atomic_mode()

        offset_o = offset
        try:
            pre_dis_info, bs, mode, offset, prefix_len = cls.pre_dis(
                bs_o, mode_o, offset)
        except:
            bs_o.leave_atomic_mode()
            raise
        candidates = cls.guess_mnemo(bs, mode, pre_dis_info, offset)
        if not candidates:
            bs_o.leave_atomic_mode()
            raise Disasm_Exception('cannot disasm (guess) at %X' % offset)

        out = []
        out_c = []
        if hasattr(bs, 'getlen'):
            bs_l = bs.getlen()
        else:
            bs_l = len(bs)

        alias = False
        for c in candidates:
            log.debug("*" * 40, mode, c.mode)
            log.debug(c.fields)

            c = cls.all_mn_inst[c][0]

            c.reset_class()
            c.mode = mode

            if not c.add_pre_dis_info(pre_dis_info):
                continue

            todo = {}
            getok = True
            fname_values = dict(pre_dis_info)
            offset_b = offset * 8

            total_l = 0
            for i, f in enumerate(c.fields_order):
                if f.flen is not None:
                    l = f.flen(mode, fname_values)
                else:
                    l = f.l
                if l is not None:
                    total_l += l
                    f.l = l
                    f.is_present = True
                    log.debug("FIELD %s %s %s %s", f.__class__, f.fname,
                              offset_b, l)
                    if bs_l * 8 - offset_b < l:
                        getok = False
                        break
                    try:
                        bv = cls.getbits(bs, mode, offset_b, l)
                    except:
                        bs_o.leave_atomic_mode()
                        raise
                    offset_b += l
                    if not f.fname in fname_values:
                        fname_values[f.fname] = bv
                    todo[i] = bv
                else:
                    f.is_present = False
                    todo[i] = None

            if not getok:
                continue

            c.l = prefix_len + total_l // 8
            for i in c.to_decode:
                f = c.fields_order[i]
                if f.is_present:
                    ret = f.decode(todo[i])
                    if not ret:
                        log.debug("cannot decode %r", f)
                        break

            if not ret:
                continue
            for a in c.args:
                a.expr = expr_simp(a.expr)

            c.b = cls.getbytes(bs, offset_o, c.l)
            c.offset = offset_o
            c = c.post_dis()
            if c is None:
                continue
            c_args = [a.expr for a in c.args]
            instr = cls.instruction(c.name, mode, c_args,
                                    additional_info=c.additional_info())
            instr.l = prefix_len + total_l // 8
            instr.b = cls.getbytes(bs, offset_o, instr.l)
            instr.offset = offset_o
            instr.get_info(c)
            if c.alias:
                alias = True
            out.append(instr)
            out_c.append(c)

        bs_o.leave_atomic_mode()

        if not out:
            raise Disasm_Exception('cannot disasm at %X' % offset_o)
        if len(out) != 1:
            if not alias:
                log.warning('dis multiple args ret default')

            for i, o in enumerate(out_c):
                if o.alias:
                    return out[i]
            raise NotImplementedError(
                'Multiple disas: \n' +
                "\n".join(str(x) for x in out)
            )
        return out[0]

    @classmethod
    def fromstring(cls, text, loc_db, mode = None):
        global total_scans
        name = re.search(r'(\S+)', text).groups()
        if not name:
            raise ValueError('cannot find name', text)
        name = name[0]

        if not name in cls.all_mn_name:
            raise ValueError('unknown name', name)
        clist = [x for x in cls.all_mn_name[name]]
        out = []
        out_args = []
        parsers = defaultdict(dict)

        for cc in clist:
            for c in cls.get_cls_instance(cc, mode):
                args_expr = []
                args_str = text[len(name):].strip(' ')

                start = 0
                cannot_parse = False
                len_o = len(args_str)

                for i, f in enumerate(c.args):
                    start_i = len_o - len(args_str)
                    if type(f.parser) == tuple:
                        parser = f.parser
                    else:
                        parser = (f.parser,)
                    for p in parser:
                        if p in parsers[(i, start_i)]:
                            continue
                        try:
                            total_scans += 1
                            v, start, stop = next(p.scanString(args_str))
                        except StopIteration:
                            v, start, stop = [None], None, None
                        if start != 0:
                            v, start, stop = [None], None, None
                        if v != [None]:
                            v = f.asm_ast_to_expr(v[0], loc_db)
                        if v is None:
                            v, start, stop = [None], None, None
                        parsers[(i, start_i)][p] = v, start, stop
                    start, stop = f.fromstring(args_str, loc_db, parsers[(i, start_i)])
                    if start != 0:
                        log.debug("cannot fromstring %r", args_str)
                        cannot_parse = True
                        break
                    if f.expr is None:
                        raise NotImplementedError('not fully functional')
                    f.expr = expr_simp(f.expr)
                    args_expr.append(f.expr)
                    args_str = args_str[stop:].strip(' ')
                    if args_str.startswith(','):
                        args_str = args_str[1:]
                    args_str = args_str.strip(' ')
                if args_str:
                    cannot_parse = True
                if cannot_parse:
                    continue

                out.append(c)
                out_args.append(args_expr)
                break

        if len(out) == 0:
            raise ValueError('cannot fromstring %r' % text)
        if len(out) != 1:
            log.debug('fromstring multiple args ret default')
        c = out[0]
        c_args = out_args[0]

        instr = cls.instruction(c.name, mode, c_args,
                                additional_info=c.additional_info())
        return instr

    def dup_info(self, infos):
        return

    @classmethod
    def get_cls_instance(cls, cc, mode, infos=None):
        c = cls.all_mn_inst[cc][0]

        c.reset_class()
        c.add_pre_dis_info()
        c.dup_info(infos)

        c.mode = mode
        yield c

    @classmethod
    def asm(cls, instr, loc_db=None):
        """
        Re asm instruction by searching mnemo using name and args. We then
        can modify args and get the hex of a modified instruction
        """
        clist = cls.all_mn_name[instr.name]
        clist = [x for x in clist]
        vals = []
        candidates = []
        args = instr.resolve_args_with_symbols(loc_db)

        for cc in clist:

            for c in cls.get_cls_instance(
                cc, instr.mode, instr.additional_info):

                cannot_parse = False
                if len(c.args) != len(instr.args):
                    continue

                # only fix args expr
                for i in range(len(c.args)):
                    c.args[i].expr = args[i]

                v = c.value(instr.mode)
                if not v:
                    log.debug("cannot encode %r", c)
                    cannot_parse = True
                if cannot_parse:
                    continue
                vals += v
                candidates.append((c, v))
        if len(vals) == 0:
            raise ValueError(
                'cannot asm %r %r' %
                (instr.name, [str(x) for x in instr.args])
            )
        if len(vals) != 1:
            log.debug('asm multiple args ret default')

        vals = cls.filter_asm_candidates(instr, candidates)
        return vals

    @classmethod
    def filter_asm_candidates(cls, instr, candidates):
        o = []
        for _, v in candidates:
            o += v
        o.sort(key=len)
        return o

    def value(self, mode):
        todo = [(0, 0, [(x, self.fields_order[x]) for x in self.to_decode[::-1]])]

        result = []
        done = []

        while todo:
            index, cur_len, to_decode = todo.pop()
            # TEST XXX
            for _, f in to_decode:
                setattr(self, f.fname, f)
            if (index, [x[1].value for x in to_decode]) in done:
                continue
            done.append((index, [x[1].value for x in to_decode]))

            can_encode = True
            for i, f in to_decode[index:]:
                f.parent.l = cur_len
                ret = f.encode()
                if not ret:
                    log.debug('cannot encode %r', f)
                    can_encode = False
                    break

                if f.value is not None and f.l:
                    if f.value > f.lmask:
                        log.debug('cannot encode %r', f)
                        can_encode = False
                        break
                    cur_len += f.l
                index += 1
                if ret is True:
                    continue

                for _ in ret:
                    o = []
                    if ((index, cur_len, [xx[1].value for xx in to_decode]) in todo or
                        (index, cur_len, [xx[1].value for xx in to_decode]) in done):
                        raise NotImplementedError('not fully functional')

                    for p, f in to_decode:
                        fnew = f.clone()
                        o.append((p, fnew))
                    todo.append((index, cur_len, o))
                can_encode = False

                break
            if not can_encode:
                continue
            result.append(to_decode)

        return self.decoded2bytes(result)

    def encodefields(self, decoded):
        bits = bitobj()
        for _, f in decoded:
            setattr(self, f.fname, f)

            if f.value is None:
                continue
            bits.putbits(f.value, f.l)

        return bits.tostring()

    def decoded2bytes(self, result):
        if not result:
            return []

        out = []
        for decoded in result:
            decoded.sort()

            o = self.encodefields(decoded)
            if o is None:
                continue
            out.append(o)
        out = list(set(out))
        return out

    def gen_args(self, args):
        out = ', '.join([str(x) for x in args])
        return out

    def args2str(self):
        args = []
        for arg in self.args:
            # XXX todo test
            if not (isinstance(arg, m2_expr.Expr) or
                    isinstance(arg.expr, m2_expr.Expr)):
                raise ValueError('zarb arg type')
            x = str(arg)
            args.append(x)
        return args

    def __str__(self):
        o = "%-10s " % self.name
        args = []
        for arg in self.args:
            # XXX todo test
            if not (isinstance(arg, m2_expr.Expr) or
                    isinstance(arg.expr, m2_expr.Expr)):
                raise ValueError('zarb arg type')
            x = str(arg)
            args.append(x)

        o += self.gen_args(args)
        return o

    def parse_prefix(self, v):
        return 0

    def set_dst_symbol(self, loc_db):
        dst = self.getdstflow(loc_db)
        args = []
        for d in dst:
            if isinstance(d, m2_expr.ExprInt):
                l = loc_db.get_or_create_offset_location(int(d))

                a = m2_expr.ExprId(l.name, d.size)
            else:
                a = d
            args.append(a)
        self.args_symb = args

    def getdstflow(self, loc_db):
        return [self.args[0].expr]


class imm_noarg(object):
    intsize = 32
    intmask = (1 << intsize) - 1

    def int2expr(self, v):
        if (v & ~self.intmask) != 0:
            return None
        return m2_expr.ExprInt(v, self.intsize)

    def expr2int(self, e):
        if not isinstance(e, m2_expr.ExprInt):
            return None
        v = int(e)
        if v & ~self.intmask != 0:
            return None
        return v

    def fromstring(self, text, loc_db, parser_result=None):
        if parser_result:
            e, start, stop = parser_result[self.parser]
        else:
            try:
                e, start, stop = next(self.parser.scanString(text))
            except StopIteration:
                return None, None
        if e == [None]:
            return None, None

        assert(m2_expr.is_expr(e))
        self.expr = e
        if self.expr is None:
            log.debug('cannot fromstring int %r', text)
            return None, None
        return start, stop

    def decodeval(self, v):
        return v

    def encodeval(self, v):
        return v

    def decode(self, v):
        v = v & self.lmask
        v = self.decodeval(v)
        e = self.int2expr(v)
        if not e:
            return False
        self.expr = e
        return True

    def encode(self):
        v = self.expr2int(self.expr)
        if v is None:
            return False
        v = self.encodeval(v)
        if v is False:
            return False
        self.value = v
        return True


class imm08_noarg(object):
    int2expr = lambda self, x: m2_expr.ExprInt(x, 8)


class imm16_noarg(object):
    int2expr = lambda self, x: m2_expr.ExprInt(x, 16)


class imm32_noarg(object):
    int2expr = lambda self, x: m2_expr.ExprInt(x, 32)


class imm64_noarg(object):
    int2expr = lambda self, x: m2_expr.ExprInt(x, 64)


class int32_noarg(imm_noarg):
    intsize = 32
    intmask = (1 << intsize) - 1

    def decode(self, v):
        v = sign_ext(v, self.l, self.intsize)
        v = self.decodeval(v)
        self.expr = self.int2expr(v)
        return True

    def encode(self):
        if not isinstance(self.expr, m2_expr.ExprInt):
            return False
        v = int(self.expr)
        if sign_ext(v & self.lmask, self.l, self.intsize) != v:
            return False
        v = self.encodeval(v & self.lmask)
        if v is False:
            return False
        self.value = v & self.lmask
        return True

class bs8(bs):
    prio = default_prio

    def __init__(self, v, cls=None, fname=None, **kargs):
        super(bs8, self).__init__(int2bin(v, 8), 8,
                                  cls=cls, fname=fname, **kargs)




def swap_uint(size, i):
    if size == 8:
        return i & 0xff
    elif size == 16:
        return struct.unpack('<H', struct.pack('>H', i & 0xffff))[0]
    elif size == 32:
        return struct.unpack('<I', struct.pack('>I', i & 0xffffffff))[0]
    elif size == 64:
        return struct.unpack('<Q', struct.pack('>Q', i & 0xffffffffffffffff))[0]
    raise ValueError('unknown int len %r' % size)


def swap_sint(size, i):
    if size == 8:
        return i
    elif size == 16:
        return struct.unpack('<h', struct.pack('>H', i & 0xffff))[0]
    elif size == 32:
        return struct.unpack('<i', struct.pack('>I', i & 0xffffffff))[0]
    elif size == 64:
        return struct.unpack('<q', struct.pack('>Q', i & 0xffffffffffffffff))[0]
    raise ValueError('unknown int len %r' % size)


def sign_ext(v, s_in, s_out):
    assert(s_in <= s_out)
    v &= (1 << s_in) - 1
    sign_in = v & (1 << (s_in - 1))
    if not sign_in:
        return v
    m = (1 << (s_out)) - 1
    m ^= (1 << s_in) - 1
    v |= m
    return v