Maroc-OS/decompiler

View on GitHub
src/filters/simplify_expressions.py

Summary

Maintainability
F
1 wk
Test Coverage
""" This module runs an expression through a series of filters.

When a filter matches, a new expression is created from the old one
and returned to the caller, which should call again until all filters
are exhausted and no simpler expression can be generated.
"""

from expressions import *

__all__ = []

def simplifier(func):
  __all__.append(func)

@simplifier
def flags(expr):
  """ transform flags operations into simpler expressions such as lower-than
      or greater-than.

  unsigned stuff:
  CARRY(a - b) becomes a < b
  !CARRY(a - b) becomes a > b

  signed stuff:
  SIGN(a - b) != OVERFLOW(a - b) becomes a < b
  SIGN(a - b) == OVERFLOW(a - b) becomes a > b

  and for both:
  !(a - b) || a < b becomes a <= b
  (a - b) && a > b becomes a >= b

  """

  is_less = lambda expr: type(expr) == neq_t and \
        type(expr.op1) == sign_t and type(expr.op2) == overflow_t and \
        expr.op1.op == expr.op2.op #and type(expr.op1.op) == sub_t
  is_greater = lambda expr: type(expr) == eq_t and \
        type(expr.op1) == sign_t and type(expr.op2) == overflow_t and \
        expr.op1.op == expr.op2.op #and type(expr.op1.op) == sub_t

  is_lower = lambda expr: type(expr) == carry_t #and type(expr.op) == sub_t
  is_above = lambda expr: type(expr) == b_not_t and is_lower(expr.op)

  is_leq = lambda expr: type(expr) == b_or_t and type(expr.op1) == b_not_t and \
              type(expr.op2) == lower_t and expr.op1.op == expr.op2
  is_aeq = lambda expr: type(expr) == b_and_t and \
              type(expr.op2) in (above_t, aeq_t) and expr.op1 == expr.op2.op1

  # signed less-than
  if is_less(expr):
    op = expr.op1.op

  # signed greater-than
  elif is_greater(expr):
    op = expr.op1.op

  # unsigned lower-than
  elif is_lower(expr):
    op = expr.op

  # unsigned above-than
  elif is_above(expr):
    op = expr.op.op

  # less-or-equal
  elif is_leq(expr):
    op = expr.op2

  # above-or-equal
  elif is_aeq(expr):
    op = expr.op1

  else:
    return

  return lower_t(op.pluck(), value_t(0, op.size))

@simplifier
def add_sub(expr):
  """ Simplify nested math expressions when the second operand of
      each expression is a number literal.

  (a +/- n1) +/- n2 => (a +/- n3) with n3 = n1 +/- n2
  (a +/- 0) => a
  """

  if type(expr) == add_t and type(expr.op1) in (add_t, sub_t) \
        and type(expr.op1.op2) == value_t and type(expr.op2) == value_t:
    _expr = expr.op1.pluck()
    _expr.add(expr.op2)
    return _expr

  if type(expr) == sub_t and type(expr.op1) in (add_t, sub_t) \
        and type(expr.op1.op2) == value_t and type(expr.op2) == value_t:
    _expr = expr.op1.pluck()
    _expr.sub(expr.op2)
    return _expr

  if type(expr) in (sub_t, add_t):
    if type(expr.op2) == value_t and expr.op2.value == 0:
      return expr.op1.pluck()

  if type(expr) == add_t and type(expr.op1) == value_t \
        and type(expr.op2) == value_t:
    _expr = value_t(expr.op1.value + expr.op2.value, expr.op1.size)
    return _expr

  if type(expr) == sub_t and type(expr.op1) == value_t \
        and type(expr.op2) == value_t:
    _expr = value_t(expr.op1.value - expr.op2.value, expr.op1.size)
    return _expr

  return

@simplifier
def ref_deref(expr):
  """ remove nested deref_t and address_t that cancel each other

  &(*(addr)) => addr
  *(&(addr)) => addr
  """

  if type(expr) == address_t and type(expr.op) == deref_t:
    return expr.op.op.pluck()

  if type(expr) == deref_t and type(expr.op) == address_t:
    return expr.op.op.pluck()

  return

