tulir/mautrix-telegram

View on GitHub
mautrix_telegram/util/parallel_file_transfer.py

Summary

Maintainability
B
5 hrs
Test Coverage
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2019 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import AsyncGenerator, Awaitable, Union, cast
from collections import defaultdict
import asyncio
import hashlib
import logging
import math
import time

from aiohttp import ClientResponse
from telethon import helpers, utils
from telethon.crypto import AuthKey
from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (
    GetFileRequest,
    SaveBigFilePartRequest,
    SaveFilePartRequest,
)
from telethon.tl.types import (
    Document,
    InputDocumentFileLocation,
    InputFile,
    InputFileBig,
    InputFileLocation,
    InputPeerPhotoFileLocation,
    InputPhotoFileLocation,
    TypeInputFile,
)

from mautrix.appservice import IntentAPI
from mautrix.types import ContentURI, EncryptedFile
from mautrix.util.logging import TraceLogger

from ..db import TelegramFile as DBTelegramFile
from ..tgclient import MautrixTelegramClient

try:
    from mautrix.crypto.attachments import async_encrypt_attachment
except ImportError:
    async_encrypt_attachment = None

log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util"))

TypeLocation = Union[
    Document,
    InputDocumentFileLocation,
    InputPeerPhotoFileLocation,
    InputFileLocation,
    InputPhotoFileLocation,
]


class DownloadSender:
    sender: MTProtoSender
    request: GetFileRequest
    remaining: int
    stride: int

    def __init__(
        self,
        sender: MTProtoSender,
        file: TypeLocation,
        offset: int,
        limit: int,
        stride: int,
        count: int,
    ) -> None:
        self.sender = sender
        self.request = GetFileRequest(file, offset=offset, limit=limit)
        self.stride = stride
        self.remaining = count

    async def next(self) -> bytes | None:
        if not self.remaining:
            return None
        result = await self.sender.send(self.request)
        self.remaining -= 1
        self.request.offset += self.stride
        return result.bytes

    def disconnect(self) -> Awaitable[None]:
        return self.sender.disconnect()


class UploadSender:
    sender: MTProtoSender
    request: SaveFilePartRequest < SaveBigFilePartRequest
    part_count: int
    stride: int
    previous: asyncio.Task | None
    loop: asyncio.AbstractEventLoop

    def __init__(
        self,
        sender: MTProtoSender,
        file_id: int,
        part_count: int,
        big: bool,
        index: int,
        stride: int,
        loop: asyncio.AbstractEventLoop,
    ) -> None:
        self.sender = sender
        self.part_count = part_count
        if big:
            self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
        else:
            self.request = SaveFilePartRequest(file_id, index, b"")
        self.stride = stride
        self.previous = None
        self.loop = loop

    async def next(self, data: bytes) -> None:
        if self.previous:
            await self.previous
        self.previous = asyncio.create_task(self._next(data))

    async def _next(self, data: bytes) -> None:
        self.request.bytes = data
        log.trace(
            f"Sending file part {self.request.file_part}/{self.part_count}"
            f" with {len(data)} bytes"
        )
        await self.sender.send(self.request)
        self.request.file_part += self.stride

    async def disconnect(self) -> None:
        if self.previous:
            await self.previous
        return await self.sender.disconnect()


