Maroc-OS/decompiler

View on GitHub
src/callconv.py

Summary

Maintainability
B
4 hrs
Test Coverage
""" Abstracts the logic behind figuring out the arguments to a function call.

http://en.wikipedia.org/wiki/X86_calling_conventions
"""

from expressions import *
import ir.intel
import ssa

__conventions__ = {}

def add_calling_convention(cls):
  __conventions__[cls.__name__] = cls
  return cls

class call_iterator_t(ssa.ssa_contextual_iterator_t):

  def __init__(self, function):
    ssa.ssa_contextual_iterator_t.__init__(self, function, self.is_correct_step)
    self.contexts = []
    return

  def is_correct_step(self, loc):
    return isinstance(loc, assignable_t)

  def is_call(self, expr):
    return isinstance(expr, assign_t) and isinstance(expr.op2, call_t)

  def copy_recursive_context(self, context):
    defined = []
    cur = context
    while cur:
      for _def in reversed(cur.defined):
        defined.append(_def)
      cur = cur.parent
    return list(reversed(defined))

  def statement(self, context, stmt, ):
    if self.is_call(stmt.expr):
      self.contexts.append((self.copy_recursive_context(context), stmt))
    ssa.ssa_contextual_iterator_t.statement(self, context, stmt)
    return

  def __iter__(self):
    self.traverse(ssa.ssa_context_t(self.function.entry_block))
    for ctx, stmt in self.contexts:
      yield ctx, stmt
    return

class convention_t(object):

  def __init__(self, function):
    self.function = function
    return

@add_calling_convention
class live_locations(convention_t):

  def process_live_stack_locations(self, context, call):
    """ find all live stack locations at the top of the stack in this context. """

    # top of stack
    tos = call.stack.copy()
    if not isinstance(tos, sub_t):
      # weird stack?
      return []

    args = []
    while True:
      found = None
      for _def in (context):
        if _def.no_index_eq(deref_t(tos.copy())):
          found = _def
      if not found:
        break
      args.append(found)
      tos.op2.value -= 4

    return args

  def process_live_registers(self, context, stmt):
    """ find all live stack locations at the top of the stack in this context. """

    args = []
    for defined in context:
      if defined.parent_statement is stmt:
        continue
      if type(defined) is regloc_t:
        args.append(defined)

    return args

  def process(self):
    for ctx, stmt in call_iterator_t(self.function):

      args = []
      args += self.process_live_stack_locations(ctx, stmt.expr.op2)
      args += self.process_live_registers(ctx, stmt)

      for arg in args:
        copy = arg.copy(with_definition=True)
        copy.definition = arg
        stmt.expr.op2.params.append(copy)
    return

@add_calling_convention
class systemv_x64_abi_t(convention_t):
  """ SystemV AMD64 ABI

  The following registers are used to pass arguments:
      RDI, RSI, RDX, RCX, R8, R9, XMM0-7
  """

  def process(self, function, ssa_tagger, block, stmt, call):

    # RDI, RSI, RDX, RCX, R8, R9
    which = [ir.intel.RDI, ir.intel.RSI, ir.intel.RDX, ir.intel.RCX, ir.intel.R8, ir.intel.R9]
    regs = []
    for n in which:
      loc = regloc_t(n, function.arch.address_size)
      newloc = ssa_tagger.has_internal_definition(stmt, loc)
      if newloc:
        regs.append(newloc.copy())
      elif ssa_tagger.has_contextual_definition(stmt, loc):
        newloc = self.insert_phi(stmt, loc)
        regs.append(newloc.copy())
      else:
        break

    params = self.make_call_arguments(regs)
    call.params = params

    return

  def make_call_arguments(self, regs):

    if len(regs) == 0:
      return None

    regs = regs[:]

    arglist = regs.pop(-1)
    while len(regs) > 0:
      arglist = comma_t(regs.pop(-1), arglist)

    return arglist

@add_calling_convention
class cdecl(live_locations):

  def process(self):
    for ctx, stmt in call_iterator_t(self.function):
      args = self.process_live_stack_locations(ctx, stmt.expr.op2)
      for arg in args:
        copy = arg.copy(with_definition=True)
        copy.definition = arg
        stmt.expr.op2.params.append(copy)
    return