src/sqlfmt/report.py
import difflib
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import click
from sqlfmt.exception import SqlfmtError
from sqlfmt.mode import Mode
STDIN_PATH = Path("-")
def style_output(
msg: str, fg: Optional[str] = None, bg: Optional[str] = None, bold: bool = False
) -> str:
"""
A thin wrapper around click.style.
See https://click.palletsprojects.com/en/8.0.x/api/?highlight=style#click.style
"""
s: str = click.style(msg, fg=fg, bg=bg, bold=bold)
return s
def unstyle_output(msg: str) -> str:
"""
A thin wrapper around click.unstyle.
"""
s: str = click.unstyle(msg)
return s
def display_output(msg: str, err: bool = True) -> None:
"""
A thin wrapper around click.echo; defaults to printing to stderr.
"""
click.echo(msg, err=err)
@dataclass
class SqlFormatResult:
"""
A SqlfmtResult is a summary of the changes made by sqlfmt to a single file.
"""
source_path: Path
source_string: str
formatted_string: str
encoding: str
utf_bom: str
exception: Optional[SqlfmtError] = None
from_cache: bool = False
def __post_init__(self) -> None:
try:
self.display_path = self.source_path.relative_to(Path.cwd())
except ValueError:
self.display_path = self.source_path
def maybe_print_to_stdout(self) -> None:
"""
If sqlfmt received a query via stdin, print the formatted string to stdout
"""
if self.source_path == STDIN_PATH:
display_output(self.formatted_string, err=False)
@property
def has_changed(self) -> bool:
return self.source_string != self.formatted_string
@property
def has_error(self) -> bool:
return self.exception is not None
@dataclass
class Report:
"""
An abstraction for a summary of results generated by a sqlfmt run.
Can be printed to stderr using display_report()
"""
results: List[SqlFormatResult]
mode: Mode
def __str__(self) -> str:
"""
Returns the full contents of the Report
"""
report = []
formatted = (
"failed formatting check"
if (self.mode.check or self.mode.diff)
else "formatted"
)
unchanged = (
"passed formatting check"
if (self.mode.check or self.mode.diff)
else "left unchanged"
)
if self.number_errored > 0:
error_msg = (
f"{self._pluralize_file(self.number_errored)} had errors while "
f"formatting."
)
report.append(style_output(error_msg, fg="red", bold=True))
if self.number_changed > 0:
changed_msg = f"{self._pluralize_file(self.number_changed)} {formatted}."
report.append(style_output(changed_msg, bold=True))
report.append(f"{self._pluralize_file(self.number_unchanged)} {unchanged}.")
for res in self.errored_results[0:50]:
err = style_output(str(res.exception), fg="red")
report.append(f"{res.display_path}\n {err}")
if not self.mode.quiet or self.mode.diff:
for res in self.changed_results:
report.append(f"{res.display_path} {formatted}.")
if self.mode.diff:
report.append(self._generate_diff(res))
if self.mode.verbose:
for res in self.unchanged_results:
report.append(f"{res.display_path} {unchanged}.")
msg = "\n".join(report)
if self.mode.color is False:
msg = unstyle_output(msg)
return msg
@staticmethod
def _pluralize_file(n: int) -> str:
"""
Returns either "1 file" or "n files", depending on n
"""
suffix = "s" if n != 1 else ""
return f"{n} file{suffix}"
@classmethod
def _generate_diff(cls, result: SqlFormatResult) -> str:
"""
Returns a non-colorized diff of the source and formatted
strings in the SqlfmtResult
"""
cleaned_lines = []
# Work around https://bugs.python.org/issue2142
for line in difflib.unified_diff(
result.source_string.splitlines(keepends=True),
result.formatted_string.splitlines(keepends=True),
fromfile="source_query",
tofile="formatted_query",
):
if line[-1] == "\n":
cleaned_lines.append(cls._style_diff_line(line))
else:
cleaned_lines.append(cls._style_diff_line(line + "\n"))
cleaned_lines.append(
cls._style_diff_line("\\ No newline at end of file\n")
)
return "".join(cleaned_lines)
@staticmethod
def _style_diff_line(line: str) -> str:
"""
Colorizes the diff created by _generate_diff
"""
if line.startswith("@@"):
styled = style_output(line, fg="cyan")
elif line.startswith("+"):
styled = style_output(line, fg="green")
elif line.startswith("-"):
styled = style_output(line, fg="red")
else:
styled = line
return styled
def display_report(self) -> None:
"""
If sqlfmt received a query via stdin, print the formatted string to stdout.
Then print the report that summarizes all results
"""
if not self.mode.check and not self.mode.diff:
for res in self.results:
res.maybe_print_to_stdout()
display_output(str(self), err=True)
@property
def changed_results(self) -> List[SqlFormatResult]:
return self._filtered_results(has_changed=True, has_error=False)
@property
def unchanged_results(self) -> List[SqlFormatResult]:
return self._filtered_results(has_changed=False, has_error=False)
@property
def errored_results(self) -> List[SqlFormatResult]:
return self._filtered_results(has_changed=True, has_error=True)
def _filtered_results(
self, has_changed: bool = True, has_error: bool = False
) -> List[SqlFormatResult]:
filtered = [
r
for r in self.results
if r.has_changed == has_changed and r.has_error == has_error
]
return sorted(filtered, key=lambda res: res.source_path)
@property
def number_changed(self) -> int:
return len(self.changed_results)
@property
def number_unchanged(self) -> int:
return len(self.unchanged_results)
@property
def number_errored(self) -> int:
return len(self.errored_results)