class ParallelTransferrer:
    client: MautrixTelegramClient
    loop: asyncio.AbstractEventLoop
    dc_id: int
    senders: list[DownloadSender | UploadSender] | None
    auth_key: AuthKey
    upload_ticker: int

    def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None:
        self.client = client
        self.loop = self.client.loop
        self.dc_id = dc_id or self.client.session.dc_id
        self.auth_key = (
            None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key
        )
        self.senders = None
        self.upload_ticker = 0

    async def _cleanup(self) -> None:
        await asyncio.gather(*(sender.disconnect() for sender in self.senders))
        self.senders = None

    @staticmethod
    def _get_connection_count(
        file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
    ) -> int:
        if file_size > full_size:
            return max_count
        return math.ceil((file_size / full_size) * max_count)

    async def _init_download(
        self, connections: int, file: TypeLocation, part_count: int, part_size: int
    ) -> None:
        minimum, remainder = divmod(part_count, connections)

        def get_part_count() -> int:
            nonlocal remainder
            if remainder > 0:
                remainder -= 1
                return minimum + 1
            return minimum

        # The first cross-DC sender will export+import the authorization, so we always create it
        # before creating any other senders.
        self.senders = [
            await self._create_download_sender(
                file, 0, part_size, connections * part_size, get_part_count()
            ),
            *await asyncio.gather(
                *(
                    self._create_download_sender(
                        file, i, part_size, connections * part_size, get_part_count()
                    )
                    for i in range(1, connections)
                )
            ),
        ]

    async def _create_download_sender(
        self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int
    ) -> DownloadSender:
        return DownloadSender(
            await self._create_sender(), file, index * part_size, part_size, stride, part_count
        )

    async def _init_upload(
        self, connections: int, file_id: int, part_count: int, big: bool
    ) -> None:
        self.senders = [
            await self._create_upload_sender(file_id, part_count, big, 0, connections),
            *await asyncio.gather(
                *(
                    self._create_upload_sender(file_id, part_count, big, i, connections)
                    for i in range(1, connections)
                )
            ),
        ]

    async def _create_upload_sender(
        self, file_id: int, part_count: int, big: bool, index: int, stride: int
    ) -> UploadSender:
        return UploadSender(
            await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop
        )

    async def _create_sender(self) -> MTProtoSender:
        dc = await self.client._get_dc(self.dc_id)
        sender = MTProtoSender(self.auth_key, loggers=self.client._log)
        await sender.connect(
            self.client._connection(
                dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy
            )
        )
        if not self.auth_key:
            log.debug(f"Exporting auth to DC {self.dc_id}")
            auth = await self.client(ExportAuthorizationRequest(self.dc_id))
            self.client._init_request.query = ImportAuthorizationRequest(
                id=auth.id, bytes=auth.bytes
            )
            req = InvokeWithLayerRequest(LAYER, self.client._init_request)
            await sender.send(req)
            self.auth_key = sender.auth_key
        return sender

    async def init_upload(
        self,
        file_id: int,
        file_size: int,
        part_size_kb: float | None = None,
        connection_count: int | None = None,
    ) -> tuple[int, int, bool]:
        connection_count = connection_count or self._get_connection_count(file_size)
        part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
        part_count = (file_size + part_size - 1) // part_size
        is_large = file_size > 10 * 1024 * 1024
        await self._init_upload(connection_count, file_id, part_count, is_large)
        return part_size, part_count, is_large

    async def upload(self, part: bytes) -> None:
        await self.senders[self.upload_ticker].next(part)
        self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)

    async def finish_upload(self) -> None:
        await self._cleanup()

    async def download(
        self,
        file: TypeLocation,
        file_size: int,
        part_size_kb: float | None = None,
        connection_count: int | None = None,
    ) -> AsyncGenerator[bytes, None]:
        connection_count = connection_count or self._get_connection_count(file_size)
        part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
        part_count = math.ceil(file_size / part_size)
        log.debug(
            f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}"
        )
        await self._init_download(connection_count, file, part_count, part_size)

        part = 0
        while part < part_count:
            tasks = []
            for sender in self.senders:
                tasks.append(asyncio.create_task(sender.next()))
            for task in tasks:
                data = await task
                if not data:
                    break
                yield data
                part += 1
                log.trace(f"Part {part} downloaded")

        log.debug("Parallel download finished, cleaning up connections")
        await self._cleanup()


parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())


async def parallel_transfer_to_matrix(
    client: MautrixTelegramClient,
    intent: IntentAPI,
    loc_id: str,
    location: TypeLocation,
    filename: str,
    encrypt: bool,
    parallel_id: int,
) -> DBTelegramFile:
    size = location.size
    mime_type = location.mime_type
    dc_id, location = utils.get_input_location(location)
    # We lock the transfers because telegram has connection count limits
    async with parallel_transfer_locks[parallel_id]:
        downloader = ParallelTransferrer(client, dc_id)
        data = downloader.download(location, size)
        decryption_info = None
        up_mime_type = mime_type
        if encrypt and async_encrypt_attachment:

            async def encrypted(stream):
                nonlocal decryption_info
                async for chunk in async_encrypt_attachment(stream):
                    if isinstance(chunk, EncryptedFile):
                        decryption_info = chunk
                    else:
                        yield chunk

            data = encrypted(data)
            up_mime_type = "application/octet-stream"
        content_uri = await intent.upload_media(
            data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None
        )
        if decryption_info:
            decryption_info.url = content_uri
    return DBTelegramFile(
        id=loc_id,
        mxc=content_uri,
        mime_type=mime_type,
        was_converted=False,
        timestamp=int(time.time()),
        size=size,
        width=None,
        height=None,
        decryption_info=decryption_info,
    )


async def _internal_transfer_to_telegram(
    client: MautrixTelegramClient, response: ClientResponse
) -> tuple[TypeInputFile, int]:
    file_id = helpers.generate_random_long()
    file_size = response.content_length

    hash_md5 = hashlib.md5()
    uploader = ParallelTransferrer(client)
    part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
    buffer = bytearray()
    async for data in response.content:
        if not is_large:
            hash_md5.update(data)
        if len(buffer) == 0 and len(data) == part_size:
            await uploader.upload(data)
            continue
        new_len = len(buffer) + len(data)
        if new_len >= part_size:
            cutoff = part_size - len(buffer)
            buffer.extend(data[:cutoff])
            await uploader.upload(bytes(buffer))
            buffer.clear()
            buffer.extend(data[cutoff:])
        else:
            buffer.extend(data)
    if len(buffer) > 0:
        await uploader.upload(bytes(buffer))
    await uploader.finish_upload()
    if is_large:
        return InputFileBig(file_id, part_count, "upload"), file_size
    else:
        return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size


async def parallel_transfer_to_telegram(
    client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int
) -> tuple[TypeInputFile, int]:
    url = intent.api.get_download_url(uri)
    async with parallel_transfer_locks[parallel_id]:
        async with intent.api.session.get(url) as response:
            return await _internal_transfer_to_telegram(client, response)