janw/podcast-archiver

View on GitHub
podcast_archiver/utils.py

Summary

Maintainability
A
0 mins
Test Coverage
from __future__ import annotations

import os
import re
from contextlib import contextmanager
from string import Formatter
from typing import IO, TYPE_CHECKING, Any, Generator, Iterable, Iterator, Literal, TypedDict, overload

from pydantic import ValidationError
from requests import HTTPError
from slugify import slugify as _slugify

from podcast_archiver.exceptions import NotModified
from podcast_archiver.logging import logger, rprint

if TYPE_CHECKING:
    from pathlib import Path

    from podcast_archiver.config import Settings
    from podcast_archiver.models import Episode, FeedInfo

filename_safe_re = re.compile(r'[/\\?%*:|"<>]')
slug_safe_re = re.compile(r"[^A-Za-z0-9-_\.]+")


MIMETYPE_EXTENSION_MAPPING: dict[str, str] = {
    "audio/mp4": "m4a",
    "audio/mp3": "mp3",
    "audio/mpeg": "mp3",
}


def get_generic_extension(link_type: str) -> str:
    return MIMETYPE_EXTENSION_MAPPING.get(link_type, "ext")


def make_filename_safe(value: str) -> str:
    return filename_safe_re.sub("-", value)


def slugify(value: str) -> str:
    return _slugify(
        value,
        lowercase=False,
        regex_pattern=slug_safe_re,
        replacements=[
            ("Ü", "UE"),
            ("ü", "ue"),
            ("Ö", "OE"),
            ("ö", "oe"),
            ("Ä", "AE"),
            ("ä", "ae"),
        ],
    )


def truncate(value: str, max_length: int) -> str:
    if len(value) <= max_length:
        return value
    truncated = value[:max_length]
    prefix, sep, suffix = truncated.rpartition(" ")
    if prefix and sep:
        return "".join((prefix, sep, "…"))
    return truncated[: max_length - 1] + "…"


class FormatterKwargs(TypedDict, total=False):
    episode: Episode
    show: FeedInfo
    ext: str


DATETIME_FIELDS = {"episode.published_time"}
DEFAULT_DATETIME_FMT = "%Y-%m-%d"


class FilenameFormatter(Formatter):
    _template: str
    _slugify: bool
    _path_root: Path

    _parsed: list[tuple[str, str | None, str | None, str | None]]

    def __init__(self, settings: Settings) -> None:
        self._template = settings.filename_template
        self._slugify = settings.slugify_paths
        self._path_root = settings.archive_directory

    def parse(  # type: ignore[override]
        self,
        format_string: str,
    ) -> Iterable[tuple[str, str | None, str | None, str | None]]:
        for literal_text, field_name, format_spec, conversion in super().parse(format_string):
            if field_name in DATETIME_FIELDS and not format_spec:
                format_spec = DEFAULT_DATETIME_FMT
            yield literal_text, field_name, format_spec, conversion

    def format_field(self, value: Any, format_spec: str) -> str:
        formatted: str = super().format_field(value, format_spec)
        if self._slugify:
            return slugify(formatted)
        return make_filename_safe(formatted)

    def format(self, episode: Episode, feed_info: FeedInfo) -> Path:  # type: ignore[override]
        kwargs: FormatterKwargs = {
            "episode": episode,
            "show": feed_info,
            "ext": episode.ext,
        }
        return self._path_root / self.vformat(self._template, args=(), kwargs=kwargs)


@overload
@contextmanager
def atomic_write(target: Path, mode: Literal["w"] = "w") -> Iterator[IO[str]]: ...


@overload
@contextmanager
def atomic_write(target: Path, mode: Literal["wb"]) -> Iterator[IO[bytes]]: ...


@contextmanager
def atomic_write(target: Path, mode: Literal["w", "wb"] = "w") -> Iterator[IO[bytes]] | Iterator[IO[str]]:
    tempfile = target.with_suffix(".part")
    try:
        with tempfile.open(mode) as fp:
            yield fp
            fp.flush()
            os.fsync(fp.fileno())
        logger.debug("Moving file '%s' => '%s'", tempfile, target)
        os.rename(tempfile, target)
    except Exception:
        target.unlink(missing_ok=True)
        raise
    finally:
        tempfile.unlink(missing_ok=True)


@contextmanager
def handle_feed_request(url: str) -> Generator[None, Any, None]:
    try:
        yield
    except HTTPError as exc:
        logger.debug("Failed to request feed url %s", url, exc_info=exc)
        if (response := getattr(exc, "response", None)) is None:
            rprint(f"[error]Failed to retrieve feed {url}: {exc}[/]")
            return

        rprint(f"[error]Received status code {response.status_code} from {url}[/]")

    except ValidationError as exc:
        logger.debug("Feed validation failed for %s", url, exc_info=exc)
        rprint(f"[error]Received invalid feed from {url}[/]")

    except NotModified as exc:
        logger.debug("Skipping retrieval for %s", exc.info)
        rprint(f"\n[bar.finished]⏲ Feed of {exc.info} is unchanged, skipping.[/]")

    except Exception as exc:
        logger.debug("Unexpected error for url %s", url, exc_info=exc)
        rprint(f"[error]Failed to retrieve feed {url}: {exc}[/]")