yukihiko-shinoda/parallel-media-downloader

View on GitHub
parallelmediadownloader/media_download_coroutine.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Coroutine of downloading media."""
import asyncio
from asyncio import Semaphore
from dataclasses import dataclass
from logging import getLogger
from typing import List

from aiohttp import ClientConnectorError, ClientResponse, ClientResponseError, ClientSession

from parallelmediadownloader.exceptions import HttpTimeoutError
from parallelmediadownloader.media_save_coroutine import MediaSaveCoroutine, SaveOrder
from parallelmediadownloader.modeia_download_result import MediaDownloadResult

__all__ = ["DownloadOrder", "MediaDownloadCoroutine"]

logger = getLogger(__name__)


@dataclass
class DownloadOrder:
    url: str
    save_order: SaveOrder


class MediaDownloadCoroutine:
    """Coroutine of downloading media."""

    def __init__(self, media_save_coroutine: MediaSaveCoroutine, *, allow_http_status: List[int] = None):
        self.media_save_coroutine = media_save_coroutine
        self.allow_http_status = [] if allow_http_status is None else allow_http_status

    async def execute(
        self, semaphore: Semaphore, client_session: ClientSession, download_order: DownloadOrder
    ) -> MediaDownloadResult:
        """function want to limit the number of parallel"""
        url = download_order.url
        async with semaphore:
            response = await self.try_get(client_session, url)
            try:
                response.raise_for_status()
            except ClientResponseError as error:
                logger.exception("Error! URL = %s", url)
                if error.status not in self.allow_http_status:
                    raise error
                logger.error("Media may be removed. URL = %s", url)
                return MediaDownloadResult(url, response.status, None)
            media = await self.try_read_response(response, url)
            media_file = await self.media_save_coroutine.execute(media, download_order.save_order)
            return MediaDownloadResult(url, response.status, media_file)

    async def try_get(self, client_session: ClientSession, url: str) -> ClientResponse:
        """Tries to get request."""
        try:
            return await client_session.get(url, timeout=30)
        except asyncio.TimeoutError as error:
            logger.error("TimeoutError. URL = %s", url)
            raise HttpTimeoutError(url=url) from error
        except ClientConnectorError as error:
            logger.error("ClientConnectorError. URL = %s", url)
            raise error

    async def try_read_response(self, response: ClientResponse, url: str) -> bytes:
        try:
            return await response.read()
        except ClientResponseError as error:
            logger.error("ClientResponseError. URL = %s", url)
            raise HttpTimeoutError(url=url) from error