petr-muller/pyff

View on GitHub
pyff/imports.py

Summary

Maintainability
B
4 hrs
Test Coverage
"""This module contains code that handles comparing imports"""

import collections.abc
import types
from typing import Set, Dict, Union, Optional, FrozenSet, Mapping, cast
import ast
import logging
from pyff.kitchensink import hl, hlistify, pluralize

ImportNode = Union[ast.Import, ast.ImportFrom]  # pylint: disable=invalid-name

LOGGER = logging.getLogger(__name__)


class ImportedName:
    """Represents a single imported name"""

    def __init__(self, name: str, node: ImportNode, alias: ast.alias) -> None:
        self.name: str = name
        self.node: ImportNode = node
        self.alias: ast.alias = alias

    def __repr__(self):  # pragma: no cover
        return f"ImportedName(name={self.name} node={self.node} alias={self.alias}"

    def is_import(self) -> bool:
        """Returns True if name was imported with `import X` statement"""
        return isinstance(self.node, ast.Import)

    def is_import_from(self) -> bool:
        """Returns True if name was imported with `from Y import X` statement"""
        return isinstance(self.node, ast.ImportFrom)

    @property
    def canonical_name(self) -> str:
        """Returns whole name.

        Example: For 'join' imported by 'from os.path import join', returns 'os.path.join'"""
        if isinstance(self.node, ast.Import):
            return self.alias.name
        elif isinstance(self.node, ast.ImportFrom):
            return f"{self.node.module}.{self.alias.name}"

        raise Exception("Node should always be one of {Import, ImportFrom}")  # pragma: no cover

    @property
    def canonical_ast(self) -> Union[ast.Name, ast.Attribute]:
        """Returns AST node for the full name

        Example: For 'join' imported by 'from os.path import join', returns AST of 'os.path.join'"""
        node: Union[ast.Name, ast.Attribute]

        if isinstance(self.node, ast.Import):
            items = self.alias.name.split(".")
            node = ast.Name(id=items.pop(0), ctx=ast.Load())
            while items:
                node = ast.Attribute(value=node, attr=items.pop(0), ctx=ast.Load())
            return node
        elif isinstance(self.node, ast.ImportFrom):
            if self.node.module is None:
                raise Exception(
                    "ast.ImportFrom has module attribute set to None"
                )  # pragma: no cover
            items = self.node.module.split(".") + [self.alias.name]
            node = ast.Name(id=items.pop(0), ctx=ast.Load())
            while items:
                node = ast.Attribute(value=node, attr=items.pop(0), ctx=ast.Load())
            return node

        raise Exception("Node should always be one of {Import, ImportFrom}")  # pragma: no cover

    def __str__(self):
        return self.name


class FromImportPyfference:
    """Represents difference in `from X import Y` between two ImportedNames"""

    def __init__(self):
        self._new: Dict[str, Set[ImportedName]] = {}
        self._removed: Dict[str, Set[ImportedName]] = {}
        self._new_modules: Set[str] = set()
        self._removed_modules: Set[str] = set()

    @property
    def new(self) -> Mapping[str, Set[ImportedName]]:
        """Returns a read-only mapping of new imported-from names"""
        return types.MappingProxyType(self._new)

    @property
    def removed(self) -> Mapping[str, Set[ImportedName]]:
        """Returns a read-only mapping of removed imported-from names"""
        return types.MappingProxyType(self._removed)

    @property
    def new_modules(self) -> FrozenSet[str]:
        """Returns a read-only set of new modules imported via `from X import Y` statements"""
        return frozenset(self._new_modules)

    @property
    def removed_modules(self) -> FrozenSet[str]:
        """Returns a read-only set of removed modules imported via `from X import Y` statements"""
        return frozenset(self._removed_modules)

    def add_new(self, node: ImportedName) -> None:
        """Add new name imported by `from X import y` statement"""
        if not node.is_import_from():
            raise ValueError(
                "FromImportPyfference can only handle ImportFrom nodes"
            )  # pragma: no cover

        module = cast(ast.ImportFrom, node.node).module

        if module is None:
            raise Exception("ast.ImportFrom has `module` attribute set to None")  # pragma: no cover

        if module not in self._new:
            self._new[module] = set()
        self._new[module].add(node)

    def add_removed(self, node: ImportedName) -> None:
        """Add removed name imported by `from X import y` statement"""
        if not node.is_import_from():
            raise ValueError(
                "FromImportPyfference can only handle ImportFrom nodes"
            )  # pragma: no cover

        module = cast(ast.ImportFrom, node.node).module
        if module is None:
            raise Exception("ast.ImportFrom has `module` attribute set to None")  # pragma: no cover

        if module not in self._removed:
            self._removed[module] = set()
        self._removed[module].add(node)

    def add_new_modules(self, modules: Set[str]) -> None:
        """Add new modules imported via `from X import Y` statements"""
        self._new_modules.update(modules)

    def add_removed_modules(self, modules: Set[str]) -> None:
        """Add removed modules imported via `from X import Y` statements"""
        self._removed_modules.update(modules)

    def delete_new_module(self, module: str) -> None:
        """Delete new module imported via `from X import Y` statements"""
        self._new_modules.discard(module)
        if module in self._new:
            del self._new[module]

    def delete_removed_module(self, module: str) -> None:
        """Delete removed module imported via `from X import Y` statements"""
        self._removed_modules.discard(module)
        if module in self._removed:
            del self._removed[module]

    def __bool__(self):
        return bool(self._new or self.removed or self.new_modules or self.removed_modules)