@simplifier
def equality_with_literals(expr):
  """ Applies commutativity of equality (==) sign

  (<1> - n1) == n2 becomes <1> == n3 where n3 = n1 + n2
  """

  if type(expr) in (eq_t, neq_t, above_t, lower_t, aeq_t, leq_t) and type(expr.op2) == value_t and \
    type(expr.op1) in (sub_t, add_t) and type(expr.op1.op2) == value_t:

    if type(expr.op1) == sub_t:
      _value = value_t(expr.op2.value + expr.op1.op2.value, max(expr.op2.size, expr.op1.op2.size))
    else:
      _value = value_t(expr.op2.value - expr.op1.op2.value, max(expr.op2.size, expr.op1.op2.size))
    return expr.__class__(expr.op1.op1.pluck(), _value)

  return

@simplifier
def negate(expr):
  """ transform negations into simpler, more readable forms

  !(a && b) becomes !a || !b
  !(a || b) becomes !a && !b
  !(a == b) becomes a != b
  !(a != b) becomes a == b
  !(!(expr)) becomes expr
  a == 0 becomes !a

  !(a < b) becomes a >= b
  !(a > b) becomes a <= b
  !(a >= b) becomes a < b
  !(a <= b) becomes a > b

  !(a - b) becomes a == b
  !(a + b) becomes a == -b

  a - b < 0 becomes a < b
  a - b > 0 becomes a > b
  """

  # !(a && b) becomes !a || !b
  if type(expr) == b_not_t and type(expr.op) == b_and_t:
    return b_or_t(b_not_t(expr.op.op1.pluck()), b_not_t(expr.op.op2.pluck()))

  # !(a || b) becomes !a && !b
  if type(expr) == b_not_t and type(expr.op) == b_or_t:
    return b_and_t(b_not_t(expr.op.op1.pluck()), b_not_t(expr.op.op2.pluck()))

  # !(a == b) becomes a != b
  if type(expr) == b_not_t and type(expr.op) == eq_t:
    return neq_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a != b) becomes a == b
  if type(expr) == b_not_t and type(expr.op) == neq_t:
    return eq_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(!(expr)) becomes expr
  if type(expr) == b_not_t and type(expr.op) == b_not_t:
    return expr.op.op.pluck()

  # a == 0 becomes !a
  if type(expr) == eq_t and type(expr.op2) == value_t and expr.op2.value == 0:
    return b_not_t(expr.op1.pluck())

  # !(a < b) becomes a >= b
  if type(expr) == b_not_t and type(expr.op) == lower_t:
    return aeq_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a > b) becomes a <= b
  if type(expr) == b_not_t and type(expr.op) == above_t:
    return leq_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a >= b) becomes a < b
  if type(expr) == b_not_t and type(expr.op) == aeq_t:
    return lower_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a <= b) becomes a > b
  if type(expr) == b_not_t and type(expr.op) == leq_t:
    return above_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a - b) becomes a == b
  if type(expr) == b_not_t and type(expr.op) == sub_t:
    return eq_t(expr.op.op1.pluck(), expr.op.op2.pluck())

  # !(a + b) becomes a == -b
  if type(expr) == b_not_t and type(expr.op) == add_t:
    return eq_t(expr.op.op1.pluck(), neg_t(expr.op.op2.pluck()))

  #  a - b < 0 becomes a < b
  #  a - b > 0 becomes a > b
  if type(expr) in (lower_t, above_t) and type(expr.op1) == sub_t:
    return lower_t(expr.op1.op1.pluck(), expr.op1.op2.pluck())

  return

