projectsyn/commodore

View on GitHub
commodore/gitrepo/__init__.py

Summary

Maintainability
A
0 mins
Test Coverage
A
97%
from __future__ import annotations

import hashlib
import re
import shutil

from collections import namedtuple
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Optional

import click

from git import Actor, BadName, FetchInfo, GitCommandError, PushInfo, Repo
from git.objects import Tree

from url_normalize import url_normalize
from url_normalize.tools import deconstruct_url, reconstruct_url

from .diff import DiffFunc, default_difffunc, process_diff


class RefError(ValueError):
    pass


class MergeConflict(ValueError):
    pass


CommitInfo = namedtuple("CommitInfo", ["commit", "branch", "tag"])


def _normalize_git_ssh(url: str) -> str:
    # Import url_normalize internal methods here, so they're not visible in the file
    # scope of gitrepo.py
    # pylint: disable=import-outside-toplevel
    from url_normalize.url_normalize import (
        normalize_userinfo,
        normalize_host,
        normalize_path,
        provide_url_scheme,
    )

    if "@" in url and not url.startswith("ssh://"):
        # Assume git@host:repo format, reformat so url_normalize understands
        # the URL
        host, repo = url.split(":")
        url = f"{host}/{repo}"
    # Reuse normalization logic from url_normalize, simplify for Git-SSH use case.
    # We can't do `url_normalize(url, "ssh"), because the library doesn't know "ssh" as
    # a scheme, and fails to look up the default port for "ssh".
    url = provide_url_scheme(url, "ssh")
    urlparts = deconstruct_url(url)
    urlparts = urlparts._replace(
        userinfo=normalize_userinfo(urlparts.userinfo),
        host=normalize_host(urlparts.host),
        path=normalize_path(urlparts.path, scheme="https"),
    )
    return reconstruct_url(urlparts)


def normalize_git_url(url: str) -> str:
    """Normalize HTTP(s) and SSH Git URLs"""
    if "@" in url and ("://" not in url or url.startswith("ssh://")):
        url = _normalize_git_ssh(url)
    elif url.startswith("http://") or url.startswith("https://"):
        url = url_normalize(url)
    return url


