tconbeer/sqlfmt

View on GitHub
src/sqlfmt/report.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
import difflib
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import click

from sqlfmt.exception import SqlfmtError
from sqlfmt.mode import Mode

STDIN_PATH = Path("-")


def style_output(
    msg: str, fg: Optional[str] = None, bg: Optional[str] = None, bold: bool = False
) -> str:
    """
    A thin wrapper around click.style.

    See https://click.palletsprojects.com/en/8.0.x/api/?highlight=style#click.style
    """
    s: str = click.style(msg, fg=fg, bg=bg, bold=bold)
    return s


def unstyle_output(msg: str) -> str:
    """
    A thin wrapper around click.unstyle.
    """
    s: str = click.unstyle(msg)
    return s


def display_output(msg: str, err: bool = True) -> None:
    """
    A thin wrapper around click.echo; defaults to printing to stderr.
    """
    click.echo(msg, err=err)


@dataclass
class SqlFormatResult:
    """
    A SqlfmtResult is a summary of the changes made by sqlfmt to a single file.
    """

    source_path: Path
    source_string: str
    formatted_string: str
    encoding: str
    utf_bom: str
    exception: Optional[SqlfmtError] = None
    from_cache: bool = False

    def __post_init__(self) -> None:
        try:
            self.display_path = self.source_path.relative_to(Path.cwd())
        except ValueError:
            self.display_path = self.source_path

    def maybe_print_to_stdout(self) -> None:
        """
        If sqlfmt received a query via stdin, print the formatted string to stdout
        """
        if self.source_path == STDIN_PATH:
            display_output(self.formatted_string, err=False)

    @property
    def has_changed(self) -> bool:
        return self.source_string != self.formatted_string

    @property
    def has_error(self) -> bool:
        return self.exception is not None


@dataclass
class Report:
    """
    An abstraction for a summary of results generated by a sqlfmt run.
    Can be printed to stderr using display_report()
    """

    results: List[SqlFormatResult]
    mode: Mode

    def __str__(self) -> str:
        """
        Returns the full contents of the Report
        """
        report = []
        formatted = (
            "failed formatting check"
            if (self.mode.check or self.mode.diff)
            else "formatted"
        )
        unchanged = (
            "passed formatting check"
            if (self.mode.check or self.mode.diff)
            else "left unchanged"
        )
        if self.number_errored > 0:
            error_msg = (
                f"{self._pluralize_file(self.number_errored)} had errors while "
                f"formatting."
            )
            report.append(style_output(error_msg, fg="red", bold=True))
        if self.number_changed > 0:
            changed_msg = f"{self._pluralize_file(self.number_changed)} {formatted}."
            report.append(style_output(changed_msg, bold=True))
        report.append(f"{self._pluralize_file(self.number_unchanged)} {unchanged}.")
        for res in self.errored_results[0:50]:
            err = style_output(str(res.exception), fg="red")
            report.append(f"{res.display_path}\n    {err}")
        if not self.mode.quiet or self.mode.diff:
            for res in self.changed_results:
                report.append(f"{res.display_path} {formatted}.")
                if self.mode.diff:
                    report.append(self._generate_diff(res))
        if self.mode.verbose:
            for res in self.unchanged_results:
                report.append(f"{res.display_path} {unchanged}.")

        msg = "\n".join(report)
        if self.mode.color is False:
            msg = unstyle_output(msg)
        return msg

    @staticmethod
    def _pluralize_file(n: int) -> str:
        """
        Returns either "1 file" or "n files", depending on n
        """
        suffix = "s" if n != 1 else ""
        return f"{n} file{suffix}"

    @classmethod
    def _generate_diff(cls, result: SqlFormatResult) -> str:
        """
        Returns a non-colorized diff of the source and formatted
        strings in the SqlfmtResult
        """
        cleaned_lines = []
        # Work around https://bugs.python.org/issue2142
        for line in difflib.unified_diff(
            result.source_string.splitlines(keepends=True),
            result.formatted_string.splitlines(keepends=True),
            fromfile="source_query",
            tofile="formatted_query",
        ):
            if line[-1] == "\n":
                cleaned_lines.append(cls._style_diff_line(line))
            else:
                cleaned_lines.append(cls._style_diff_line(line + "\n"))
                cleaned_lines.append(
                    cls._style_diff_line("\\ No newline at end of file\n")
                )

        return "".join(cleaned_lines)

    @staticmethod
    def _style_diff_line(line: str) -> str:
        """
        Colorizes the diff created by _generate_diff
        """
        if line.startswith("@@"):
            styled = style_output(line, fg="cyan")
        elif line.startswith("+"):
            styled = style_output(line, fg="green")
        elif line.startswith("-"):
            styled = style_output(line, fg="red")
        else:
            styled = line
        return styled

    def display_report(self) -> None:
        """
        If sqlfmt received a query via stdin, print the formatted string to stdout.

        Then print the report that summarizes all results
        """
        if not self.mode.check and not self.mode.diff:
            for res in self.results:
                res.maybe_print_to_stdout()
        display_output(str(self), err=True)

    @property
    def changed_results(self) -> List[SqlFormatResult]:
        return self._filtered_results(has_changed=True, has_error=False)

    @property
    def unchanged_results(self) -> List[SqlFormatResult]:
        return self._filtered_results(has_changed=False, has_error=False)

    @property
    def errored_results(self) -> List[SqlFormatResult]:
        return self._filtered_results(has_changed=True, has_error=True)

    def _filtered_results(
        self, has_changed: bool = True, has_error: bool = False
    ) -> List[SqlFormatResult]:
        filtered = [
            r
            for r in self.results
            if r.has_changed == has_changed and r.has_error == has_error
        ]
        return sorted(filtered, key=lambda res: res.source_path)

    @property
    def number_changed(self) -> int:
        return len(self.changed_results)

    @property
    def number_unchanged(self) -> int:
        return len(self.unchanged_results)

    @property
    def number_errored(self) -> int:
        return len(self.errored_results)