tconbeer/sqlfmt

View on GitHub
src/sqlfmt/merger.py

Summary

Maintainability
C
1 day
Test Coverage
A
100%
import itertools
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple

from sqlfmt.comment import Comment
from sqlfmt.exception import CannotMergeException, SqlfmtSegmentError
from sqlfmt.line import Line
from sqlfmt.mode import Mode
from sqlfmt.node import Node
from sqlfmt.operator_precedence import OperatorPrecedence
from sqlfmt.segment import Segment, create_segments_from_lines


@dataclass
class LineMerger:
    mode: Mode

    def create_merged_line(self, lines: List[Line]) -> List[Line]:
        """
        Returns a new line by merging together all nodes in lines. Raises an
        exception if the returned line would be too long, empty, or the nodes in
        any of the lines violate the rules in _raise_unmergeable.
        """

        if len(lines) <= 1:
            return lines

        nodes, comments = self._extract_components(lines)

        merged_line = Line.from_nodes(
            previous_node=lines[0].previous_node,
            nodes=nodes,
            comments=comments,
        )

        if merged_line.is_too_long(self.mode.line_length):
            raise CannotMergeException("Merged line is too long")

        # add in any leading or trailing blank lines
        leading_blank_lines = self._extract_leading_blank_lines(lines)
        trailing_blank_lines = list(
            reversed(self._extract_leading_blank_lines(reversed(lines)))
        )

        return leading_blank_lines + [merged_line] + trailing_blank_lines

    def safe_create_merged_line(self, lines: List[Line]) -> List[Line]:
        try:
            return self.create_merged_line(lines)
        except CannotMergeException:
            return lines

    @classmethod
    def _extract_components(
        cls, lines: Iterable[Line]
    ) -> Tuple[List[Node], List[Comment]]:
        """
        Given a list of lines, return 2 components:
        1. list of all nodes in those lines, with only a single trailing newline
        2. list of all comments in all of those lines

        Raise CannotMergeException if lines contain nodes that cannot
        be merged.
        """
        nodes: List[Node] = []
        comments: List[Comment] = []
        final_newline: Optional[Node] = None
        allow_multiline_jinja = True
        has_multiline_jinja = False
        has_inline_comment_above = False
        for line in lines:
            # only merge lines with comments if it's a standalone comment
            # above the first line or an inline comment after the last
            # line
            if line.comments:
                if has_inline_comment_above:
                    raise CannotMergeException(
                        "Can't merge lines with inline comments and other comments"
                    )
                elif (
                    len(line.comments) == 1
                    and len(line.nodes) > 1
                    and (
                        line.comments[0].is_inline
                        or line.comments[0].previous_node == line.nodes[-2]
                    )
                ):
                    # this is a comment that must be rendered inline,
                    # so it'll probably block merging unless the
                    # next line is just a comma
                    has_inline_comment_above = True
                elif len(nodes) == 1 and nodes[0].is_operator:
                    # if source has standalone operators, we can merge
                    # the operator into the contents, even if there is
                    # a comment in the way
                    pass
                elif nodes:
                    raise CannotMergeException(
                        "Can't merge lines with standalone comments unless the "
                        "comments are above the first line"
                    )
            # make an exception for inline comments followed by
            # a lonely comma (e.g., leading commas with inline comments)
            elif has_inline_comment_above:
                if not (line.is_standalone_comma or line.is_blank_line):
                    raise CannotMergeException(
                        "Can't merge lines with inline comments unless "
                        "the following line is a single standalone comma "
                        "or a blank line"
                    )

            if has_multiline_jinja and not (
                line.starts_with_operator or line.starts_with_comma
            ):
                raise CannotMergeException(
                    "Can't merge lines containing multiline nodes"
                )
            # skip over newline nodes
            content_nodes = [
                cls._raise_unmergeable(node, allow_multiline_jinja)
                for node in line.nodes
                if not node.is_newline
            ]
            if content_nodes:
                final_newline = line.nodes[-1]
                nodes.extend(content_nodes)
                # we can merge a line containing multiline jinja
                # into a preceding line iff:
                # the multiline node is on the second line and follows a
                # standalone operator
                if not (
                    allow_multiline_jinja
                    and len(content_nodes) == 1
                    and content_nodes[0].is_operator
                ):
                    allow_multiline_jinja = False
                # we can merge a line into a preceding line that
                # contains multiline jinja iff:
                # the line starts with an operator or a comma
                has_multiline_jinja = any(
                    [node.is_multiline_jinja for node in content_nodes]
                )
            comments.extend(line.comments)

        if not nodes or not final_newline:
            raise CannotMergeException("Can't merge only whitespace/newlines")

        nodes.append(final_newline)

        return nodes, comments

    @staticmethod
    def _raise_unmergeable(node: Node, allow_multiline_jinja: bool) -> Node:
        """
        Raises a CannotMergeException if the node cannot be merged. Otherwise
        returns the node
        """
        if node.formatting_disabled:
            raise CannotMergeException(
                "Can't merge lines containing disabled formatting"
            )
        elif node.divides_queries:
            raise CannotMergeException(
                "Can't merge multiple queries onto a single line"
            )
        elif node.is_multiline_jinja and not allow_multiline_jinja:
            raise CannotMergeException("Can't merge lines containing multiline nodes")
        else:
            return node

    @staticmethod
    def _extract_leading_blank_lines(lines: Iterable[Line]) -> List[Line]:
        leading_blank_lines: List[Line] = []
        for line in lines:
            if line.is_blank_line:
                leading_blank_lines.append(line)
            else:
                break
        return leading_blank_lines

    def maybe_merge_lines(self, lines: List[Line]) -> List[Line]:
        """
        Tries to merge lines into a single line; if that fails,
        splits lines into segments of equal depth, merges
        runs of operators at that depth, and then recurses into
        each segment

        Returns a new list of Lines
        """
        if not lines or all([line.formatting_disabled for line in lines]):
            return lines

        try:
            merged_lines = self.create_merged_line(lines)
        except CannotMergeException:
            merged_lines = []
            # doesn't fit onto a single line, so split into
            # segments at the depth of lines[0]
            segments = create_segments_from_lines(lines)
            # if a segment starts with a standalone operator,
            # the first two lines of that segment should likely
            # be merged before doing anything else
            segments = self._fix_standalone_operators(segments)
            if len(segments) > 1:
                # merge together segments of equal depth that are
                # joined by operators
                segments = self._maybe_merge_operators(
                    segments, OperatorPrecedence.tiers()
                )
                # some operators really should not be by themselves
                # so if their segments are too long to be merged,
                # we merge just their first line onto the prior segment
                segments = self._maybe_stubbornly_merge(segments)
                # then recurse into each segment and try to merge lines
                # within individual segments
                for segment in segments:
                    merged_lines.extend(self.maybe_merge_lines(segment))
            # if there was only a single segment at the depth of the
            # top line, we need to move down one line and try again.
            # Because of the structure of a well-split set of lines,
            # in this case moving down one line is guaranteed to move
            # us in one depth.
            # if the final line of the segment matches the top line,
            # we need to strip that off so we only segment the
            # indented lines
            else:
                only_segment = segments[0]
                try:
                    _, i = only_segment.head
                except SqlfmtSegmentError:
                    merged_lines.extend(only_segment)
                else:
                    merged_lines.extend(only_segment[: i + 1])
                    for segment in only_segment.split_after(i):
                        merged_lines.extend(self.maybe_merge_lines(segment))

        return merged_lines

    def _fix_standalone_operators(self, segments: List[Segment]) -> List[Segment]:
        """
        If the first line of a segment is a standalone operator,
        we should try to merge the first two lines together before
        doing anything else
        """
        for segment in segments:
            try:
                head, i = segment.head
                if head.is_standalone_operator:
                    remainder_after_operator = Segment(segment[i + 1 :])
                    _, j = remainder_after_operator.head
                    try:
                        merged_lines = self.create_merged_line(segment[: i + j + 2])
                        segment[: i + j + 2] = merged_lines
                    except CannotMergeException:
                        pass
            except SqlfmtSegmentError:
                pass
        return segments

    def _maybe_merge_operators(
        self,
        segments: List[Segment],
        op_tiers: List[OperatorPrecedence],
    ) -> List[Segment]:
        """
        Tries to merge runs of segments that start with operators into previous
        segments. Operators have a priority that determines a sort of hierarchy;
        if we can't merge a whole run of operators, we increase the priority to
        create shorter runs that can be merged
        """
        if len(segments) <= 1 or not op_tiers:
            return segments
        head = 0
        new_segments: List[Segment] = []
        precedence = op_tiers.pop()

        for i, segment in enumerate(segments[1:], start=1):
            if not self._segment_continues_operator_sequence(segment, precedence):
                new_segments.extend(
                    self._try_merge_operator_segments(segments[head:i], op_tiers.copy())
                )
                head = i

        # we need to try one more time to merge everything after head
        else:
            new_segments.extend(
                self._try_merge_operator_segments(segments[head:], op_tiers.copy())
            )

        return new_segments

    @classmethod
    def _segment_continues_operator_sequence(
        cls, segment: Segment, max_precedence: OperatorPrecedence
    ) -> bool:
        """
        Returns true if the first line of the segment is part
        of a sequence of operators of priority <= max_priority
        """
        try:
            line, _ = segment.head
        except SqlfmtSegmentError:
            # if a segment is blank, keep scanning
            return True
        else:
            return (
                line.starts_with_operator
                and OperatorPrecedence.from_node(line.nodes[0]) <= max_precedence
            ) or line.starts_with_comma

    def _try_merge_operator_segments(
        self, segments: List[Segment], op_tiers: List[OperatorPrecedence]
    ) -> List[Segment]:
        """
        Attempts to merge segments into a single line; if that fails,
        recurses at a lower operator priority
        """
        if len(segments) <= 1:
            return segments

        try:
            new_segments = [
                Segment(self.create_merged_line(list(itertools.chain(*segments))))
            ]
        except CannotMergeException:
            new_segments = self._maybe_merge_operators(segments, op_tiers)

        return new_segments

    def _maybe_stubbornly_merge(self, segments: List[Segment]) -> List[Segment]:
        """
        We prefer some operators, like `as`, `over()`, `exclude()`, and
        array or dictionary accessing with `[]` to be
        forced onto the prior line, even if the contents of their brackets
        don't fit there. This is also true for most operators that open
        a bracket, like `in ()` or `+ ()`, as long as the preceding segment
        does not also start with an operator.

        This method scans for segments that start with
        such operators and partially merges those segments with the prior
        segments by calling _stubbornly_merge()
        """
        if len(segments) <= 1:
            return segments

        new_segments = [segments[0]]

        # first stubborn-merge all p0 operators
        for i, segment in enumerate(segments[1:], start=1):
            if (
                # always stubbornly merge P0 operators (e.g., `over`)
                self._segment_continues_operator_sequence(
                    segment, max_precedence=OperatorPrecedence.OTHER_TIGHT
                )
            ):
                new_segments = self._stubbornly_merge(new_segments, segment)
            else:
                new_segments.append(segment)

        if len(new_segments) == 1:
            return new_segments

        # next, stubbon-merge qualifying p1 operators
        segments = new_segments
        new_segments = [segments[0]]

        starts_with_p1_operator = [
            self._segment_continues_operator_sequence(
                segment, max_precedence=OperatorPrecedence.COMPARATORS
            )
            for segment in segments
        ]
        for i, segment in enumerate(segments[1:], start=1):
            if (
                not starts_with_p1_operator[i - 1]
                and starts_with_p1_operator[i]
                and Segment(self.safe_create_merged_line(segment)).tail_closes_head
            ):
                new_segments = self._stubbornly_merge(new_segments, segment)
            else:
                new_segments.append(segment)

        return new_segments

    def _stubbornly_merge(
        self, prev_segments: List[Segment], segment: Segment
    ) -> List[Segment]:
        """
        Attempts several different methods of merging the last segment in
        new_segments and segment. Returns a list of segments that represent the
        best possible merger of those segments
        """
        new_segments = prev_segments.copy()
        prev_segment = new_segments.pop()
        try:
            head, i = segment.head
        except SqlfmtSegmentError:
            new_segments.extend([prev_segment, segment])
            return new_segments

        # try to merge the first line of this segment with the previous segment
        try:
            prev_segment = Segment(self.create_merged_line(prev_segment + [head]))
            prev_segment.extend(segment[i + 1 :])
            new_segments.append(prev_segment)
        except CannotMergeException:
            # try to add this segment to the last line of the previous segment
            last_line, k = prev_segment.tail
            try:
                new_last_lines = self.create_merged_line([last_line] + segment)
                prev_segment[-(k + 1) :] = new_last_lines
                new_segments.append(prev_segment)
            except CannotMergeException:
                # try to add just the first line of this segment to the last
                # line of the previous segment
                try:
                    new_last_lines = self.create_merged_line([last_line, head])
                    prev_segment[-(k + 1) :] = new_last_lines
                    prev_segment.extend(segment[i + 1 :])
                    new_segments.append(prev_segment)
                except CannotMergeException:
                    # give up and just return the original segments
                    new_segments.extend([prev_segment, segment])

        return new_segments