class ImportsPyfference:
    """Represent difference between two ImportedNames."""

    def __init__(self):
        self._new_imports: Set[ImportedName] = set()
        self._removed_imports: Set[ImportedName] = set()
        self.fromimports: FromImportPyfference = FromImportPyfference()
        self._changed_to_fromimport: Dict[str, Set[ImportedName]] = {}
        self._changed_to_import: Dict[str, Set[ImportedName]] = {}

    def __bool__(self):
        return bool(
            (
                self._new_imports
                or self._removed_imports
                or self.fromimports
                or self._changed_to_fromimport
                or self._changed_to_import
            )
        )

    @property
    def new_imports(self) -> FrozenSet[ImportedName]:
        """Returns a read-only set of new imported names"""
        return frozenset(self._new_imports)

    @property
    def removed_imports(self) -> FrozenSet[ImportedName]:
        """Returns a read-only set of removed imported names"""
        return frozenset(self._removed_imports)

    def simplify(self) -> Optional["ImportsPyfference"]:
        """Cleans empty differences, empty sets etc. after manipulation"""
        return self if self else None

    def new_import(self, node: ImportedName) -> None:
        """Add a new imported name"""
        self._new_imports.add(node)

    def removed_import(self, node: ImportedName) -> None:
        """Add a removed imported name"""
        self._removed_imports.add(node)

    def new_from_import(self, node: ImportedName) -> None:
        """Add a new name imported via `from X import Y` statement"""
        if not node.is_import_from():
            raise ValueError(
                "ImportsPyfference.new_from_import can only handle ImportFrom nodes"
            )  # pragma: no cover

        if cast(ast.ImportFrom, node.node).module:
            self.fromimports.add_new(node)

    def removed_from_import(self, node: ImportedName) -> None:
        """Add a removed name imported via `from X import Y` statement"""
        if not node.is_import_from():
            raise ValueError(
                "ImportsPyfference.new_from_import can only handle ImportFrom nodes"
            )  # pragma: no cover

        if cast(ast.ImportFrom, node.node).module:
            self.fromimports.add_removed(node)

    def new_fromimport_modules(self, modules: Set[str]) -> None:
        """Add new modules imported via `from X import Y` statement"""
        self.fromimports.add_new_modules(modules)

    def removed_fromimport_modules(self, modules: Set[str]) -> None:
        """Add removed modules imported via `from X import Y` statement"""
        self.fromimports.add_removed_modules(modules)

    def reduce(self) -> None:
        """Find special cases and other reductions in the differences

        (1) Find matching names imported by different import statements and
            create special records for these changes.
            Example: `from os import path` in one version and `import os` in another"""
        for name in set(self._new_imports):
            if name.name in self.fromimports.removed_modules:
                LOGGER.debug(
                    f"New module has 'import {name}' "
                    f"and old module had 'from {name} import ...': "
                    f"Adding a change record"
                )
                self._new_imports.discard(name)
                self._changed_to_import[name.name] = self.fromimports.removed[name.name]
                self.fromimports.delete_removed_module(name.name)

        for name in set(self._removed_imports):
            if name.name in self.fromimports.new_modules:
                self._removed_imports.discard(name)
                self._changed_to_fromimport[name.name] = self.fromimports.new[name.name]
                self.fromimports.delete_new_module(name.name)
                LOGGER.debug(
                    f"Old module had 'import {name}' and "
                    f"new module has 'from {name} import ...': "
                    f"Adding a change record"
                )

    def __str__(self):
        lines = []
        removed_imports = sorted([name.name for name in self.removed_imports])
        if removed_imports:
            packages = pluralize("package", removed_imports)
            names = hlistify(removed_imports)
            lines.append(f"Removed import of {packages} {names}")

        new_imports = sorted([name.name for name in self.new_imports])
        if new_imports:
            packages = pluralize("package", new_imports)
            names = hlistify(new_imports)
            lines.append(f"New imported {packages} {names}")

        for module, names in self.fromimports.removed.items():
            removed_names = sorted([str(name) for name in names])
            hl_removed_names = hlistify(removed_names)
            if module in self.fromimports.removed_modules:
                lines.append(f"Removed import of {hl_removed_names} from removed {hl(module)}")
            else:
                lines.append(f"Removed import of {hl_removed_names} from {hl(module)}")

        for module, names in self.fromimports.new.items():
            new_names = sorted([str(name) for name in names])
            if module in self.fromimports.new_modules:
                lines.append(f"New imported {hlistify(new_names)} from new {hl(module)}")
            else:
                lines.append(f"New imported {hlistify(new_names)} from {hl(module)}")

        for module, names in self._changed_to_fromimport.items():
            new_names = sorted([str(name) for name in names])
            lines.append(
                f"New imported {hlistify(new_names)} from {hl(module)} "
                f"(previously, full {hl(module)} was imported)"
            )

        for module, names in self._changed_to_import.items():
            new_names = sorted([str(name) for name in names])
            was = "was" if len(new_names) == 1 else "were"
            lines.append(
                f"New imported package {hl(module)} "
                f"(previously, only {hlistify(new_names)} "
                f"{was} imported from {hl(module)})"
            )

        return "\n".join(lines)


