python-security/pyt

View on GitHub
pyt/cfg/expr_visitor.py

Summary

Maintainability
D
2 days
Test Coverage
import ast
import logging

from .alias_helper import handle_aliases_in_calls
from ..core.ast_helper import (
    Arguments,
    get_call_names_as_string
)
from ..core.module_definitions import ModuleDefinitions
from ..core.node_types import (
    AssignmentCallNode,
    AssignmentNode,
    BBorBInode,
    ConnectToExitNode,
    EntryOrExitNode,
    IgnoredNode,
    Node,
    RestoreNode,
    ReturnNode,
    YieldNode
)
from .expr_visitor_helper import (
    BUILTINS,
    MUTATORS,
    return_connection_handler,
    SavedVariable
)
from ..helper_visitors import (
    LabelVisitor,
    RHSVisitor
)
from .stmt_visitor import StmtVisitor
from .stmt_visitor_helper import CALL_IDENTIFIER

log = logging.getLogger(__name__)


class ExprVisitor(StmtVisitor):
    def __init__(
        self,
        node,
        project_modules,
        local_modules,
        filename,
        module_definitions=None,
        allow_local_directory_imports=True
    ):
        """Create an empty CFG."""
        super().__init__(allow_local_directory_imports=allow_local_directory_imports)
        self.project_modules = project_modules
        self.local_modules = local_modules if self._allow_local_modules else []
        self.filenames = [filename]
        self.blackbox_assignments = set()
        self.nodes = list()
        self.function_call_index = 0
        self.undecided = False
        self.function_names = list()
        self.function_return_stack = list()
        self.function_definition_stack = list()  # used to avoid recursion
        self.module_definitions_stack = list()
        self.prev_nodes_to_avoid = list()
        self.last_control_flow_nodes = list()
        self._within_mutating_call = False

        # Are we already in a module?
        if module_definitions:
            self.init_function_cfg(node, module_definitions)
        else:
            self.init_cfg(node)

    def init_cfg(self, node):
        self.module_definitions_stack.append(ModuleDefinitions(filename=self.filenames[-1]))

        entry_node = self.append_node(EntryOrExitNode('Entry module'))

        module_statements = self.visit(node)

        if not module_statements:
            raise Exception('Empty module. It seems that your file is empty,' +
                            'there is nothing to analyse.')

        exit_node = self.append_node(EntryOrExitNode('Exit module'))

        if isinstance(module_statements, IgnoredNode):
            entry_node.connect(exit_node)
            return

        first_node = module_statements.first_statement

        if CALL_IDENTIFIER not in first_node.label:
            entry_node.connect(first_node)

        last_nodes = module_statements.last_statements
        exit_node.connect_predecessors(last_nodes)

    def init_function_cfg(self, node, module_definitions):
        self.module_definitions_stack.append(module_definitions)

        self.function_names.append(node.name)
        self.function_return_stack.append(node.name)

        entry_node = self.append_node(EntryOrExitNode('Entry function'))

        module_statements = self.stmt_star_handler(node.body)
        exit_node = self.append_node(EntryOrExitNode('Exit function'))

        if isinstance(module_statements, IgnoredNode):
            entry_node.connect(exit_node)
            return

        first_node = module_statements.first_statement

        if CALL_IDENTIFIER not in first_node.label:
            entry_node.connect(first_node)

        last_nodes = module_statements.last_statements
        exit_node.connect_predecessors(last_nodes)

    def visit_Yield(self, node):
        label = LabelVisitor()
        label.visit(node)

        if node.value is None:
            rhs_visitor_result = []
        else:
            rhs_visitor_result = RHSVisitor.result_for_node(node.value)

        # Yield is a bit like augmented assignment to a return value
        this_function_name = self.function_return_stack[-1]
        LHS = 'yld_' + this_function_name
        return self.append_node(YieldNode(
            LHS + ' += ' + label.result,
            LHS,
            node,
            rhs_visitor_result + [LHS],
            path=self.filenames[-1])
        )

    def visit_YieldFrom(self, node):
        return self.visit_Yield(node)

    def visit_Attribute(self, node):
        return self.visit_miscelleaneous_node(
            node
        )

    def visit_Name(self, node):
        return self.visit_miscelleaneous_node(
            node
        )

    def visit_NameConstant(self, node):
        return self.visit_miscelleaneous_node(
            node
        )

    def visit_Str(self, node):
        return IgnoredNode()

    def visit_Subscript(self, node):
        return self.visit_miscelleaneous_node(
            node
        )

    def visit_Tuple(self, node):
        return self.visit_miscelleaneous_node(
            node
        )

    def connect_if_allowed(
        self,
        previous_node,
        node_to_connect_to
    ):
        # e.g.
        # while x != 10:
        #     if x > 0:
        #         print(x)
        #         break
        #     else:
        #         print('hest')
        # print('next')  # self.nodes[-1] is print('hest')
        #
        # So we connect to `while x!= 10` instead
        if self.last_control_flow_nodes[-1]:
            self.last_control_flow_nodes[-1].connect(node_to_connect_to)
            self.last_control_flow_nodes[-1] = None
            return

        # Except in this case:
        #
        # if not image_name:
        #     return 404
        # print('foo')  # We do not want to connect this line with `return 404`
        if previous_node is not self.prev_nodes_to_avoid[-1] and not isinstance(previous_node, ReturnNode):
            previous_node.connect(node_to_connect_to)

    def save_local_scope(
        self,
        line_number,
        saved_function_call_index
    ):
        """Save the local scope before entering a function call by saving all the LHS's of assignments so far.

        Args:
            line_number(int): Of the def of the function call about to be entered into.
            saved_function_call_index(int): Unique number for each call.

        Returns:
            saved_variables(list[SavedVariable])
            first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
        """
        saved_variables = list()
        saved_variables_so_far = set()
        first_node = None

        # Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode
        for assignment in [node for node in self.nodes
                           if (type(node) == AssignmentNode or
                               type(node) == AssignmentCallNode or
                               type(Node) == BBorBInode)]:  # type() is used on purpose here
            if assignment.left_hand_side in saved_variables_so_far:
                continue
            saved_variables_so_far.add(assignment.left_hand_side)
            save_name = 'save_{}_{}'.format(saved_function_call_index, assignment.left_hand_side)

            previous_node = self.nodes[-1]

            saved_scope_node = RestoreNode(
                save_name + ' = ' + assignment.left_hand_side,
                save_name,
                [assignment.left_hand_side],
                line_number=line_number,
                path=self.filenames[-1]
            )
            if not first_node:
                first_node = saved_scope_node

            self.nodes.append(saved_scope_node)
            # Save LHS
            saved_variables.append(SavedVariable(LHS=save_name,
                                                 RHS=assignment.left_hand_side))
            self.connect_if_allowed(previous_node, saved_scope_node)

        return (saved_variables, first_node)

    def save_def_args_in_temp(
        self,
        call_args,
        def_args,
        line_number,
        saved_function_call_index,
        first_node
    ):
        """Save the arguments of the definition being called. Visit the arguments if they're calls.

        Args:
            call_args(list[ast.Name]): Of the call being made.
            def_args(ast_helper.Arguments): Of the definition being called.
            line_number(int): Of the call being made.
            saved_function_call_index(int): Unique number for each call.
            first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.

        Returns:
            args_mapping(dict): A mapping of call argument to definition argument.
            first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
        """
        args_mapping = dict()
        last_return_value_of_nested_call = None

        # Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument
        for i, call_arg in enumerate(call_args):
            # If this results in an IndexError it is invalid Python
            def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]

            return_value_of_nested_call = None
            if isinstance(call_arg, ast.Call):
                return_value_of_nested_call = self.visit(call_arg)
                restore_node = RestoreNode(
                    def_arg_temp_name + ' = ' + return_value_of_nested_call.left_hand_side,
                    def_arg_temp_name,
                    [return_value_of_nested_call.left_hand_side],
                    line_number=line_number,
                    path=self.filenames[-1]
                )
                if return_value_of_nested_call in self.blackbox_assignments:
                    self.blackbox_assignments.add(restore_node)
            else:
                call_arg_label_visitor = LabelVisitor()
                call_arg_label_visitor.visit(call_arg)
                call_arg_rhs_visitor = RHSVisitor()
                call_arg_rhs_visitor.visit(call_arg)
                restore_node = RestoreNode(
                    def_arg_temp_name + ' = ' + call_arg_label_visitor.result,
                    def_arg_temp_name,
                    call_arg_rhs_visitor.result,
                    line_number=line_number,
                    path=self.filenames[-1]
                )

            # If there are no saved variables, then this is the first node
            if not first_node:
                first_node = restore_node

            if isinstance(call_arg, ast.Call):
                if last_return_value_of_nested_call:
                    # connect inner to other_inner in e.g. `outer(inner(image_name), other_inner(image_name))`
                    if isinstance(return_value_of_nested_call, BBorBInode):
                        last_return_value_of_nested_call.connect(return_value_of_nested_call)
                    else:
                        last_return_value_of_nested_call.connect(return_value_of_nested_call.first_node)
                else:
                    # I should only set this once per loop, inner in e.g. `outer(inner(image_name), other_inner(image_name))`
                    # (inner_most_call is used when predecessor is a ControlFlowNode in connect_control_flow_node)
                    if isinstance(return_value_of_nested_call, BBorBInode):
                        first_node.inner_most_call = return_value_of_nested_call
                    else:
                        first_node.inner_most_call = return_value_of_nested_call.first_node
                # We purposefully should not set this as the first_node of return_value_of_nested_call, last makes sense
                last_return_value_of_nested_call = return_value_of_nested_call
            self.connect_if_allowed(self.nodes[-1], restore_node)
            self.nodes.append(restore_node)

            if isinstance(call_arg, ast.Call):
                args_mapping[return_value_of_nested_call.left_hand_side] = def_args[i]
            else:
                args_mapping[def_args[i]] = call_arg_label_visitor.result

        return (args_mapping, first_node)

    def create_local_scope_from_def_args(
        self,
        call_args,
        def_args,
        line_number,
        saved_function_call_index
    ):
        """Create the local scope before entering the body of a function call.

        Args:
            call_args(list[ast.Name]): Of the call being made.
            def_args(ast_helper.Arguments): Of the definition being called.
            line_number(int): Of the def of the function call about to be entered into.
            saved_function_call_index(int): Unique number for each call.

        Note: We do not need a connect_if_allowed because of the
              preceding call to save_def_args_in_temp.
        """
        # Create e.g. def_arg1 = temp_N_def_arg1 for each argument
        for i in range(len(call_args)):
            def_arg_local_name = def_args[i]
            def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]
            local_scope_node = RestoreNode(
                def_arg_local_name + ' = ' + def_arg_temp_name,
                def_arg_local_name,
                [def_arg_temp_name],
                line_number=line_number,
                path=self.filenames[-1]
            )
            # Chain the local scope nodes together
            self.nodes[-1].connect(local_scope_node)
            self.nodes.append(local_scope_node)

    def visit_and_get_function_nodes(
        self,
        definition,
        first_node
    ):
        """Visits the nodes of a user defined function.

        Args:
            definition(LocalModuleDefinition): Definition of the function being added.
            first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.

        Returns:
            the_new_nodes(list[Node]): The nodes added while visiting the function.
            first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
        """
        len_before_visiting_func = len(self.nodes)
        previous_node = self.nodes[-1]
        entry_node = self.append_node(EntryOrExitNode('Function Entry ' +
                                                      definition.name))
        if not first_node:
            first_node = entry_node
        self.connect_if_allowed(previous_node, entry_node)

        function_body_connect_statements = self.stmt_star_handler(definition.node.body)
        entry_node.connect(function_body_connect_statements.first_statement)

        exit_node = self.append_node(EntryOrExitNode('Exit ' + definition.name))
        exit_node.connect_predecessors(function_body_connect_statements.last_statements)

        the_new_nodes = self.nodes[len_before_visiting_func:]
        return_connection_handler(the_new_nodes, exit_node)

        return (the_new_nodes, first_node)

    def restore_saved_local_scope(
        self,
        saved_variables,
        args_mapping,
        line_number
    ):
        """Restore the previously saved variables to their original values.

        Args:
           saved_variables(list[SavedVariable])
           args_mapping(dict): A mapping of call argument to definition argument.
           line_number(int): Of the def of the function call about to be entered into.

        Note: We do not need connect_if_allowed because of the
              preceding call to save_local_scope.
        """
        restore_nodes = list()
        for var in saved_variables:
            # Is var.RHS a call argument?
            if var.RHS in args_mapping:
                # If so, use the corresponding definition argument for the RHS of the label.
                restore_nodes.append(RestoreNode(
                    var.RHS + ' = ' + args_mapping[var.RHS],
                    var.RHS,
                    [var.LHS],
                    line_number=line_number,
                    path=self.filenames[-1]
                ))
            else:
                # Create a node for e.g. foo = save_1_foo
                restore_nodes.append(RestoreNode(
                    var.RHS + ' = ' + var.LHS,
                    var.RHS,
                    [var.LHS],
                    line_number=line_number,
                    path=self.filenames[-1]
                ))

        # Chain the restore nodes
        for node, successor in zip(restore_nodes, restore_nodes[1:]):
            node.connect(successor)

        if restore_nodes:
            # Connect the last node to the first restore node
            self.nodes[-1].connect(restore_nodes[0])
            self.nodes.extend(restore_nodes)

        return restore_nodes

    def return_handler(
        self,
        call_node,
        function_nodes,
        saved_function_call_index,
        first_node
    ):
        """Handle the return from a function during a function call.

        Args:
            call_node(ast.Call) : The node that calls the definition.
            function_nodes(list[Node]): List of nodes of the function being called.
            saved_function_call_index(int): Unique number for each call.
            first_node(EntryOrExitNode or RestoreNode): Used to connect previous statements to this function.
        """
        if any(isinstance(node, YieldNode) for node in function_nodes):
            # Presence of a `YieldNode` means that the function is a generator
            rhs_prefix = 'yld_'
        elif any(isinstance(node, ConnectToExitNode) for node in function_nodes):
            # Only `Return`s and `Raise`s can be of type ConnectToExitNode
            rhs_prefix = 'ret_'
        else:
            return  # No return value

        # Create e.g. ~call_1 = ret_func_foo RestoreNode
        LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index)
        RHS = rhs_prefix + get_call_names_as_string(call_node.func)
        return_node = RestoreNode(
            LHS + ' = ' + RHS,
            LHS,
            [RHS],
            line_number=call_node.lineno,
            path=self.filenames[-1]
        )
        return_node.first_node = first_node
        self.nodes[-1].connect(return_node)
        self.nodes.append(return_node)

    def process_function(self, call_node, definition):
        """Processes a user defined function when it is called.

        Increments self.function_call_index each time it is called, we can refer to it as N in the comments.
        Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode. (save_local_scope)
        Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument.
            Visit the arguments if they're calls. (save_def_args_in_temp)
        Create e.g. def_arg1 = temp_N_def_arg1 for each argument. (create_local_scope_from_def_args)
        Visit and get function nodes. (visit_and_get_function_nodes)
        Loop through each save_N_LHS node and create an e.g.
            foo = save_1_foo or, if foo was a call arg, foo = arg_mapping[foo]. (restore_saved_local_scope)
        Create e.g. ~call_1 = ret_func_foo RestoreNode. (return_handler)

        Notes:
            Page 31 in the original thesis, but changed a little.
            We don't have to return the ~call_1 = ret_func_foo RestoreNode made in return_handler,
                because it's the last node anyway, that we return in this function.
            e.g. ret_func_foo gets assigned to visit_Return.

        Args:
            call_node(ast.Call) : The node that calls the definition.
            definition(LocalModuleDefinition): Definition of the function being called.

        Returns:
            Last node in self.nodes, probably the return of the function appended to self.nodes in return_handler.
        """
        self.function_call_index += 1
        saved_function_call_index = self.function_call_index

        def_node = definition.node

        saved_variables, first_node = self.save_local_scope(
            def_node.lineno,
            saved_function_call_index
        )

        args_mapping, first_node = self.save_def_args_in_temp(
            call_node.args,
            Arguments(def_node.args),
            call_node.lineno,
            saved_function_call_index,
            first_node
        )
        self.filenames.append(definition.path)
        self.create_local_scope_from_def_args(
            call_node.args,
            Arguments(def_node.args),
            def_node.lineno,
            saved_function_call_index
        )
        function_nodes, first_node = self.visit_and_get_function_nodes(
            definition,
            first_node
        )
        self.filenames.pop()  # Should really probably move after restore_saved_local_scope!!!
        self.restore_saved_local_scope(
            saved_variables,
            args_mapping,
            def_node.lineno
        )
        self.return_handler(
            call_node,
            function_nodes,
            saved_function_call_index,
            first_node
        )
        self.function_return_stack.pop()
        self.function_definition_stack.pop()

        return self.nodes[-1]

    def visit_Call(self, node):
        _id = get_call_names_as_string(node.func)
        local_definitions = self.module_definitions_stack[-1]

        alias = handle_aliases_in_calls(_id, local_definitions.import_alias_mapping)
        if alias:
            definition = local_definitions.get_definition(alias)
        else:
            definition = local_definitions.get_definition(_id)

        # e.g. "request.args.get" -> "get"
        last_attribute = _id.rpartition('.')[-1]

        if definition:
            if definition in self.function_definition_stack:
                log.debug("Recursion encountered in function %s", _id)
                return self.add_blackbox_or_builtin_call(node, blackbox=True)
            if isinstance(definition.node, ast.ClassDef):
                self.add_blackbox_or_builtin_call(node, blackbox=False)
            elif isinstance(definition.node, ast.FunctionDef):
                self.undecided = False
                self.function_return_stack.append(_id)
                self.function_definition_stack.append(definition)
                return self.process_function(node, definition)
            else:
                raise Exception('Definition was neither FunctionDef or ' +
                                'ClassDef, cannot add the function ')
        elif (
            not self._within_mutating_call and
            last_attribute in MUTATORS
            and isinstance(node.func, ast.Attribute)
        ):
            # Change list.append(x) ---> list += list.append(x)
            # This does in fact propagate as we don't know that append returns None
            fake_aug_assign = ast.AugAssign(
                target=node.func.value,
                op=ast.Add,
                value=node,
            )
            ast.copy_location(fake_aug_assign, node)
            self._within_mutating_call = True  # Don't do this recursively
            result = self.visit(fake_aug_assign)
            self._within_mutating_call = False
            return result
        elif last_attribute not in BUILTINS:
            # Mark the call as a blackbox because we don't have the definition
            return self.add_blackbox_or_builtin_call(node, blackbox=True)
        return self.add_blackbox_or_builtin_call(node, blackbox=False)