miasm/analysis/data_flow.py
"""Data flow analysis based on miasm intermediate representation"""
from builtins import range
from collections import namedtuple, Counter
from pprint import pprint as pp
from future.utils import viewitems, viewvalues
from miasm.core.utils import encode_hex
from miasm.core.graph import DiGraph
from miasm.ir.ir import AssignBlock, IRBlock
from miasm.expression.expression import ExprLoc, ExprMem, ExprId, ExprInt,\
ExprAssign, ExprOp, ExprWalk, ExprSlice, \
is_function_call, ExprVisitorCallbackBottomToTop
from miasm.expression.simplifications import expr_simp, expr_simp_explicit
from miasm.core.interval import interval
from miasm.expression.expression_helper import possible_values
from miasm.analysis.ssa import get_phi_sources_parent_block, \
irblock_has_phi
from miasm.ir.symbexec import get_expr_base_offset
from collections import deque
class ReachingDefinitions(dict):
"""
Computes for each assignblock the set of reaching definitions.
Example:
IR block:
lbl0:
0 A = 1
B = 3
1 B = 2
2 A = A + B + 4
Reach definition of lbl0:
(lbl0, 0) => {}
(lbl0, 1) => {A: {(lbl0, 0)}, B: {(lbl0, 0)}}
(lbl0, 2) => {A: {(lbl0, 0)}, B: {(lbl0, 1)}}
(lbl0, 3) => {A: {(lbl0, 2)}, B: {(lbl0, 1)}}
Source set 'REACHES' in: Kennedy, K. (1979).
A survey of data flow analysis techniques.
IBM Thomas J. Watson Research Division, Algorithm MK
This class is usable as a dictionary whose structure is
{ (block, index): { lvalue: set((block, index)) } }
"""
ircfg = None
def __init__(self, ircfg):
super(ReachingDefinitions, self).__init__()
self.ircfg = ircfg
self.compute()
def get_definitions(self, block_lbl, assignblk_index):
"""Returns the dict { lvalue: set((def_block_lbl, def_index)) }
associated with self.ircfg.@block.assignblks[@assignblk_index]
or {} if it is not yet computed
"""
return self.get((block_lbl, assignblk_index), {})
def compute(self):
"""This is the main fixpoint"""
modified = True
while modified:
modified = False
for block in viewvalues(self.ircfg.blocks):
modified |= self.process_block(block)
def process_block(self, block):
"""
Fetch reach definitions from predecessors and propagate it to
the assignblk in block @block.
"""
predecessor_state = {}
for pred_lbl in self.ircfg.predecessors(block.loc_key):
if pred_lbl not in self.ircfg.blocks:
continue
pred = self.ircfg.blocks[pred_lbl]
for lval, definitions in viewitems(self.get_definitions(pred_lbl, len(pred))):
predecessor_state.setdefault(lval, set()).update(definitions)
modified = self.get((block.loc_key, 0)) != predecessor_state
if not modified:
return False
self[(block.loc_key, 0)] = predecessor_state
for index in range(len(block)):
modified |= self.process_assignblock(block, index)
return modified
def process_assignblock(self, block, assignblk_index):
"""
Updates the reach definitions with values defined at
assignblock @assignblk_index in block @block.
NB: the effect of assignblock @assignblk_index in stored at index
(@block, @assignblk_index + 1).
"""
assignblk = block[assignblk_index]
defs = self.get_definitions(block.loc_key, assignblk_index).copy()
for lval in assignblk:
defs.update({lval: set([(block.loc_key, assignblk_index)])})
modified = self.get((block.loc_key, assignblk_index + 1)) != defs
if modified:
self[(block.loc_key, assignblk_index + 1)] = defs
return modified
ATTR_DEP = {"color" : "black",
"_type" : "data"}
AssignblkNode = namedtuple('AssignblkNode', ['label', 'index', 'var'])
class DiGraphDefUse(DiGraph):
"""Representation of a Use-Definition graph as defined by
Kennedy, K. (1979). A survey of data flow analysis techniques.
IBM Thomas J. Watson Research Division.
Example:
IR block:
lbl0:
0 A = 1
B = 3
1 B = 2
2 A = A + B + 4
Def use analysis:
(lbl0, 0, A) => {(lbl0, 2, A)}
(lbl0, 0, B) => {}
(lbl0, 1, B) => {(lbl0, 2, A)}
(lbl0, 2, A) => {}
"""
def __init__(self, reaching_defs,
deref_mem=False, apply_simp=False, *args, **kwargs):
"""Instantiate a DiGraph
@blocks: IR blocks
"""
self._edge_attr = {}
# For dot display
self._filter_node = None
self._dot_offset = None
self._blocks = reaching_defs.ircfg.blocks
super(DiGraphDefUse, self).__init__(*args, **kwargs)
self._compute_def_use(reaching_defs,
deref_mem=deref_mem,
apply_simp=apply_simp)
def edge_attr(self, src, dst):
"""
Return a dictionary of attributes for the edge between @src and @dst
@src: the source node of the edge
@dst: the destination node of the edge
"""
return self._edge_attr[(src, dst)]
def _compute_def_use(self, reaching_defs,
deref_mem=False, apply_simp=False):
for block in viewvalues(self._blocks):
self._compute_def_use_block(block,
reaching_defs,
deref_mem=deref_mem,
apply_simp=apply_simp)
def _compute_def_use_block(self, block, reaching_defs, deref_mem=False, apply_simp=False):
for index, assignblk in enumerate(block):
assignblk_reaching_defs = reaching_defs.get_definitions(block.loc_key, index)
for lval, expr in viewitems(assignblk):
self.add_node(AssignblkNode(block.loc_key, index, lval))
expr = expr_simp_explicit(expr) if apply_simp else expr
read_vars = expr.get_r(mem_read=deref_mem)
if deref_mem and lval.is_mem():
read_vars.update(lval.ptr.get_r(mem_read=deref_mem))
for read_var in read_vars:
for reach in assignblk_reaching_defs.get(read_var, set()):
self.add_data_edge(AssignblkNode(reach[0], reach[1], read_var),
AssignblkNode(block.loc_key, index, lval))
def del_edge(self, src, dst):
super(DiGraphDefUse, self).del_edge(src, dst)
del self._edge_attr[(src, dst)]
def add_uniq_labeled_edge(self, src, dst, edge_label):
"""Adds the edge (@src, @dst) with label @edge_label.
if edge (@src, @dst) already exists, the previous label is overridden
"""
self.add_uniq_edge(src, dst)
self._edge_attr[(src, dst)] = edge_label
def add_data_edge(self, src, dst):
"""Adds an edge representing a data dependency
and sets the label accordingly"""
self.add_uniq_labeled_edge(src, dst, ATTR_DEP)
def node2lines(self, node):
lbl, index, reg = node
yield self.DotCellDescription(text="%s (%s)" % (lbl, index),
attr={'align': 'center',
'colspan': 2,
'bgcolor': 'grey'})
src = self._blocks[lbl][index][reg]
line = "%s = %s" % (reg, src)
yield self.DotCellDescription(text=line, attr={})
yield self.DotCellDescription(text="", attr={})
class DeadRemoval(object):
"""
Do dead removal
"""
def __init__(self, lifter, expr_to_original_expr=None):
self.lifter = lifter
if expr_to_original_expr is None:
expr_to_original_expr = {}
self.expr_to_original_expr = expr_to_original_expr
def add_expr_to_original_expr(self, expr_to_original_expr):
self.expr_to_original_expr.update(expr_to_original_expr)
def is_unkillable_destination(self, lval, rval):
if (
lval.is_mem() or
self.lifter.IRDst == lval or
lval.is_id("exception_flags") or
is_function_call(rval)
):
return True
return False
def get_block_useful_destinations(self, block):
"""
Force keeping of specific cases
block: IRBlock instance
"""
useful = set()
for index, assignblk in enumerate(block):
for lval, rval in viewitems(assignblk):
if self.is_unkillable_destination(lval, rval):
useful.add(AssignblkNode(block.loc_key, index, lval))
return useful
def is_tracked_var(self, lval, variable):
new_lval = self.expr_to_original_expr.get(lval, lval)
return new_lval == variable
def find_definitions_from_worklist(self, worklist, ircfg):
"""
Find variables definition in @worklist by browsing the @ircfg
"""
locs_done = set()
defs = set()
while worklist:
found = False
elt = worklist.pop()
if elt in locs_done:
continue
locs_done.add(elt)
variable, loc_key = elt
block = ircfg.get_block(loc_key)
if block is None:
# Consider no sources in incomplete graph
continue
for index, assignblk in reversed(list(enumerate(block))):
for dst, src in viewitems(assignblk):
if self.is_tracked_var(dst, variable):
defs.add(AssignblkNode(loc_key, index, dst))
found = True
break
if found:
break
if not found:
for predecessor in ircfg.predecessors(loc_key):
worklist.add((variable, predecessor))
return defs
def find_out_regs_definitions_from_block(self, block, ircfg):
"""
Find definitions of out regs starting from @block
"""
worklist = set()
for reg in self.lifter.get_out_regs(block):
worklist.add((reg, block.loc_key))
ret = self.find_definitions_from_worklist(worklist, ircfg)
return ret
def add_def_for_incomplete_leaf(self, block, ircfg, reaching_defs):
"""
Add valid definitions at end of @block plus out regs
"""
valid_definitions = reaching_defs.get_definitions(
block.loc_key,
len(block)
)
worklist = set()
for lval, definitions in viewitems(valid_definitions):
for definition in definitions:
new_lval = self.expr_to_original_expr.get(lval, lval)
worklist.add((new_lval, block.loc_key))
ret = self.find_definitions_from_worklist(worklist, ircfg)
useful = ret
useful.update(self.find_out_regs_definitions_from_block(block, ircfg))
return useful
def get_useful_assignments(self, ircfg, defuse, reaching_defs):
"""
Mark useful statements using previous reach analysis and defuse
Return a set of triplets (block, assignblk number, lvalue) of
useful definitions
PRE: compute_reach(self)
"""
useful = set()
for block_lbl, block in viewitems(ircfg.blocks):
block = ircfg.get_block(block_lbl)
if block is None:
# skip unknown blocks: won't generate dependencies
continue
block_useful = self.get_block_useful_destinations(block)
useful.update(block_useful)
successors = ircfg.successors(block_lbl)
for successor in successors:
if successor not in ircfg.blocks:
keep_all_definitions = True
break
else:
keep_all_definitions = False
if keep_all_definitions:
useful.update(self.add_def_for_incomplete_leaf(block, ircfg, reaching_defs))
continue
if len(successors) == 0:
useful.update(self.find_out_regs_definitions_from_block(block, ircfg))
else:
continue
# Useful nodes dependencies
for node in useful:
for parent in defuse.reachable_parents(node):
yield parent
def do_dead_removal(self, ircfg):
"""
Remove useless assignments.
This function is used to analyse relation of a * complete function *
This means the blocks under study represent a solid full function graph.
Source : Kennedy, K. (1979). A survey of data flow analysis techniques.
IBM Thomas J. Watson Research Division, page 43
@ircfg: Lifter instance
"""
modified = False
reaching_defs = ReachingDefinitions(ircfg)
defuse = DiGraphDefUse(reaching_defs, deref_mem=True)
useful = self.get_useful_assignments(ircfg, defuse, reaching_defs)
useful = set(useful)
for block in list(viewvalues(ircfg.blocks)):
irs = []
for idx, assignblk in enumerate(block):
new_assignblk = dict(assignblk)
for lval in assignblk:
if AssignblkNode(block.loc_key, idx, lval) not in useful:
del new_assignblk[lval]
modified = True
irs.append(AssignBlock(new_assignblk, assignblk.instr))
ircfg.blocks[block.loc_key] = IRBlock(block.loc_db, block.loc_key, irs)
return modified
def __call__(self, ircfg):
ret = self.do_dead_removal(ircfg)
return ret
def _test_merge_next_block(ircfg, loc_key):
"""
Test if the irblock at @loc_key can be merge with its son
@ircfg: IRCFG instance
@loc_key: LocKey instance of the candidate parent irblock
"""
if loc_key not in ircfg.blocks:
return None
sons = ircfg.successors(loc_key)
if len(sons) != 1:
return None
son = list(sons)[0]
if ircfg.predecessors(son) != [loc_key]:
return None
if son not in ircfg.blocks:
return None
return son
def _do_merge_blocks(ircfg, loc_key, son_loc_key):
"""
Merge two irblocks at @loc_key and @son_loc_key
@ircfg: DiGrpahIR
@loc_key: LocKey instance of the parent IRBlock
@loc_key: LocKey instance of the son IRBlock
"""
assignblks = []
for assignblk in ircfg.blocks[loc_key]:
if ircfg.IRDst not in assignblk:
assignblks.append(assignblk)
continue
affs = {}
for dst, src in viewitems(assignblk):
if dst != ircfg.IRDst:
affs[dst] = src
if affs:
assignblks.append(AssignBlock(affs, assignblk.instr))
assignblks += ircfg.blocks[son_loc_key].assignblks
new_block = IRBlock(ircfg.loc_db, loc_key, assignblks)
ircfg.discard_edge(loc_key, son_loc_key)
for son_successor in ircfg.successors(son_loc_key):
ircfg.add_uniq_edge(loc_key, son_successor)
ircfg.discard_edge(son_loc_key, son_successor)
del ircfg.blocks[son_loc_key]
ircfg.del_node(son_loc_key)
ircfg.blocks[loc_key] = new_block
def _test_jmp_only(ircfg, loc_key, heads):
"""
If irblock at @loc_key sets only IRDst to an ExprLoc, return the
corresponding loc_key target.
Avoid creating predecssors for heads LocKeys
None in other cases.
@ircfg: IRCFG instance
@loc_key: LocKey instance of the candidate irblock
@heads: LocKey heads of the graph
"""
if loc_key not in ircfg.blocks:
return None
irblock = ircfg.blocks[loc_key]
if len(irblock.assignblks) != 1:
return None
items = list(viewitems(dict(irblock.assignblks[0])))
if len(items) != 1:
return None
if len(ircfg.successors(loc_key)) != 1:
return None
# Don't create predecessors on heads
dst, src = items[0]
assert dst.is_id("IRDst")
if not src.is_loc():
return None
dst = src.loc_key
if loc_key in heads:
predecessors = set(ircfg.predecessors(dst))
predecessors.difference_update(set([loc_key]))
if predecessors:
return None
return dst
def _relink_block_node(ircfg, loc_key, son_loc_key, replace_dct):
"""
Link loc_key's parents to parents directly to son_loc_key
"""
for parent in set(ircfg.predecessors(loc_key)):
parent_block = ircfg.blocks.get(parent, None)
if parent_block is None:
continue
new_block = parent_block.modify_exprs(
lambda expr:expr.replace_expr(replace_dct),
lambda expr:expr.replace_expr(replace_dct)
)
# Link parent to new dst
ircfg.add_uniq_edge(parent, son_loc_key)
# Unlink block
ircfg.blocks[new_block.loc_key] = new_block
ircfg.del_node(loc_key)
def _remove_to_son(ircfg, loc_key, son_loc_key):
"""
Merge irblocks; The final block has the @son_loc_key loc_key
Update references
Condition:
- irblock at @loc_key is a pure jump block
- @loc_key is not an entry point (can be removed)
@irblock: IRCFG instance
@loc_key: LocKey instance of the parent irblock
@son_loc_key: LocKey instance of the son irblock
"""
# Ircfg loop => don't mess
if loc_key == son_loc_key:
return False
# Unlink block destinations
ircfg.del_edge(loc_key, son_loc_key)
replace_dct = {
ExprLoc(loc_key, ircfg.IRDst.size):ExprLoc(son_loc_key, ircfg.IRDst.size)
}
_relink_block_node(ircfg, loc_key, son_loc_key, replace_dct)
ircfg.del_node(loc_key)
del ircfg.blocks[loc_key]
return True
def _remove_to_parent(ircfg, loc_key, son_loc_key):
"""
Merge irblocks; The final block has the @loc_key loc_key
Update references
Condition:
- irblock at @loc_key is a pure jump block
- @son_loc_key is not an entry point (can be removed)
@irblock: IRCFG instance
@loc_key: LocKey instance of the parent irblock
@son_loc_key: LocKey instance of the son irblock
"""
# Ircfg loop => don't mess
if loc_key == son_loc_key:
return False
# Unlink block destinations
ircfg.del_edge(loc_key, son_loc_key)
old_irblock = ircfg.blocks[son_loc_key]
new_irblock = IRBlock(ircfg.loc_db, loc_key, old_irblock.assignblks)
ircfg.blocks[son_loc_key] = new_irblock
ircfg.add_irblock(new_irblock)
replace_dct = {
ExprLoc(son_loc_key, ircfg.IRDst.size):ExprLoc(loc_key, ircfg.IRDst.size)
}
_relink_block_node(ircfg, son_loc_key, loc_key, replace_dct)
ircfg.del_node(son_loc_key)
del ircfg.blocks[son_loc_key]
return True
def merge_blocks(ircfg, heads):
"""
This function modifies @ircfg to apply the following transformations:
- group an irblock with its son if the irblock has one and only one son and
this son has one and only one parent (spaghetti code).
- if an irblock is only made of an assignment to IRDst with a given label,
this irblock is dropped and its parent destination targets are
updated. The irblock must have a parent (avoid deleting the function head)
- if an irblock is a head of the graph and is only made of an assignment to
IRDst with a given label, this irblock is dropped and its son becomes the
head. References are fixed
This function avoid creating predecessors on heads
Return True if at least an irblock has been modified
@ircfg: IRCFG instance
@heads: loc_key to keep
"""
modified = False
todo = set(ircfg.nodes())
while todo:
loc_key = todo.pop()
# Test merge block
son = _test_merge_next_block(ircfg, loc_key)
if son is not None and son not in heads:
_do_merge_blocks(ircfg, loc_key, son)
todo.add(loc_key)
modified = True
continue
# Test jmp only block
son = _test_jmp_only(ircfg, loc_key, heads)
if son is not None and loc_key not in heads:
ret = _remove_to_son(ircfg, loc_key, son)
modified |= ret
if ret:
todo.add(loc_key)
continue
# Test head jmp only block
if (son is not None and
son not in heads and
son in ircfg.blocks):
# jmp only test done previously
ret = _remove_to_parent(ircfg, loc_key, son)
modified |= ret
if ret:
todo.add(loc_key)
continue
return modified
def remove_empty_assignblks(ircfg):
"""
Remove empty assignblks in irblocks of @ircfg
Return True if at least an irblock has been modified
@ircfg: IRCFG instance
"""
modified = False
for loc_key, block in list(viewitems(ircfg.blocks)):
irs = []
block_modified = False
for assignblk in block:
if len(assignblk):
irs.append(assignblk)
else:
block_modified = True
if block_modified:
new_irblock = IRBlock(ircfg.loc_db, loc_key, irs)
ircfg.blocks[loc_key] = new_irblock
modified = True
return modified
class SSADefUse(DiGraph):
"""
Generate DefUse information from SSA transformation
Links are not valid for ExprMem.
"""
def add_var_def(self, node, src):
index2dst = self._links.setdefault(node.label, {})
dst2src = index2dst.setdefault(node.index, {})
dst2src[node.var] = src
def add_def_node(self, def_nodes, node, src):
if node.var.is_id():
def_nodes[node.var] = node
def add_use_node(self, use_nodes, node, src):
sources = set()
if node.var.is_mem():
sources.update(node.var.ptr.get_r(mem_read=True))
sources.update(src.get_r(mem_read=True))
for source in sources:
if not source.is_mem():
use_nodes.setdefault(source, set()).add(node)
def get_node_target(self, node):
return self._links[node.label][node.index][node.var]
def set_node_target(self, node, src):
self._links[node.label][node.index][node.var] = src
@classmethod
def from_ssa(cls, ssa):
"""
Return a DefUse DiGraph from a SSA graph
@ssa: SSADiGraph instance
"""
graph = cls()
# First pass
# Link line to its use and def
def_nodes = {}
use_nodes = {}
graph._links = {}
for lbl in ssa.graph.nodes():
block = ssa.graph.blocks.get(lbl, None)
if block is None:
continue
for index, assignblk in enumerate(block):
for dst, src in viewitems(assignblk):
node = AssignblkNode(lbl, index, dst)
graph.add_var_def(node, src)
graph.add_def_node(def_nodes, node, src)
graph.add_use_node(use_nodes, node, src)
for dst, node in viewitems(def_nodes):
graph.add_node(node)
if dst not in use_nodes:
continue
for use in use_nodes[dst]:
graph.add_uniq_edge(node, use)
return graph
def expr_has_mem(expr):
"""
Return True if expr contains at least one memory access
@expr: Expr instance
"""
def has_mem(self):
return self.is_mem()
visitor = ExprWalk(has_mem)
return visitor.visit(expr)
def stack_to_reg(expr):
if expr.is_mem():
ptr = expr.arg
SP = lifter.sp
if ptr == SP:
return ExprId("STACK.0", expr.size)
elif (ptr.is_op('+') and
len(ptr.args) == 2 and
ptr.args[0] == SP and
ptr.args[1].is_int()):
diff = int(ptr.args[1])
assert diff % 4 == 0
diff = (0 - diff) & 0xFFFFFFFF
return ExprId("STACK.%d" % (diff // 4), expr.size)
return False
def is_stack_access(lifter, expr):
if not expr.is_mem():
return False
ptr = expr.ptr
diff = expr_simp(ptr - lifter.sp)
if not diff.is_int():
return False
return expr
def visitor_get_stack_accesses(lifter, expr, stack_vars):
if is_stack_access(lifter, expr):
stack_vars.add(expr)
return expr
def get_stack_accesses(lifter, expr):
result = set()
def get_stack(expr_to_test):
visitor_get_stack_accesses(lifter, expr_to_test, result)
return None
visitor = ExprWalk(get_stack)
visitor.visit(expr)
return result
def get_interval_length(interval_in):
length = 0
for start, stop in interval_in.intervals:
length += stop + 1 - start
return length
def check_expr_below_stack(lifter, expr):
"""
Return False if expr pointer is below original stack pointer
@lifter: lifter_model_call instance
@expr: Expression instance
"""
ptr = expr.ptr
diff = expr_simp(ptr - lifter.sp)
if not diff.is_int():
return True
if int(diff) == 0 or int(expr_simp(diff.msb())) == 0:
return False
return True
def retrieve_stack_accesses(lifter, ircfg):
"""
Walk the ssa graph and find stack based variables.
Return a dictionary linking stack base address to its size/name
@lifter: lifter_model_call instance
@ircfg: IRCFG instance
"""
stack_vars = set()
for block in viewvalues(ircfg.blocks):
for assignblk in block:
for dst, src in viewitems(assignblk):
stack_vars.update(get_stack_accesses(lifter, dst))
stack_vars.update(get_stack_accesses(lifter, src))
stack_vars = [expr for expr in stack_vars if check_expr_below_stack(lifter, expr)]
base_to_var = {}
for var in stack_vars:
base_to_var.setdefault(var.ptr, set()).add(var)
base_to_interval = {}
for addr, vars in viewitems(base_to_var):
var_interval = interval()
for var in vars:
offset = expr_simp(addr - lifter.sp)
if not offset.is_int():
# skip non linear stack offset
continue
start = int(offset)
stop = int(expr_simp(offset + ExprInt(var.size // 8, offset.size)))
mem = interval([(start, stop-1)])
var_interval += mem
base_to_interval[addr] = var_interval
if not base_to_interval:
return {}
# Check if not intervals overlap
_, tmp = base_to_interval.popitem()
while base_to_interval:
addr, mem = base_to_interval.popitem()
assert (tmp & mem).empty
tmp += mem
base_to_info = {}
for addr, vars in viewitems(base_to_var):
name = "var_%d" % (len(base_to_info))
size = max([var.size for var in vars])
base_to_info[addr] = size, name
return base_to_info
def fix_stack_vars(expr, base_to_info):
"""
Replace local stack accesses in expr using information in @base_to_info
@expr: Expression instance
@base_to_info: dictionary linking stack base address to its size/name
"""
if not expr.is_mem():
return expr
ptr = expr.ptr
if ptr not in base_to_info:
return expr
size, name = base_to_info[ptr]
var = ExprId(name, size)
if size == expr.size:
return var
assert expr.size < size
return var[:expr.size]
def replace_mem_stack_vars(expr, base_to_info):
return expr.visit(lambda expr:fix_stack_vars(expr, base_to_info))
def replace_stack_vars(lifter, ircfg):
"""
Try to replace stack based memory accesses by variables.
Hypothesis: the input ircfg must have all it's accesses to stack explicitly
done through the stack register, ie every aliases on those variables is
resolved.
WARNING: may fail
@lifter: lifter_model_call instance
@ircfg: IRCFG instance
"""
base_to_info = retrieve_stack_accesses(lifter, ircfg)
modified = False
for block in list(viewvalues(ircfg.blocks)):
assignblks = []
for assignblk in block:
out = {}
for dst, src in viewitems(assignblk):
new_dst = dst.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info))
new_src = src.visit(lambda expr:replace_mem_stack_vars(expr, base_to_info))
if new_dst != dst or new_src != src:
modified |= True
out[new_dst] = new_src
out = AssignBlock(out, assignblk.instr)
assignblks.append(out)
new_block = IRBlock(block.loc_db, block.loc_key, assignblks)
ircfg.blocks[block.loc_key] = new_block
return modified
def memlookup_test(expr, bs, is_addr_ro_variable, result):
if expr.is_mem() and expr.ptr.is_int():
ptr = int(expr.ptr)
if is_addr_ro_variable(bs, ptr, expr.size):
result.add(expr)
return False
return True
def memlookup_visit(expr, bs, is_addr_ro_variable):
result = set()
def retrieve_memlookup(expr_to_test):
memlookup_test(expr_to_test, bs, is_addr_ro_variable, result)
return None
visitor = ExprWalk(retrieve_memlookup)
visitor.visit(expr)
return result
def get_memlookup(expr, bs, is_addr_ro_variable):
return memlookup_visit(expr, bs, is_addr_ro_variable)
def read_mem(bs, expr):
ptr = int(expr.ptr)
var_bytes = bs.getbytes(ptr, expr.size // 8)[::-1]
try:
value = int(encode_hex(var_bytes), 16)
except ValueError:
return expr
return ExprInt(value, expr.size)
def load_from_int(ircfg, bs, is_addr_ro_variable):
"""
Replace memory read based on constant with static value
@ircfg: IRCFG instance
@bs: binstream instance
@is_addr_ro_variable: callback(addr, size) to test memory candidate
"""
modified = False
for block in list(viewvalues(ircfg.blocks)):
assignblks = list()
for assignblk in block:
out = {}
for dst, src in viewitems(assignblk):
# Test src
mems = get_memlookup(src, bs, is_addr_ro_variable)
src_new = src
if mems:
replace = {}
for mem in mems:
value = read_mem(bs, mem)
replace[mem] = value
src_new = src.replace_expr(replace)
if src_new != src:
modified = True
# Test dst pointer if dst is mem
if dst.is_mem():
ptr = dst.ptr
mems = get_memlookup(ptr, bs, is_addr_ro_variable)
if mems:
replace = {}
for mem in mems:
value = read_mem(bs, mem)
replace[mem] = value
ptr_new = ptr.replace_expr(replace)
if ptr_new != ptr:
modified = True
dst = ExprMem(ptr_new, dst.size)
out[dst] = src_new
out = AssignBlock(out, assignblk.instr)
assignblks.append(out)
block = IRBlock(block.loc_db, block.loc_key, assignblks)
ircfg.blocks[block.loc_key] = block
return modified
class AssignBlockLivenessInfos(object):
"""
Description of live in / live out of an AssignBlock
"""
__slots__ = ["gen", "kill", "var_in", "var_out", "live", "assignblk"]
def __init__(self, assignblk, gen, kill):
self.gen = gen
self.kill = kill
self.var_in = set()
self.var_out = set()
self.live = set()
self.assignblk = assignblk
def __str__(self):
out = []
out.append("\tVarIn:" + ", ".join(str(x) for x in self.var_in))
out.append("\tGen:" + ", ".join(str(x) for x in self.gen))
out.append("\tKill:" + ", ".join(str(x) for x in self.kill))
out.append(
'\n'.join(
"\t%s = %s" % (dst, src)
for (dst, src) in viewitems(self.assignblk)
)
)
out.append("\tVarOut:" + ", ".join(str(x) for x in self.var_out))
return '\n'.join(out)
class IRBlockLivenessInfos(object):
"""
Description of live in / live out of an AssignBlock
"""
__slots__ = ["loc_key", "infos", "assignblks"]
def __init__(self, irblock):
self.loc_key = irblock.loc_key
self.infos = []
self.assignblks = []
for assignblk in irblock:
gens, kills = set(), set()
for dst, src in viewitems(assignblk):
expr = ExprAssign(dst, src)
read = expr.get_r(mem_read=True)
write = expr.get_w()
gens.update(read)
kills.update(write)
self.infos.append(AssignBlockLivenessInfos(assignblk, gens, kills))
self.assignblks.append(assignblk)
def __getitem__(self, index):
"""Getitem on assignblks"""
return self.assignblks.__getitem__(index)
def __str__(self):
out = []
out.append("%s:" % self.loc_key)
for info in self.infos:
out.append(str(info))
out.append('')
return "\n".join(out)
class DiGraphLiveness(DiGraph):
"""
DiGraph representing variable liveness
"""
def __init__(self, ircfg):
super(DiGraphLiveness, self).__init__()
self.ircfg = ircfg
self.loc_db = ircfg.loc_db
self._blocks = {}
# Add irblocks gen/kill
for node in ircfg.nodes():
irblock = ircfg.blocks.get(node, None)
if irblock is None:
continue
irblockinfos = IRBlockLivenessInfos(irblock)
self.add_node(irblockinfos.loc_key)
self.blocks[irblockinfos.loc_key] = irblockinfos
for succ in ircfg.successors(node):
self.add_uniq_edge(node, succ)
for pred in ircfg.predecessors(node):
self.add_uniq_edge(pred, node)
@property
def blocks(self):
return self._blocks
def init_var_info(self):
"""Add ircfg out regs"""
raise NotImplementedError("Abstract method")
def node2lines(self, node):
"""
Output liveness information in dot format
"""
names = self.loc_db.get_location_names(node)
if not names:
node_name = self.loc_db.pretty_str(node)
else:
node_name = "".join("%s:\n" % name for name in names)
yield self.DotCellDescription(
text="%s" % node_name,
attr={
'align': 'center',
'colspan': 2,
'bgcolor': 'grey',
}
)
if node not in self._blocks:
yield [self.DotCellDescription(text="NOT PRESENT", attr={})]
return
for i, info in enumerate(self._blocks[node].infos):
var_in = "VarIn:" + ", ".join(str(x) for x in info.var_in)
var_out = "VarOut:" + ", ".join(str(x) for x in info.var_out)
assignmnts = ["%s = %s" % (dst, src) for (dst, src) in viewitems(info.assignblk)]
if i == 0:
yield self.DotCellDescription(
text=var_in,
attr={
'bgcolor': 'green',
}
)
for assign in assignmnts:
yield self.DotCellDescription(text=assign, attr={})
yield self.DotCellDescription(
text=var_out,
attr={
'bgcolor': 'green',
}
)
yield self.DotCellDescription(text="", attr={})
def back_propagate_compute(self, block):
"""
Compute the liveness information in the @block.
@block: AssignBlockLivenessInfos instance
"""
infos = block.infos
modified = False
for i in reversed(range(len(infos))):
new_vars = set(infos[i].gen.union(infos[i].var_out.difference(infos[i].kill)))
if infos[i].var_in != new_vars:
modified = True
infos[i].var_in = new_vars
if i > 0 and infos[i - 1].var_out != set(infos[i].var_in):
modified = True
infos[i - 1].var_out = set(infos[i].var_in)
return modified
def back_propagate_to_parent(self, todo, node, parent):
"""
Back propagate the liveness information from @node to @parent.
@node: loc_key of the source node
@parent: loc_key of the node to update
"""
parent_block = self.blocks[parent]
cur_block = self.blocks[node]
if cur_block.infos[0].var_in == parent_block.infos[-1].var_out:
return
var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out)
parent_block.infos[-1].var_out = var_info
todo.add(parent)
def compute_liveness(self):
"""
Compute the liveness information for the digraph.
"""
todo = set(self.leaves())
while todo:
node = todo.pop()
cur_block = self.blocks.get(node, None)
if cur_block is None:
continue
modified = self.back_propagate_compute(cur_block)
if not modified:
continue
# We modified parent in, propagate to parents
for pred in self.predecessors(node):
self.back_propagate_to_parent(todo, node, pred)
return True
class DiGraphLivenessIRA(DiGraphLiveness):
"""
DiGraph representing variable liveness for IRA
"""
def init_var_info(self, lifter):
"""Add ircfg out regs"""
for node in self.leaves():
irblock = self.ircfg.blocks.get(node, None)
if irblock is None:
continue
var_out = lifter.get_out_regs(irblock)
irblock_liveness = self.blocks[node]
irblock_liveness.infos[-1].var_out = var_out
def discard_phi_sources(ircfg, deleted_vars):
"""
Remove phi sources in @ircfg belonging to @deleted_vars set
@ircfg: IRCFG instance in ssa form
@deleted_vars: unused phi sources
"""
for block in list(viewvalues(ircfg.blocks)):
if not block.assignblks:
continue
assignblk = block[0]
todo = {}
modified = False
for dst, src in viewitems(assignblk):
if not src.is_op('Phi'):
todo[dst] = src
continue
srcs = set(expr for expr in src.args if expr not in deleted_vars)
assert(srcs)
if len(srcs) > 1:
todo[dst] = ExprOp('Phi', *srcs)
continue
todo[dst] = srcs.pop()
modified = True
if not modified:
continue
assignblks = list(block)
assignblk = dict(assignblk)
assignblk.update(todo)
assignblk = AssignBlock(assignblk, assignblks[0].instr)
assignblks[0] = assignblk
new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks)
ircfg.blocks[block.loc_key] = new_irblock
return True
def get_unreachable_nodes(ircfg, edges_to_del, heads):
"""
Return the unreachable nodes starting from heads and the associated edges to
be deleted.
@ircfg: IRCFG instance
@edges_to_del: edges already marked as deleted
heads: locations of graph heads
"""
todo = set(heads)
visited_nodes = set()
new_edges_to_del = set()
while todo:
node = todo.pop()
if node in visited_nodes:
continue
visited_nodes.add(node)
for successor in ircfg.successors(node):
if (node, successor) not in edges_to_del:
todo.add(successor)
all_nodes = set(ircfg.nodes())
nodes_to_del = all_nodes.difference(visited_nodes)
for node in nodes_to_del:
for successor in ircfg.successors(node):
if successor not in nodes_to_del:
# Frontier: link from a deleted node to a living node
new_edges_to_del.add((node, successor))
return nodes_to_del, new_edges_to_del
def update_phi_with_deleted_edges(ircfg, edges_to_del):
"""
Update phi which have a source present in @edges_to_del
@ssa: IRCFG instance in ssa form
@edges_to_del: edges to delete
"""
phi_locs_to_srcs = {}
for loc_src, loc_dst in edges_to_del:
phi_locs_to_srcs.setdefault(loc_dst, set()).add(loc_src)
modified = False
blocks = dict(ircfg.blocks)
for loc_dst, loc_srcs in viewitems(phi_locs_to_srcs):
if loc_dst not in ircfg.blocks:
continue
block = ircfg.blocks[loc_dst]
if not irblock_has_phi(block):
continue
assignblks = list(block)
assignblk = assignblks[0]
out = {}
for dst, phi_sources in viewitems(assignblk):
if not phi_sources.is_op('Phi'):
out[dst] = phi_sources
continue
var_to_parents = get_phi_sources_parent_block(
ircfg,
loc_dst,
phi_sources.args
)
to_keep = set(phi_sources.args)
for src in phi_sources.args:
parents = var_to_parents[src]
remaining = parents.difference(loc_srcs)
if not remaining:
to_keep.discard(src)
modified = True
assert to_keep
if len(to_keep) == 1:
out[dst] = to_keep.pop()
else:
out[dst] = ExprOp('Phi', *to_keep)
assignblk = AssignBlock(out, assignblks[0].instr)
assignblks[0] = assignblk
new_irblock = IRBlock(block.loc_db, loc_dst, assignblks)
blocks[block.loc_key] = new_irblock
for loc_key, block in viewitems(blocks):
ircfg.blocks[loc_key] = block
return modified
def del_unused_edges(ircfg, heads):
"""
Delete non accessible edges in the @ircfg graph.
@ircfg: IRCFG instance in ssa form
@heads: location of the heads of the graph
"""
deleted_vars = set()
modified = False
edges_to_del_1 = set()
for node in ircfg.nodes():
successors = set(ircfg.successors(node))
block = ircfg.blocks.get(node, None)
if block is None:
continue
dst = block.dst
possible_dsts = set(solution.value for solution in possible_values(dst))
if not all(dst.is_loc() for dst in possible_dsts):
continue
possible_dsts = set(dst.loc_key for dst in possible_dsts)
if len(possible_dsts) == len(successors):
continue
dsts_to_del = successors.difference(possible_dsts)
for dst in dsts_to_del:
edges_to_del_1.add((node, dst))
# Remove edges and update phi accordingly
# Two cases here:
# - edge is directly linked to a phi node
# - edge is indirect linked to a phi node
nodes_to_del, edges_to_del_2 = get_unreachable_nodes(ircfg, edges_to_del_1, heads)
modified |= update_phi_with_deleted_edges(ircfg, edges_to_del_1.union(edges_to_del_2))
for src, dst in edges_to_del_1.union(edges_to_del_2):
ircfg.del_edge(src, dst)
for node in nodes_to_del:
if node not in ircfg.blocks:
continue
block = ircfg.blocks[node]
ircfg.del_node(node)
del ircfg.blocks[node]
for assignblock in block:
for dst in assignblock:
deleted_vars.add(dst)
if deleted_vars:
modified |= discard_phi_sources(ircfg, deleted_vars)
return modified
class DiGraphLivenessSSA(DiGraphLivenessIRA):
"""
DiGraph representing variable liveness is a SSA graph
"""
def __init__(self, ircfg):
super(DiGraphLivenessSSA, self).__init__(ircfg)
self.loc_key_to_phi_parents = {}
for irblock in viewvalues(self.blocks):
if not irblock_has_phi(irblock):
continue
out = {}
for sources in viewvalues(irblock[0]):
if not sources.is_op('Phi'):
# Some phi sources may have already been resolved to an
# expression
continue
var_to_parents = get_phi_sources_parent_block(self, irblock.loc_key, sources.args)
for var, var_parents in viewitems(var_to_parents):
out.setdefault(var, set()).update(var_parents)
self.loc_key_to_phi_parents[irblock.loc_key] = out
def back_propagate_to_parent(self, todo, node, parent):
if parent not in self.blocks:
return
parent_block = self.blocks[parent]
cur_block = self.blocks[node]
irblock = self.ircfg.blocks[node]
if cur_block.infos[0].var_in == parent_block.infos[-1].var_out:
return
var_info = cur_block.infos[0].var_in.union(parent_block.infos[-1].var_out)
if irblock_has_phi(irblock):
# Remove phi special case
out = set()
phi_sources = self.loc_key_to_phi_parents[irblock.loc_key]
for var in var_info:
if var not in phi_sources:
out.add(var)
continue
if parent in phi_sources[var]:
out.add(var)
var_info = out
parent_block.infos[-1].var_out = var_info
todo.add(parent)
def get_phi_sources(phi_src, phi_dsts, ids_to_src):
"""
Return False if the @phi_src has more than one non-phi source
Else, return its source
@ids_to_src: Dictionary linking phi source to its definition
"""
true_values = set()
for src in phi_src.args:
if src in phi_dsts:
# Source is phi dst => skip
continue
true_src = ids_to_src[src]
if true_src in phi_dsts:
# Source is phi dst => skip
continue
# Check if src is not also a phi
if true_src.is_op('Phi'):
phi_dsts.add(src)
true_src = get_phi_sources(true_src, phi_dsts, ids_to_src)
if true_src is False:
return False
if true_src is True:
continue
true_values.add(true_src)
if len(true_values) != 1:
return False
if not true_values:
return True
if len(true_values) != 1:
return False
true_value = true_values.pop()
return true_value
class DelDummyPhi(object):
"""
Del dummy phi
Find nodes which are in the same equivalence class and replace phi nodes by
the class representative.
"""
def src_gen_phi_node_srcs(self, equivalence_graph):
for node in equivalence_graph.nodes():
if not node.is_op("Phi"):
continue
phi_successors = equivalence_graph.successors(node)
for head in phi_successors:
# Walk from head to find if we have a phi merging node
known = set([node])
todo = set([head])
done = set()
while todo:
node = todo.pop()
if node in done:
continue
known.add(node)
is_ok = True
for parent in equivalence_graph.predecessors(node):
if parent not in known:
is_ok = False
break
if not is_ok:
continue
if node.is_op("Phi"):
successors = equivalence_graph.successors(node)
phi_node = successors.pop()
return set([phi_node]), phi_node, head, equivalence_graph
done.add(node)
for successor in equivalence_graph.successors(node):
todo.add(successor)
return None
def get_equivalence_class(self, node, ids_to_src):
todo = set([node])
done = set()
defined = set()
equivalence = set()
src_to_dst = {}
equivalence_graph = DiGraph()
while todo:
dst = todo.pop()
if dst in done:
continue
done.add(dst)
equivalence.add(dst)
src = ids_to_src.get(dst)
if src is None:
# Node is not defined
continue
src_to_dst[src] = dst
defined.add(dst)
if src.is_id():
equivalence_graph.add_uniq_edge(src, dst)
todo.add(src)
elif src.is_op('Phi'):
equivalence_graph.add_uniq_edge(src, dst)
for arg in src.args:
assert arg.is_id()
equivalence_graph.add_uniq_edge(arg, src)
todo.add(arg)
else:
if src.is_mem() or (src.is_op() and src.op.startswith("call")):
if src in equivalence_graph.nodes():
return None
equivalence_graph.add_uniq_edge(src, dst)
equivalence.add(src)
if len(equivalence_graph.heads()) == 0:
raise RuntimeError("Inconsistent graph")
elif len(equivalence_graph.heads()) == 1:
# Every nodes in the equivalence graph may be equivalent to the root
head = equivalence_graph.heads().pop()
successors = equivalence_graph.successors(head)
if len(successors) == 1:
# If successor is an id
successor = successors.pop()
if successor.is_id():
nodes = equivalence_graph.nodes()
nodes.discard(head)
nodes.discard(successor)
nodes = [node for node in nodes if node.is_id()]
return nodes, successor, head, equivalence_graph
else:
# Walk from head to find if we have a phi merging node
known = set()
todo = set([head])
done = set()
while todo:
node = todo.pop()
if node in done:
continue
known.add(node)
is_ok = True
for parent in equivalence_graph.predecessors(node):
if parent not in known:
is_ok = False
break
if not is_ok:
continue
if node.is_op("Phi"):
successors = equivalence_graph.successors(node)
assert len(successors) == 1
phi_node = successors.pop()
return set([phi_node]), phi_node, head, equivalence_graph
done.add(node)
for successor in equivalence_graph.successors(node):
todo.add(successor)
return self.src_gen_phi_node_srcs(equivalence_graph)
def del_dummy_phi(self, ssa, head):
ids_to_src = {}
def_to_loc = {}
for block in viewvalues(ssa.graph.blocks):
for index, assignblock in enumerate(block):
for dst, src in viewitems(assignblock):
if not dst.is_id():
continue
ids_to_src[dst] = src
def_to_loc[dst] = block.loc_key
modified = False
for loc_key in ssa.graph.blocks.keys():
block = ssa.graph.blocks[loc_key]
if not irblock_has_phi(block):
continue
assignblk = block[0]
for dst, phi_src in viewitems(assignblk):
assert phi_src.is_op('Phi')
result = self.get_equivalence_class(dst, ids_to_src)
if result is None:
continue
defined, node, true_value, equivalence_graph = result
if expr_has_mem(true_value):
# Don't propagate ExprMem
continue
if true_value.is_op() and true_value.op.startswith("call"):
# Don't propagate call
continue
# We have an equivalence of nodes
to_del = set(defined)
# Remove all implicated phis
for dst in to_del:
loc_key = def_to_loc[dst]
block = ssa.graph.blocks[loc_key]
assignblk = block[0]
fixed_phis = {}
for old_dst, old_phi_src in viewitems(assignblk):
if old_dst in defined:
continue
fixed_phis[old_dst] = old_phi_src
assignblks = list(block)
assignblks[0] = AssignBlock(fixed_phis, assignblk.instr)
assignblks[1:1] = [AssignBlock({dst: true_value}, assignblk.instr)]
new_irblock = IRBlock(block.loc_db, block.loc_key, assignblks)
ssa.graph.blocks[loc_key] = new_irblock
modified = True
return modified
def replace_expr_from_bottom(expr_orig, dct):
def replace(expr):
if expr in dct:
return dct[expr]
return expr
visitor = ExprVisitorCallbackBottomToTop(lambda expr:replace(expr))
return visitor.visit(expr_orig)
def is_mem_sub_part(needle, mem):
"""
If @needle is a sub part of @mem, return the offset of @needle in @mem
Else, return False
@needle: ExprMem
@mem: ExprMem
"""
ptr_base_a, ptr_offset_a = get_expr_base_offset(needle.ptr)
ptr_base_b, ptr_offset_b = get_expr_base_offset(mem.ptr)
if ptr_base_a != ptr_base_b:
return False
# Test if sub part starts after mem
if not (ptr_offset_b <= ptr_offset_a < ptr_offset_b + mem.size // 8):
return False
# Test if sub part ends before mem
if not (ptr_offset_a + needle.size // 8 <= ptr_offset_b + mem.size // 8):
return False
return ptr_offset_a - ptr_offset_b
class UnionFind(object):
"""
Implementation of UnionFind structure
__classes: a list of Set of equivalent elements
node_to_class: Dictionary linkink an element to its equivalent class
order: Dictionary link an element to it's weight
The order attributes is used to allow the selection of a representative
element of an equivalence class
"""
def __init__(self):
self.index = 0
self.__classes = []
self.node_to_class = {}
self.order = dict()
def copy(self):
"""
Return a copy of the object
"""
unionfind = UnionFind()
unionfind.index = self.index
unionfind.__classes = [set(known_class) for known_class in self.__classes]
node_to_class = {}
for class_eq in unionfind.__classes:
for node in class_eq:
node_to_class[node] = class_eq
unionfind.node_to_class = node_to_class
unionfind.order = dict(self.order)
return unionfind
def replace_node(self, old_node, new_node):
"""
Replace the @old_node by the @new_node
"""
classes = self.get_classes()
new_classes = []
replace_dct = {old_node:new_node}
for eq_class in classes:
new_class = set()
for node in eq_class:
new_class.add(replace_expr_from_bottom(node, replace_dct))
new_classes.append(new_class)
node_to_class = {}
for class_eq in new_classes:
for node in class_eq:
node_to_class[node] = class_eq
self.__classes = new_classes
self.node_to_class = node_to_class
new_order = dict()
for node,index in self.order.items():
new_node = replace_expr_from_bottom(node, replace_dct)
new_order[new_node] = index
self.order = new_order
def get_classes(self):
"""
Return a list of the equivalent classes
"""
classes = []
for class_tmp in self.__classes:
classes.append(set(class_tmp))
return classes
def nodes(self):
for known_class in self.__classes:
for node in known_class:
yield node
def __eq__(self, other):
if self is other:
return True
if self.__class__ is not other.__class__:
return False
return Counter(frozenset(known_class) for known_class in self.__classes) == Counter(frozenset(known_class) for known_class in other.__classes)
def __ne__(self, other):
# required Python 2.7.14
return not self == other
def __str__(self):
components = self.__classes
out = ['UnionFind<']
for component in components:
out.append("\t" + (", ".join([str(node) for node in component])))
out.append('>')
return "\n".join(out)
def add_equivalence(self, node_a, node_b):
"""
Add the new equivalence @node_a == @node_b
@node_a is equivalent to @node_b, but @node_b is more representative
than @node_a
"""
if node_b not in self.order:
self.order[node_b] = self.index
self.index += 1
# As node_a is destination, we always replace its index
self.order[node_a] = self.index
self.index += 1
if node_a not in self.node_to_class and node_b not in self.node_to_class:
new_class = set([node_a, node_b])
self.node_to_class[node_a] = new_class
self.node_to_class[node_b] = new_class
self.__classes.append(new_class)
elif node_a in self.node_to_class and node_b not in self.node_to_class:
known_class = self.node_to_class[node_a]
known_class.add(node_b)
self.node_to_class[node_b] = known_class
elif node_a not in self.node_to_class and node_b in self.node_to_class:
known_class = self.node_to_class[node_b]
known_class.add(node_a)
self.node_to_class[node_a] = known_class
else:
raise RuntimeError("Two nodes cannot be in two classes")
def _get_master(self, node):
if node not in self.node_to_class:
return None
known_class = self.node_to_class[node]
best_node = node
for node in known_class:
if self.order[node] < self.order[best_node]:
best_node = node
return best_node
def get_master(self, node):
"""
Return the representative element of the equivalence class containing
@node
@node: ExprMem or ExprId
"""
if not node.is_mem():
return self._get_master(node)
if node in self.node_to_class:
# Full expr mem is known
return self._get_master(node)
# Test if mem is sub part of known node
for expr in self.node_to_class:
if not expr.is_mem():
continue
ret = is_mem_sub_part(node, expr)
if ret is False:
continue
master = self._get_master(expr)
master = master[ret * 8 : ret * 8 + node.size]
return master
return self._get_master(node)
def del_element(self, node):
"""
Remove @node for the equivalence classes
"""
assert node in self.node_to_class
known_class = self.node_to_class[node]
known_class.discard(node)
del(self.node_to_class[node])
del(self.order[node])
def del_get_new_master(self, node):
"""
Remove @node for the equivalence classes and return it's representative
equivalent element
@node: Element to delete
"""
if node not in self.node_to_class:
return None
known_class = self.node_to_class[node]
known_class.discard(node)
del(self.node_to_class[node])
del(self.order[node])
if not known_class:
return None
best_node = list(known_class)[0]
for node in known_class:
if self.order[node] < self.order[best_node]:
best_node = node
return best_node
class ExprToGraph(ExprWalk):
"""
Transform an Expression into a tree and add link nodes to an existing tree
"""
def __init__(self, graph):
super(ExprToGraph, self).__init__(self.link_nodes)
self.graph = graph
def link_nodes(self, expr, *args, **kwargs):
"""
Transform an Expression @expr into a tree and add link nodes to the
current tree
@expr: Expression
"""
if expr in self.graph.nodes():
return None
self.graph.add_node(expr)
if expr.is_mem():
self.graph.add_uniq_edge(expr, expr.ptr)
elif expr.is_slice():
self.graph.add_uniq_edge(expr, expr.arg)
elif expr.is_cond():
self.graph.add_uniq_edge(expr, expr.cond)
self.graph.add_uniq_edge(expr, expr.src1)
self.graph.add_uniq_edge(expr, expr.src2)
elif expr.is_compose():
for arg in expr.args:
self.graph.add_uniq_edge(expr, arg)
elif expr.is_op():
for arg in expr.args:
self.graph.add_uniq_edge(expr, arg)
return None
class State(object):
"""
Object representing the state of a program at a given point
The state is represented using equivalence classes
Each assignment can create/destroy equivalence classes. Interferences
between expression is computed using `may_interfer` function
"""
def __init__(self):
self.equivalence_classes = UnionFind()
self.undefined = set()
def __str__(self):
return "{0.equivalence_classes}\n{0.undefined}".format(self)
def copy(self):
state = self.__class__()
state.equivalence_classes = self.equivalence_classes.copy()
state.undefined = self.undefined.copy()
return state
def __eq__(self, other):
if self is other:
return True
if self.__class__ is not other.__class__:
return False
return (
set(self.equivalence_classes.nodes()) == set(other.equivalence_classes.nodes()) and
sorted(self.equivalence_classes.edges()) == sorted(other.equivalence_classes.edges()) and
self.undefined == other.undefined
)
def __ne__(self, other):
# required Python 2.7.14
return not self == other
def may_interfer(self, dsts, src):
"""
Return True if @src may interfere with expressions in @dsts
@dsts: Set of Expressions
@src: expression to test
"""
srcs = src.get_r()
for src in srcs:
for dst in dsts:
if dst in src:
return True
if dst.is_mem() and src.is_mem():
dst_base, dst_offset = get_expr_base_offset(dst.ptr)
src_base, src_offset = get_expr_base_offset(src.ptr)
if dst_base != src_base:
return True
dst_size = dst.size // 8
src_size = src.size // 8
# Special case:
# @32[ESP + 0xFFFFFFFE], @32[ESP]
# Both memories alias
if dst_offset + dst_size <= int(dst_base.mask) + 1:
# @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF]
interval1 = interval([(dst_offset, dst_offset + dst.size // 8 - 1)])
else:
# @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF]
interval1 = interval([(dst_offset, int(dst_base.mask))])
interval1 += interval([(0, dst_size - (int(dst_base.mask) + 1 - dst_offset) - 1 )])
if src_offset + src_size <= int(src_base.mask) + 1:
# @32[ESP + 0xFFFFFFFC] => [0xFFFFFFFC, 0xFFFFFFFF]
interval2 = interval([(src_offset, src_offset + src.size // 8 - 1)])
else:
# @32[ESP + 0xFFFFFFFE] => [0x0, 0x1] U [0xFFFFFFFE, 0xFFFFFFFF]
interval2 = interval([(src_offset, int(src_base.mask))])
interval2 += interval([(0, src_size - (int(src_base.mask) + 1 - src_offset) - 1)])
if (interval1 & interval2).empty:
continue
return True
return False
def _get_representative_expr(self, expr):
representative = self.equivalence_classes.get_master(expr)
if representative is None:
return expr
return representative
def get_representative_expr(self, expr):
"""
Replace each sub expression of @expr by its representative element
@expr: Expression to analyse
"""
new_expr = expr.visit(self._get_representative_expr)
return new_expr
def propagation_allowed(self, expr):
"""
Return True if @expr can be propagated
Don't propagate:
- Phi nodes
- call_func_ret / call_func_stack operants
"""
if (
expr.is_op('Phi') or
(expr.is_op() and expr.op.startswith("call_func"))
):
return False
return True
def eval_assignblock(self, assignblock):
"""
Evaluate the @assignblock on the current state
@assignblock: AssignBlock instance
"""
out = dict(assignblock.items())
new_out = dict()
# Replace sub expression by their equivalence class repesentative
for dst, src in out.items():
if src.is_op('Phi'):
# Don't replace in phi
new_src = src
else:
new_src = self.get_representative_expr(src)
if dst.is_mem():
new_ptr = self.get_representative_expr(dst.ptr)
new_dst = ExprMem(new_ptr, dst.size)
else:
new_dst = dst
new_dst = expr_simp(new_dst)
new_src = expr_simp(new_src)
new_out[new_dst] = new_src
# For each destination, update (or delete) dependent's node according to
# equivalence classes
classes = self.equivalence_classes
for dst in new_out:
replacement = classes.del_get_new_master(dst)
if replacement is None:
to_del = set([dst])
to_replace = {}
else:
to_del = set()
to_replace = {dst:replacement}
graph = DiGraph()
# Build en expression graph linking all classes
has_parents = False
for node in classes.nodes():
if dst in node:
# Only dependent nodes are interesting here
has_parents = True
expr_to_graph = ExprToGraph(graph)
expr_to_graph.visit(node)
if not has_parents:
continue
todo = graph.leaves()
done = set()
while todo:
node = todo.pop(0)
if node in done:
continue
# If at least one son is not done, re do later
if [son for son in graph.successors(node) if son not in done]:
todo.append(node)
continue
done.add(node)
# If at least one son cannot be replaced (deleted), our last
# chance is to have an equivalence
if any(son in to_del for son in graph.successors(node)):
# One son has been deleted!
# Try to find a replacement of the whole expression
replacement = classes.del_get_new_master(node)
if replacement is None:
to_del.add(node)
for predecessor in graph.predecessors(node):
if predecessor not in todo:
todo.append(predecessor)
continue
else:
to_replace[node] = replacement
# Continue with replacement
# Everyson is live or has been replaced
new_node = node.replace_expr(to_replace)
if new_node == node:
# If node is not touched (Ex: leaf node)
for predecessor in graph.predecessors(node):
if predecessor not in todo:
todo.append(predecessor)
continue
# Node has been modified, update equivalence classes
classes.replace_node(node, new_node)
to_replace[node] = new_node
for predecessor in graph.predecessors(node):
if predecessor not in todo:
todo.append(predecessor)
continue
new_assignblk = AssignBlock(new_out, assignblock.instr)
dsts = new_out.keys()
# Remove interfering known classes
to_del = set()
for node in list(classes.nodes()):
if self.may_interfer(dsts, node):
# Interfere with known equivalence class
self.equivalence_classes.del_element(node)
if node.is_id() or node.is_mem():
self.undefined.add(node)
# Update equivalence classes
for dst, src in new_out.items():
# Delete equivalence class interfering with dst
to_del = set()
classes = self.equivalence_classes
for node in classes.nodes():
if dst in node:
to_del.add(node)
for node in to_del:
self.equivalence_classes.del_element(node)
if node.is_id() or node.is_mem():
self.undefined.add(node)
# Don't create equivalence if self interfer
if self.may_interfer(dsts, src):
if dst in self.equivalence_classes.nodes():
self.equivalence_classes.del_element(dst)
if dst.is_id() or dst.is_mem():
self.undefined.add(dst)
continue
if not self.propagation_allowed(src):
continue
self.undefined.discard(dst)
if dst in self.equivalence_classes.nodes():
self.equivalence_classes.del_element(dst)
self.equivalence_classes.add_equivalence(dst, src)
return new_assignblk
def merge(self, other):
"""
Merge the current state with @other
Merge rules:
- if two nodes are equal in both states => in equivalence class
- if node value is different or non present in another state => undefined
@other: State instance
"""
classes1 = self.equivalence_classes
classes2 = other.equivalence_classes
undefined = set(node for node in self.undefined if node.is_id() or node.is_mem())
undefined.update(set(node for node in other.undefined if node.is_id() or node.is_mem()))
# Should we compute interference between srcs and undefined ?
# Nop => should already interfere in other state
components1 = classes1.get_classes()
components2 = classes2.get_classes()
node_to_component2 = {}
for component in components2:
for node in component:
node_to_component2[node] = component
# Compute intersection of equivalence classes of states
out = []
nodes_ok = set()
while components1:
component1 = components1.pop()
for node in component1:
if node in undefined:
continue
component2 = node_to_component2.get(node)
if component2 is None:
if node.is_id() or node.is_mem():
assert(node not in nodes_ok)
undefined.add(node)
continue
if node not in component2:
continue
# Found two classes containing node
common = component1.intersection(component2)
if len(common) == 1:
# Intersection contains only one node => undefine node
if node.is_id() or node.is_mem():
assert(node not in nodes_ok)
undefined.add(node)
component2.discard(common.pop())
continue
if common:
# Intersection contains multiple nodes
# Here, common nodes don't interfere with any undefined
nodes_ok.update(common)
out.append(common)
diff = component1.difference(common)
if diff:
components1.append(diff)
component2.difference_update(common)
break
# Discard remaining components2 elements
for component in components2:
for node in component:
if node.is_id() or node.is_mem():
assert(node not in nodes_ok)
undefined.add(node)
all_nodes = set()
for common in out:
all_nodes.update(common)
new_order = dict(
(node, index) for (node, index) in classes1.order.items()
if node in all_nodes
)
unionfind = UnionFind()
new_classes = []
global_max_index = 0
for common in out:
min_index = None
master = None
for node in common:
index = new_order[node]
global_max_index = max(index, global_max_index)
if min_index is None or min_index > index:
min_index = index
master = node
for node in common:
if node == master:
continue
unionfind.add_equivalence(node, master)
unionfind.index = global_max_index
unionfind.order = new_order
state = self.__class__()
state.equivalence_classes = unionfind
state.undefined = undefined
return state
class PropagateExpressions(object):
"""
Propagate expressions
The algorithm propagates equivalence classes expressions from the entry
point. During the analyse, we replace source nodes by its equivalence
classes representative. Equivalence classes can be modified during analyse
due to memory aliasing.
For example:
B = A+1
C = A
A = 6
D = [B]
Will result in:
B = A+1
C = A
A = 6
D = [C+1]
"""
@staticmethod
def new_state():
return State()
def merge_prev_states(self, ircfg, states, loc_key):
"""
Merge predecessors states of irblock at location @loc_key
@ircfg: IRCfg instance
@states: Dictionary linking locations to state
@loc_key: location of the current irblock
"""
prev_states = []
for predecessor in ircfg.predecessors(loc_key):
prev_states.append((predecessor, states[predecessor]))
filtered_prev_states = []
for (_, prev_state) in prev_states:
if prev_state is not None:
filtered_prev_states.append(prev_state)
prev_states = filtered_prev_states
if not prev_states:
state = self.new_state()
elif len(prev_states) == 1:
state = prev_states[0].copy()
else:
while prev_states:
state = prev_states.pop()
if state is not None:
break
for prev_state in prev_states:
state = state.merge(prev_state)
return state
def update_state(self, irblock, state):
"""
Propagate the @state through the @irblock
@irblock: IRBlock instance
@state: State instance
"""
new_assignblocks = []
modified = False
for assignblock in irblock:
if not assignblock.items():
continue
new_assignblk = state.eval_assignblock(assignblock)
new_assignblocks.append(new_assignblk)
if new_assignblk != assignblock:
modified = True
new_irblock = IRBlock(irblock.loc_db, irblock.loc_key, new_assignblocks)
return new_irblock, modified
def propagate(self, ssa, head, max_expr_depth=None):
"""
Apply algorithm on the @ssa graph
"""
ircfg = ssa.ircfg
self.loc_db = ircfg.loc_db
irblocks = ssa.ircfg.blocks
states = {}
for loc_key, irblock in irblocks.items():
states[loc_key] = None
todo = deque([head])
while todo:
loc_key = todo.popleft()
irblock = irblocks.get(loc_key)
if irblock is None:
continue
state_orig = states[loc_key]
state = self.merge_prev_states(ircfg, states, loc_key)
state = state.copy()
new_irblock, modified_irblock = self.update_state(irblock, state)
if state_orig is not None:
# Merge current and previous state
state = state.merge(state_orig)
if (state.equivalence_classes == state_orig.equivalence_classes and
state.undefined == state_orig.undefined
):
continue
states[loc_key] = state
# Propagate to sons
for successor in ircfg.successors(loc_key):
todo.append(successor)
# Update blocks
todo = set(loc_key for loc_key in irblocks)
modified = False
while todo:
loc_key = todo.pop()
irblock = irblocks.get(loc_key)
if irblock is None:
continue
state = self.merge_prev_states(ircfg, states, loc_key)
new_irblock, modified_irblock = self.update_state(irblock, state)
modified |= modified_irblock
irblocks[new_irblock.loc_key] = new_irblock
return modified