python-security/pyt

View on GitHub
pyt/core/transformer.py

Summary

Maintainability
A
0 mins
Test Coverage
import ast


class AsyncTransformer():
    """Converts all async nodes into their synchronous counterparts."""

    def visit_Await(self, node):
        """Awaits are treated as if the keyword was absent."""
        return self.visit(node.value)

    def visit_AsyncFunctionDef(self, node):
        return self.visit(ast.FunctionDef(**node.__dict__))

    def visit_AsyncFor(self, node):
        return self.visit(ast.For(**node.__dict__))

    def visit_AsyncWith(self, node):
        return self.visit(ast.With(**node.__dict__))


class ChainedFunctionTransformer():
    def visit_chain(self, node, depth=1):
        if (
            isinstance(node.value, ast.Call) and
            isinstance(node.value.func, ast.Attribute) and
            isinstance(node.value.func.value, ast.Call)
        ):
            # Node is assignment or return with value like `b.c().d()`
            call_node = node.value
            # If we want to handle nested functions in future, depth needs fixing
            temp_var_id = '__chain_tmp_{}'.format(depth)
            # AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
            unvisited_inner_call = ast.Assign(
                targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
                value=call_node.func.value,
            )
            ast.copy_location(unvisited_inner_call, node)
            inner_calls = self.visit_chain(unvisited_inner_call, depth + 1)
            for inner_call_node in inner_calls:
                ast.copy_location(inner_call_node, node)
            outer_call = self.generic_visit(type(node)(
                value=ast.Call(
                    func=ast.Attribute(
                        value=ast.Name(id=temp_var_id, ctx=ast.Load()),
                        attr=call_node.func.attr,
                        ctx=ast.Load(),
                    ),
                    args=call_node.args,
                    keywords=call_node.keywords,
                ),
                **{field: value for field, value in ast.iter_fields(node) if field != 'value'}  # e.g. targets
            ))
            ast.copy_location(outer_call, node)
            ast.copy_location(outer_call.value, node)
            ast.copy_location(outer_call.value.func, node)
            return [*inner_calls, outer_call]
        else:
            return [self.generic_visit(node)]

    def visit_Assign(self, node):
        return self.visit_chain(node)

    def visit_Return(self, node):
        return self.visit_chain(node)


class IfExpRewriter(ast.NodeTransformer):
    """Splits IfExp ternary expressions containing complex tests into multiple statements

    Will change

    a if b(c) else d

    into

    a if __if_exp_0 else d

    with Assign nodes in assignments [__if_exp_0 = b(c)]
    """

    def __init__(self, starting_index=0):
        self._temporary_variable_index = starting_index
        self.assignments = []
        super().__init__()

    def visit_IfExp(self, node):
        if isinstance(node.test, (ast.Name, ast.Attribute)):
            return self.generic_visit(node)
        else:
            temp_var_id = '__if_exp_{}'.format(self._temporary_variable_index)
            self._temporary_variable_index += 1
            assignment_of_test = ast.Assign(
                targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
                value=self.visit(node.test),
            )
            ast.copy_location(assignment_of_test, node)
            self.assignments.append(assignment_of_test)
            transformed_if_exp = ast.IfExp(
                test=ast.Name(id=temp_var_id, ctx=ast.Load()),
                body=self.visit(node.body),
                orelse=self.visit(node.orelse),
            )
            ast.copy_location(transformed_if_exp, node)
            return transformed_if_exp

    def visit_FunctionDef(self, node):
        return node


class IfExpTransformer:
    """Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""

    def visit_body(self, nodes):
        new_nodes = []
        count = 0
        for node in nodes:
            rewriter = IfExpRewriter(count)
            possibly_transformed_node = rewriter.visit(node)
            if rewriter.assignments:
                new_nodes.extend(rewriter.assignments)
                count += len(rewriter.assignments)
            new_nodes.append(possibly_transformed_node)
        return new_nodes

    def visit_FunctionDef(self, node):
        transformed = ast.FunctionDef(
            name=node.name,
            args=node.args,
            body=self.visit_body(node.body),
            decorator_list=node.decorator_list,
            returns=node.returns
        )
        ast.copy_location(transformed, node)
        return self.generic_visit(transformed)

    def visit_Module(self, node):
        transformed = ast.Module(self.visit_body(node.body))
        ast.copy_location(transformed, node)
        return self.generic_visit(transformed)


class PytTransformer(AsyncTransformer, IfExpTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
    pass