cea-sec/miasm

View on GitHub
miasm/core/modint.py

Summary

Maintainability
D
1 day
Test Coverage
#-*- coding:utf-8 -*-

from builtins import range
from functools import total_ordering

@total_ordering
class moduint(object):

    def __init__(self, arg):
        self.arg = int(arg) % self.__class__.limit
        assert(self.arg >= 0 and self.arg < self.__class__.limit)

    def __repr__(self):
        return self.__class__.__name__ + '(' + hex(self.arg) + ')'

    def __hash__(self):
        return hash(self.arg)

    @classmethod
    def maxcast(cls, c2):
        c2 = c2.__class__
        if cls.size > c2.size:
            return cls
        else:
            return c2

    def __eq__(self, y):
        if isinstance(y, moduint):
            return self.arg == y.arg
        return self.arg == y

    def __ne__(self, y):
        # required Python 2.7.14
        return not self == y

    def __lt__(self, y):
        if isinstance(y, moduint):
            return self.arg < y.arg
        return self.arg < y

    def __add__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg + y.arg)
        else:
            return self.__class__(self.arg + y)

    def __and__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg & y.arg)
        else:
            return self.__class__(self.arg & y)

    def __div__(self, y):
        # Python: 8 / -7 == -2 (C-like: -1)
        # int(float) trick cannot be used, due to information loss
        # Examples:
        #
        # 42 / 10 => 4
        # 42 % 10 => 2
        #
        # -42 / 10 => -4
        # -42 % 10 => -2
        #
        # 42 / -10 => -4
        # 42 % -10 => 2
        #
        # -42 / -10 => 4
        # -42 % -10 => -2

        den = int(y)
        num = int(self)
        result_sign = 1 if (den * num) >= 0 else -1
        cls = self.__class__
        if isinstance(y, moduint):
            cls = self.maxcast(y)
        return (abs(num) // abs(den)) * result_sign

    def __floordiv__(self, y):
        return self.__div__(y)

    def __int__(self):
        return int(self.arg)

    def __long__(self):
        return int(self.arg)

    def __index__(self):
        return int(self.arg)

    def __invert__(self):
        return self.__class__(~self.arg)

    def __lshift__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg << y.arg)
        else:
            return self.__class__(self.arg << y)

    def __mod__(self, y):
        # See __div__ for implementation choice
        cls = self.__class__
        if isinstance(y, moduint):
            cls = self.maxcast(y)
        return cls(self.arg - y * (self // y))

    def __mul__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg * y.arg)
        else:
            return self.__class__(self.arg * y)

    def __neg__(self):
        return self.__class__(-self.arg)

    def __or__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg | y.arg)
        else:
            return self.__class__(self.arg | y)

    def __radd__(self, y):
        return self.__add__(y)

    def __rand__(self, y):
        return self.__and__(y)

    def __rdiv__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(y.arg // self.arg)
        else:
            return self.__class__(y // self.arg)

    def __rfloordiv__(self, y):
        return self.__rdiv__(y)

    def __rlshift__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(y.arg << self.arg)
        else:
            return self.__class__(y << self.arg)

    def __rmod__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(y.arg % self.arg)
        else:
            return self.__class__(y % self.arg)

    def __rmul__(self, y):
        return self.__mul__(y)

    def __ror__(self, y):
        return self.__or__(y)

    def __rrshift__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(y.arg >> self.arg)
        else:
            return self.__class__(y >> self.arg)

    def __rshift__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg >> y.arg)
        else:
            return self.__class__(self.arg >> y)

    def __rsub__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(y.arg - self.arg)
        else:
            return self.__class__(y - self.arg)

    def __rxor__(self, y):
        return self.__xor__(y)

    def __sub__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg - y.arg)
        else:
            return self.__class__(self.arg - y)

    def __xor__(self, y):
        if isinstance(y, moduint):
            cls = self.maxcast(y)
            return cls(self.arg ^ y.arg)
        else:
            return self.__class__(self.arg ^ y)

    def __hex__(self):
        return hex(self.arg)

    def __abs__(self):
        return abs(self.arg)

    def __rpow__(self, v):
        return v ** self.arg

    def __pow__(self, v):
        return self.__class__(self.arg ** v)


class modint(moduint):

    def __init__(self, arg):
        if isinstance(arg, moduint):
            arg = arg.arg
        a = arg % self.__class__.limit
        if a >= self.__class__.limit // 2:
            a -= self.__class__.limit
        self.arg = a
        assert(
            self.arg >= -self.__class__.limit // 2 and
            self.arg < self.__class__.limit
        )


def is_modint(a):
    return isinstance(a, moduint)


mod_size2uint = {}
mod_size2int = {}

mod_uint2size = {}
mod_int2size = {}

def define_int(size):
    """Build the 'modint' instance corresponding to size @size"""
    global mod_size2int, mod_int2size

    name = 'int%d' % size
    cls = type(name, (modint,), {"size": size, "limit": 1 << size})
    globals()[name] = cls
    mod_size2int[size] = cls
    mod_int2size[cls] = size
    return cls

def define_uint(size):
    """Build the 'moduint' instance corresponding to size @size"""
    global mod_size2uint, mod_uint2size

    name = 'uint%d' % size
    cls = type(name, (moduint,), {"size": size, "limit": 1 << size})
    globals()[name] = cls
    mod_size2uint[size] = cls
    mod_uint2size[cls] = size
    return cls

def define_common_int():
    "Define common int"
    common_int = range(1, 257)

    for i in common_int:
        define_int(i)

    for i in common_int:
        define_uint(i)

define_common_int()