class ImportedNames(collections.abc.Mapping):  # pylint: disable=too-few-public-methods
    """Dictionary mapping external names to appropriate ImportedName"""

    @staticmethod
    def extract(code: ast.Module) -> "ImportedNames":
        """Extracts ImportedNames from a Module"""
        import_walker = ImportExtractor()
        import_walker.visit(code)
        return import_walker.names

    @staticmethod
    def compare(old: "ImportedNames", new: "ImportedNames") -> Optional[ImportsPyfference]:
        """Compare two sets of imported names."""
        LOGGER.debug("Comparing ImportedNames")
        change = ImportsPyfference()
        for name, node in new.names.items():
            if name not in old.names:
                LOGGER.debug(f"New name '{name}' not present in old names")
                if node.is_import():
                    change.new_import(node)
                elif node.is_import_from():
                    change.new_from_import(node)

        for name, node in old.names.items():
            if name not in new.names:
                LOGGER.debug(f"Old name '{name}' not present in new names")
                if node.is_import():
                    change.removed_import(node)
                elif node.is_import_from():
                    change.removed_from_import(node)

        change.new_fromimport_modules(new.from_modules - old.from_modules)
        LOGGER.debug(f"New modules from which names were imported: " f"{change.fromimports.new}")
        change.removed_fromimport_modules(old.from_modules - new.from_modules)
        LOGGER.debug(
            f"Removed modules from which names were imported: " f"{change.fromimports.removed}"
        )

        change.reduce()

        return change if change else None

    def __init__(self) -> None:
        self.names: Dict[str, ImportedName] = {}
        self.from_modules: Set[str] = set()

    def __getitem__(self, item):
        return self.names[item]

    def __iter__(self):
        yield from self.names

    def __len__(self):
        return len(self.names)

    def __repr__(self):  # pragma: no cover
        return f"ImportedNames(names={self.names}, from_modules={self.from_modules}"

    def _add(self, name: ast.alias, node: ImportNode) -> None:
        if name.asname is not None:
            self.names[name.asname] = ImportedName(name.asname, node, alias=name)
        else:
            self.names[name.name] = ImportedName(name.name, node, alias=name)

    def add_import(self, node: ast.Import) -> None:
        """Add a 'import X, Y' statement"""
        for name in node.names:
            self._add(name, node)

    def add_importfrom(self, node: ast.ImportFrom) -> None:
        """Add a 'from X import Y' statement"""
        for name in node.names:
            self._add(name, node)
        if node.module:
            self.from_modules.add(node.module)


class ImportExtractor(ast.NodeVisitor):
    """Extracts information about import and 'import from' statements"""

    def __init__(self) -> None:
        self.names = ImportedNames()
        super(ImportExtractor, self).__init__()

    def visit_Import(self, node):  # pylint: disable=invalid-name
        """Save information about `import X, Y` statements"""
        self.names.add_import(node)

    def visit_ImportFrom(self, node):  # pylint: disable=invalid-name
        """Save information about `from x import y` statements"""
        self.names.add_importfrom(node)


def pyff_imports(old: ast.Module, new: ast.Module) -> Optional[ImportsPyfference]:
    """Return differences in import statements in two modules"""
    old_walker = ImportExtractor()
    new_walker = ImportExtractor()

    old_walker.visit(old)
    new_walker.visit(new)

    difference = ImportedNames.compare(old_walker.names, new_walker.names)

    return difference if difference else None


def pyff_imports_code(old_code: str, new_code: str) -> Optional[ImportsPyfference]:
    """Return differences in import statements in two modules"""
    old_ast = ast.parse(old_code)
    new_ast = ast.parse(new_code)

    return pyff_imports(old_ast, new_ast)