class GitRepo:
    _repo: Repo
    _author: Optional[Actor]

    @classmethod
    def clone(cls, repository_url: str, directory: Path, cfg):
        """Clone repository.

        Create initial commit on master branch, if remote is empty."""
        name = None
        email = None
        if cfg:
            name = cfg.username
            email = cfg.usermail
        r = GitRepo(
            repository_url,
            directory,
            force_init=True,
            author_name=name,
            author_email=email,
        )
        try:
            r.checkout()
        except RefError as e:
            click.echo(f" > {e}, creating initial commit for {directory}")
            r.commit("Initial commit")
        except Exception as e:
            raise click.ClickException(
                f"While cloning git repository from {repository_url}: {e}"
            ) from e
        return r

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        remote: Optional[str],
        targetdir: Path,
        force_init=False,
        author_name: Optional[str] = None,
        author_email: Optional[str] = None,
        config=None,
        bare=False,
    ):
        if not force_init and targetdir.exists():
            self._repo = Repo(targetdir)
        else:
            self._repo = Repo.init(targetdir, bare=bare, initial_branch="master")

        if remote:
            self.remote = remote

        self._author = None
        self._author_name = author_name
        self._author_email = author_email

        if config:
            self._debug = config.debug
            self._trace = config.trace
        else:
            self._debug = False
            self._trace = False

    @property
    def repo(self) -> Repo:
        return self._repo

    @property
    def remote(self) -> str:
        return self._repo.remote().url

    @remote.setter
    def remote(self, remote: str):
        remote = normalize_git_url(remote)
        try:
            self._repo.remote().set_url(remote)
        except ValueError:
            self._repo.create_remote("origin", remote)

        # Generate a best effort push-over-SSH URL for http(s) repo URLs.
        if remote.startswith(("http://", "https://")):
            remote_parts = deconstruct_url(remote)
            # Ignore everything but the host and path parts of the URL to
            # build the push URL
            pushurl = f"ssh://git@{remote_parts.host}{remote_parts.path}"
            self._repo.remote().set_url(pushurl, push=True)

    @property
    def working_tree_dir(self) -> Optional[Path]:
        d = self._repo.working_tree_dir
        if d:
            return Path(d)
        return None

    @property
    def head_short_sha(self) -> str:
        sha = self._repo.head.commit.hexsha
        return self._repo.git.rev_parse(sha, short=6)

    @property
    def _author_env(self) -> dict[str, str]:
        return {
            Actor.env_author_name: self.author.name or "",
            Actor.env_author_email: self.author.email or "",
            Actor.env_committer_name: self.author.name or "",
            Actor.env_committer_email: self.author.email or "",
        }

    @property
    def _null_tree(self) -> Tree:
        """Generate empty Tree for the repo.
        An empty Git tree is represented by the C string "tree 0".
        The hash of the empty tree is always SHA1("tree 0\\0").  This method computes
        the hexdigest of this sha1 and creates a tree object for the empty tree of
        `self._repo`.
        """
        null_tree_sha = hashlib.sha1(b"tree 0\0").hexdigest()  # nosec
        return self._repo.tree(null_tree_sha)

    @property
    def author(self) -> Actor:
        if not self._author:
            if self._author_name and self._author_email:
                self._author = Actor(self._author_name, self._author_email)
            else:
                try:
                    self._author = Actor.committer(self._repo.config_reader())
                except KeyError:
                    # Handle case where UID-based user info fallback doesn't work
                    click.echo(
                        " > Can't determine author information, falling back to "
                        + "`Commodore <commodore@syn.tools>`"
                    )
                    self._author = Actor("Commodore", "commodore@syn.tools")
        return self._author

    def _remote_prefix(self) -> str:
        """
        Find prefix of Git remote, will usually be 'origin/'.
        """
        return self._repo.remote().name + "/"

    def _default_version(self) -> str:
        """
        Find default branch of the remote
        """
        try:
            version = self._repo.remote().refs["HEAD"].reference.name
        except IndexError:
            try:
                self._repo.git.remote("set-head", "origin", "--auto")
                version = self._repo.remote().refs["HEAD"].reference.name
            except GitCommandError:
                # If we don't have a remote HEAD, we fall back to creating a master
                # branch.
                version = f"{self._remote_prefix()}master"
        return version.replace(self._remote_prefix(), "", 1)

    def _find_commit_for_version(
        self, version: str, remote_heads: Iterable[FetchInfo]
    ) -> CommitInfo:
        remote_prefix = self._remote_prefix()
        for head in remote_heads:
            tag = None
            branch = None
            headname = head.name
            if headname.startswith(remote_prefix):
                branch = headname.replace(remote_prefix, "", 1)
                headname = branch
            else:
                tag = headname

            if headname == version:
                commit = head.commit
                break
        else:
            # If we haven't found a branch or tag matching the requested version,
            # assume the version is a commit sha.
            commit = self._repo.rev_parse(version)
            branch = None
            tag = None

        return CommitInfo(commit=commit, branch=branch, tag=tag)

    def fetch(
        self, remote: str = "origin", tags: bool = True, prune: bool = True
    ) -> Iterable[FetchInfo]:
        return self._repo.remote(remote).fetch(tags=tags, prune=prune)

    def has_local_branches(self) -> bool:
        if len(self.repo.remotes) == 0:
            # If we don't have a remote, the fact that we have local branches is
            # useless to determine whether to abort or continue a compile.
            return False
        local_heads = set(h.name for h in self.repo.heads)
        remote_prefix = self._remote_prefix()
        remote_heads = set(h.name.replace(remote_prefix, "", 1) for h in self.fetch())
        return len(local_heads - remote_heads) > 0

    def has_local_changes(self) -> bool:
        return self._repo.is_dirty() or len(self._repo.untracked_files) > 0

    def is_ahead_of_remote(self) -> bool:
        if self.repo.head.is_detached:
            # Always return False for repo which has a detached head checked out.
            return False

        active_branch = self.repo.active_branch
        tracking_branch = active_branch.tracking_branch()
        if not tracking_branch:
            # If there's no tracking branch there's no point in reporting that we're
            # ahead of anything.
            return False

        return (
            len(list(self.repo.iter_commits(f"{tracking_branch}..{active_branch}"))) > 0
        )

    def _create_worktree(self, worktree: Path, version: str):
        """Create worktree.

        This method expects `worktree` to not exist."""

        # We need to use `git.execute()` for the worktree commands as GitPython only has
        # basic support for worktrees.
        self._repo.git.execute(["git", "worktree", "prune"])
        cmd = ["git", "worktree", "add", "-f", str(worktree), version]
        try:
            self._repo.git.execute(cmd)
        except GitCommandError as e:
            # Assume that GitCommandError is only caused by invalid versions
            raise RefError(f"Failed to checkout revision '{version}'") from e

    def _migrate_to_worktree(self, wtr: GitRepo, worktree: Path, version: str):
        """Migrate non-worktree checkout to worktree."""
        if wtr.has_local_branches() or wtr.has_local_changes():
            raise click.ClickException(
                f"Migrating dependency {worktree} to worktree-based checkout "
                + "would delete uncommitted changes, untracked files, or "
                + "unpushed branches, aborting..."
            )
        click.secho(f" > Removing non-worktree based checkout {worktree}", fg="yellow")
        shutil.rmtree(worktree)
        self._create_worktree(worktree, version)

    def _update_worktree_remote(self, wtr: GitRepo, worktree: Path, version: str):
        """Update existing worktree checkout to new remote.

        Updating the remote for a worktree needs special handling, since we generally
        don't want to update the remote URL for the previously used remote URL's bare
        clone, but instead recreate the worktree from the new remote's bare clone."""
        if wtr.has_local_changes():
            raise click.ClickException(
                f"Switching remote for worktree '{worktree}' would delete "
                + "uncommitted changes or untracked files, aborting..."
            )
        # Remove stale worktree if there are no uncommitted changes
        wtr_common_dir = Path(wtr.repo.common_dir).resolve()
        click.secho(
            f" > Removing stale, but clean worktree {worktree}, "
            + f"bare repo is available at {wtr_common_dir}",
            fg="green",
        )
        wtr.repo.git.execute(["git", "worktree", "remove", str(worktree)])
        self._create_worktree(worktree, version)

    def _checkout_existing_worktree(self, worktree: Path, version: str):
        """Perform checkout if requested worktree directory already exists.

        The heavy work is generally done by `_migrate_to_worktree()`,
        `_update_worktree_remote()` or `checkout()`.
        """
        wtr = GitRepo(
            None, worktree, author_name=self.author.name, author_email=self.author.email
        )
        if not wtr.repo.has_separate_working_tree():
            # If the worktree's common dir is stored in the repository working tree
            # root, we're migrating from a non-worktree checkout to a worktree checkout.
            self._migrate_to_worktree(wtr, worktree, version)
        elif wtr.remote != self.remote:
            # If the existing directory is already a worktree, but we're using a
            # different remote for the requested worktree, we need to recreate the
            # worktree from the new remote's bare clone.
            self._update_worktree_remote(wtr, worktree, version)
        else:
            # Otherwise, we just need to update the worktree's version. We simply use
            # `checkout()` in the worktree to do so.
            wtr.checkout(version)

    def checkout_worktree(self, worktree: Path, version: Optional[str]):
        """Create worktree if it doesn't exist and check out `version` in it.

        If `version` is not provided, the remote's default branch is checked out.

        If a repo which isn't a worktree of `self` is found in the requested worktree
        location, the method will try to replace the old checkout with the requested
        worktree unless there's any local changes (untracked files, uncommitted changes,
        or branches which don't exist upstream).
        """
        # Try to fetch remote heads, so we can actually check them out
        try:
            _ = self.fetch()
        except ValueError:
            pass

        if version is None:
            version = self._default_version()

        # If the worktree directory exists, use `_checkout_existing_worktree()`
        if worktree.is_dir():
            self._checkout_existing_worktree(worktree, version)
            return

        # If the worktree directory doesn't exist yet, create the worktree
        self._create_worktree(worktree, version)

    def initialize_worktree(
        self, worktree: Path, initial_branch: Optional[str] = None
    ) -> None:
        if not initial_branch:
            initial_branch = self._default_version()

        # We need an initial commit to be able to create a worktree. Create initial
        # commit from empty tree.
        initsha = self._repo.git.execute(  # type: ignore[call-overload]
            command=[
                "git",
                "commit-tree",
                "-m",
                "Initial commit",
                self._null_tree.hexsha,
            ],
            env=self._author_env,
        )

        # Create worktree using the provided branch name
        self._repo.git.execute(
            ["git", "worktree", "add", str(worktree), initsha, "-b", initial_branch]
        )

    @property
    def worktrees(self) -> list[GitRepo]:
        """List all worktrees for the repo"""
        # First prune worktrees, to ensure repo worktree state is clean
        self._repo.git.execute(["git", "worktree", "prune"])
        worktrees: list[GitRepo] = []
        wt_list = self._repo.git.execute(
            ["git", "worktree", "list", "--porcelain"],
            as_process=False,
            with_extended_output=False,
            stdout_as_string=True,
        ).splitlines()
        for line in wt_list:
            if " " not in line:
                continue
            k, v = line.split(" ")
            if k == "worktree":
                worktrees.append(
                    GitRepo(
                        None,
                        Path(v),
                        author_name=self.author.name,
                        author_email=self.author.email,
                    )
                )

        return worktrees

    def checkout(self, version: Optional[str] = None):
        remote_heads = self.fetch()
        if not remote_heads:
            # GitPython's fetch-info parsing chokes on lines like
            # "   (refs/remotes/origin/HEAD has become dangling)"
            # see also https://github.com/gitpython-developers/GitPython/issues/962.
            # We handle this case by simply performing a second fetch which should
            # return fetch-infos with flags = 4 (HEAD_UPTODATE).
            remote_heads = self.fetch()

        if version is None:
            # Handle case where we want the default branch of the remote
            version = self._default_version()

        try:
            commit, branch, tag = self._find_commit_for_version(version, remote_heads)
            if branch:
                # If we found a remote branch for the requested version, find
                # or create a local branch and point HEAD to the branch.
                _head = [h for h in self._repo.heads if h.name == branch]
                if len(_head) > 0:
                    head = _head[0]
                    head.commit = commit
                else:
                    head = self._repo.create_head(branch, commit=commit)

                head.set_tracking_branch(self.repo.remote().refs[branch])
                self._repo.head.set_reference(head)
            elif tag:
                # Simply create detached head pointing to tag
                self._repo.head.set_reference(tag)
            else:
                # Create detached head by setting repo.head.reference as
                # direct ref to commit object.
                self._repo.head.set_reference(commit)

            # Reset working tree to current HEAD reference
            self._repo.head.reset(index=True, working_tree=True)
        except GitCommandError as e:
            raise RefError(f"Failed to checkout revision '{version}'") from e
        except BadName as e:
            raise RefError(f"Revision '{version}' not found in repository") from e

    def _check_conflicts(self):
        """Check for conflicts in index. Raise `MergeConflict` for the first conflict
        found."""
        for path, blobs in self.repo.index.unmerged_blobs().items():
            for stage, b in blobs:
                if stage != 0:
                    raise MergeConflict(path)

    def _compute_changed_files(
        self, ignore_pattern: Optional[re.Pattern]
    ) -> tuple[list[str], list[str]]:
        """Return a list of files to add to the index and a list of files to remove from
        the index.

        New or modified files matching `ignore_pattern` are not considered for staging.
        The `ignore_pattern` is never applied to files to be removed."""
        # We always want to stage untracked files. The implementation for
        # `untracked_files` respects the repo's `.gitignore`.
        to_add = self._repo.untracked_files
        # We don't want to remove anything by default.
        to_remove = []

        # Determine changes to stage, separated into removals and other changes
        changes = self._repo.index.diff(None)
        if changes:
            for c in changes:
                if c.change_type == "D" or c.deleted_file:
                    # Track removed files for `index.remove()`
                    to_remove.append(c.b_path)
                else:
                    # Track changes which aren't deletions for `index.add()`
                    to_add.append(c.a_path)

        if ignore_pattern:
            to_add = [f for f in to_add if not ignore_pattern.search(f)]

        return to_add, to_remove

    def stage_all(
        self,
        diff_func: DiffFunc = default_difffunc,
        ignore_pattern: Optional[re.Pattern] = None,
    ) -> tuple[str, bool]:
        """Stage all changes.
        This method currently doesn't handle hidden files correctly.

        This method returns a tuple containing the colorized diff of the staged changes
        and a boolean indicating whether any changes were staged.

        The method can raise `MergeConflict` if staged changes contain merge conflicts.
        """
        to_add, to_remove = self._compute_changed_files(ignore_pattern)

        index = self._repo.index
        if len(to_add) > 0:
            index.add(items=to_add)
        if len(to_remove) > 0:
            index.remove(items=to_remove)

        self._check_conflicts()

        # Compute diff of all changes
        try:
            diff = index.diff(self._repo.head.commit)
        except ValueError:
            # Assume that we're in an empty repo if we get a ValueError from
            # index.diff(repo.head.commit). Diff against empty tree.
            diff = index.diff(self._null_tree)

        changed = False
        difftext: list[str] = []
        if diff:
            changed = True
            for ct in diff.change_type:
                # We need to disable type checking here since gitpython expects a value
                # of type `Lit_change_type` in iter_change_type() but returns plain
                # strings in `diff.change_type`.
                for c in diff.iter_change_type(ct):  # type: ignore[arg-type]
                    difftext.extend(process_diff(ct, c, diff_func))

        return "\n".join(difftext), changed

    def stage_files(self, files: Sequence[str]):
        """Add provided list of files to index.
        Can raise `MergeConflict` if staged changes contain merge conflicts."""
        self._repo.index.add(files)
        self._check_conflicts()

    def commit(self, commit_message: str, amend=False):
        author = self.author

        if self._trace:
            click.echo(f' > Using "{author.name} <{author.email}>" as commit author')

        if amend:
            # We need to call out to `git commit` for amending
            self._repo.git.execute(  # type: ignore[call-overload]
                [
                    "git",
                    "commit",
                    "--amend",
                    "--no-edit",
                    "--reset-author",
                    "-m",
                    commit_message,
                ],
                env=self._author_env,
            )
        else:
            self._repo.index.commit(commit_message, author=author, committer=author)

    def push(
        self, remote: Optional[str] = None, version: Optional[str] = None
    ) -> Iterable[PushInfo]:
        if not remote:
            remote = "origin"
        if not version:
            version = self._default_version()
        return self._repo.remote(remote).push(version)

    def reset(self, working_tree: bool = False):
        self._repo.head.reset(working_tree=working_tree)