@simplifier
def equalities(expr):
  """ equalities """

  # a == b || a > b becomes a >= b
  # a == b || a < b becomes a <= b
  if type(expr) == b_or_t and \
      type(expr.op1) == eq_t and type(expr.op2) in (lower_t, above_t) and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    cls =  {lower_t: leq_t, above_t: aeq_t}[type(expr.op2)]
    return cls(expr.op1.op1.pluck(), expr.op1.op2.pluck())

  # a > b || a == b becomes a >= b
  # a < b || a == b becomes a <= b
  if type(expr) == b_or_t and \
      type(expr.op1) in (lower_t, above_t) and type(expr.op2) == eq_t and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    cls =  {lower_t: leq_t, above_t: aeq_t}[type(expr.op1)]
    return cls(expr.op1.op1.pluck(), expr.op1.op2.pluck())

  # a == b || a <= b becomes a <= b
  # a == b || a >= b becomes a >= b
  if type(expr) == b_or_t and \
      type(expr.op1) == eq_t and type(expr.op2) in (leq_t, aeq_t) and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    return expr.op2.pluck()

  # a <= b || a == b becomes a <= b
  # a >= b || a == b becomes a >= b
  if type(expr) == b_or_t and \
      type(expr.op1) in (leq_t, aeq_t) and type(expr.op2) == eq_t and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    return expr.op1.pluck()

  # a != b && a >= b becomes a > b
  # a != b && a <= b becomes a < b
  # a != b && a > b becomes a > b
  # a != b && a < b becomes a < b
  if type(expr) == b_and_t and \
      type(expr.op1) == neq_t and type(expr.op2) in (leq_t, aeq_t, above_t, lower_t) and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    cls =  {leq_t: lower_t, aeq_t: above_t, above_t: above_t, lower_t: lower_t}[type(expr.op2)]
    return cls(expr.op1.op1.pluck(), expr.op1.op2.pluck())

  # a >= b && a != b becomes a > b
  # a <= b && a != b becomes a < b
  # a > b && a != b becomes a > b
  # a < b && a != b becomes a < b
  if type(expr) == b_and_t and \
      type(expr.op1) in (leq_t, aeq_t, above_t, lower_t) and type(expr.op2) == neq_t and \
      expr.op1.op1 == expr.op2.op1 and expr.op1.op2 == expr.op2.op2:
    cls =  {leq_t: lower_t, aeq_t: above_t, above_t: above_t, lower_t: lower_t}[type(expr.op1)]
    return cls(expr.op1.op1.pluck(), expr.op1.op2.pluck())

  return

@simplifier
def correct_signs(expr):
  """ substitute addition or substraction by its inverse depending on the operand sign

  x + -y becomes x - y
  x - -y becomes x + y
  """

  if type(expr) == add_t and type(expr.op2) == value_t and expr.op2.value < 0:
    return sub_t(expr.op1.pluck(), value_t(abs(expr.op2.value), expr.op2.size))

  if type(expr) == sub_t and type(expr.op2) == value_t and expr.op2.value < 0:
    return add_t(expr.op1.pluck(), value_t(abs(expr.op2.value), expr.op2.size))

  return

@simplifier
def special_xor(expr):
  """ transform xor_t into a literal 0 if both operands to the xor are the same

  x ^ x becomes 0
  """

  if type(expr) == xor_t and expr.op1 == expr.op2:
    return value_t(0, expr.op1.size)

  return

@simplifier
def special_and(expr):
  """ transform the and (&) operator into a simpler form in the special case
  that both operands are the same

  x & x becomes x
  """

  if type(expr) == and_t and expr.op1 == expr.op2:
    return expr.op1.pluck()

  return

def once(expr, deep=False):
  """ run all filters and return the first available simplification """

  for filter in __all__:
    newexpr = filter(expr)
    if newexpr:
      if expr.parent:
        for op in expr.iteroperands():
          if isinstance(op, assignable_t):
            op.unlink()
        expr.replace(newexpr)
      return newexpr

  if deep and isinstance(expr, expr_t):
    for op in expr.operands:
      newexpr = once(op, deep)
      if newexpr:
        return expr

  return

def run(expr, deep=False):
  """ combine expressions until they cannot be combined any more.
      return the new expression. """

  while True:
    newexpr = once(expr, deep=deep)
    if not newexpr:
      break
    expr = newexpr

  return expr