best-doctor/mr_proper

View on GitHub
mr_proper/utils/ast_pure.py

Summary

Maintainability
B
5 hrs
Test Coverage
C
77%
import ast
import inspect
import importlib.util
import sys
from typing import Optional, List, cast

from mr_proper.common_types import AnyFuncdef
from mr_proper.utils.ast import BUILTINS_LIST, get_nodes_from_funcdef_body, get_ast_tree

if sys.version_info >= (3, 8):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict


class EntityImportInfo(TypedDict):
    file_path: Optional[str]
    import_path: Optional[str]
    name: str


def get_not_pure_internal_calls(
    funcdef_node: AnyFuncdef,
    file_ast_tree: ast.Module,
    pyfilepath: str,
) -> List[str]:
    from mr_proper.public_api import is_function_pure

    not_pure_calls: List[str] = []
    for call_node in get_nodes_from_funcdef_body(funcdef_node, [ast.Call]):
        if not isinstance(call_node.func, ast.Name):
            continue  # recursively check only functions calls, not methods/attrs calls
        if (
            call_node.func.id.lower() == call_node.func.id  # check for only snake_case calls
            and call_node.func.id not in BUILTINS_LIST
        ):
            import_info = get_name_import_path(call_node.func, pyfilepath)
            imported_funcdef_node = get_funcdef_by(import_info) if import_info else None
            if imported_funcdef_node and import_info:
                filepath = cast(str, import_info['file_path'])
                is_call_clean: Optional[bool] = is_function_pure(
                    imported_funcdef_node,
                    file_ast_tree=file_ast_tree,
                    recursive=False,
                    pyfilepath=filepath,
                    with_errors=False,
                )
            else:
                is_call_clean = None
            if is_call_clean is False:
                not_pure_calls.append(call_node.func.id)
    return not_pure_calls


def get_name_import_path(name_node: ast.Name, pyfilepath: str) -> Optional[EntityImportInfo]:
    current_node = name_node.parent  # type: ignore
    while True:
        for child in ast.iter_child_nodes(current_node):
            # check for Import, not only ImportFrom
            if isinstance(child, ast.ImportFrom) and name_node.id in (n.name for n in child.names):
                return extract_import_info_from_import_node(child, name_node)
            elif (
                isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
                and child.name == name_node.id
            ):
                return {
                    'file_path': pyfilepath if pyfilepath != 'built-in' else None,
                    'import_path': None,
                    'name': name_node.id,
                }
        current_node = current_node.parent
        if not current_node:
            break
    return None


def extract_import_info_from_import_node(
    import_from_node: ast.ImportFrom,
    name_node: ast.Name,
) -> EntityImportInfo:
    import_path = import_from_node.module
    if import_path is None:
        filepath = None
    else:
        filepath = get_file_path_by(import_path, name_node.id)
        if filepath == 'built-in':
            filepath = None
    return {'file_path': filepath, 'import_path': import_from_node.module, 'name': name_node.id}


def get_funcdef_by(import_info: EntityImportInfo) -> Optional[AnyFuncdef]:
    file_path = import_info['file_path']
    if file_path:
        ast_tree = get_ast_tree(file_path)
        if not ast_tree:
            return None
        try:
            funcdef = [
                n for n in ast.walk(ast_tree)
                if (
                    isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
                    and n.name == import_info['name']
                )
            ]
        except AttributeError:
            return None
        if funcdef:
            return funcdef[0]
    return None


def get_file_path_by(
    import_path: str,
    imported_name: str,
    dynamic_resolver: bool = False,
) -> Optional[str]:
    if dynamic_resolver:
        module = importlib.import_module(import_path)
        return inspect.getfile(getattr(module, imported_name))
    else:
        try:
            module_spec = importlib.util.find_spec(import_path)
        except ModuleNotFoundError:
            return None
        return module_spec.origin if module_spec else None