tulir/mautrix-telegram

View on GitHub
mautrix_telegram/db/telethon_session.py

Summary

Maintainability
A
1 hr
Test Coverage
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Iterable
import asyncio
import datetime

from telethon import utils
from telethon.crypto import AuthKey
from telethon.sessions import MemorySession
from telethon.tl.types import PeerChannel, PeerChat, PeerUser, updates

from mautrix.util.async_db import Database, Scheme

fake_db = Database.create("") if TYPE_CHECKING else None


class PgSession(MemorySession):
    db: ClassVar[Database] = fake_db

    session_id: str
    _dc_id: int
    _server_address: str | None
    _port: int | None
    _auth_key: AuthKey | None
    _takeout_id: int | None
    _process_entities_lock: asyncio.Lock

    def __init__(
        self,
        session_id: str,
        dc_id: int = 0,
        server_address: str | None = None,
        port: int | None = None,
        auth_key: AuthKey | None = None,
        takeout_id: int | None = None,
    ) -> None:
        super().__init__()
        self.session_id = session_id
        self._dc_id = dc_id
        self._server_address = server_address
        self._port = port
        self._auth_key = auth_key
        self._takeout_id = takeout_id
        self._process_entities_lock = asyncio.Lock()

    def clone(self, to_instance=None) -> MemorySession:
        # We don't want to store data of clones
        # (which are used for temporarily connecting to different DCs)
        return super().clone(MemorySession())

    @property
    def auth_key_bytes(self) -> bytes | None:
        return self._auth_key.key if self._auth_key else None

    @classmethod
    async def get(cls, session_id: str) -> PgSession:
        q = (
            "SELECT session_id, dc_id, server_address, port, auth_key FROM telethon_sessions "
            "WHERE session_id=$1"
        )
        row = await cls.db.fetchrow(q, session_id)
        if row is None:
            return cls(session_id)
        data = {**row}
        auth_key = AuthKey(data.pop("auth_key", None))
        return cls(**data, auth_key=auth_key)

    @classmethod
    async def has(cls, session_id: str) -> bool:
        q = "SELECT COUNT(*) FROM telethon_sessions WHERE session_id=$1"
        count = await cls.db.fetchval(q, session_id)
        return count > 0

    async def save(self) -> None:
        q = (
            "INSERT INTO telethon_sessions (session_id, dc_id, server_address, port, auth_key) "
            "VALUES ($1, $2, $3, $4, $5) ON CONFLICT (session_id) "
            "DO UPDATE SET dc_id=$2, server_address=$3, port=$4, auth_key=$5"
        )
        await self.db.execute(
            q, self.session_id, self.dc_id, self.server_address, self.port, self.auth_key_bytes
        )

    _tables: ClassVar[tuple[str, ...]] = (
        "telethon_sessions",
        "telethon_entities",
        "telethon_sent_files",
        "telethon_update_state",
    )

    async def delete(self) -> None:
        async with self.db.acquire() as conn, conn.transaction():
            for table in self._tables:
                await conn.execute(f"DELETE FROM {table} WHERE session_id=$1", self.session_id)

    async def close(self) -> None:
        # Nothing to do here, DB connection is global
        pass

    async def get_update_state(self, entity_id: int) -> updates.State | None:
        q = (
            "SELECT pts, qts, date, seq, unread_count FROM telethon_update_state "
            "WHERE session_id=$1 AND entity_id=$2"
        )
        row = await self.db.fetchrow(q, self.session_id, entity_id)
        if row is None:
            return None
        date = datetime.datetime.utcfromtimestamp(row["date"])
        return updates.State(row["pts"], row["qts"], date, row["seq"], row["unread_count"])

    _set_update_state_q = """
    INSERT INTO telethon_update_state (session_id, entity_id, pts, qts, date, seq, unread_count)
    VALUES ($1, $2, $3, $4, $5, $6, $7)
    ON CONFLICT (session_id, entity_id) DO UPDATE SET
        pts=excluded.pts, qts=excluded.qts, date=excluded.date, seq=excluded.seq,
        unread_count=excluded.unread_count
    """

    async def set_update_state(self, entity_id: int, row: updates.State) -> None:
        q = self._set_update_state_q
        ts = row.date.timestamp()
        await self.db.execute(
            q, self.session_id, entity_id, row.pts, row.qts, ts, row.seq, row.unread_count
        )

    async def set_update_states(self, rows: list[tuple[int, updates.State]]) -> None:
        rows = [
            (
                self.session_id,
                entity_id,
                row.pts,
                row.qts,
                row.date.timestamp(),
                row.seq,
                row.unread_count,
            )
            for entity_id, row in rows
        ]
        if self.db.scheme == Scheme.POSTGRES:
            q = """
            INSERT INTO telethon_update_state (
                session_id, entity_id, pts, qts, date, seq, unread_count
            )
            VALUES (
                $1,
                unnest($2::bigint[]), unnest($3::bigint[]), unnest($4::bigint[]),
                unnest($5::bigint[]), unnest($6::bigint[]), unnest($7::integer[])
            )
            ON CONFLICT (session_id, entity_id) DO UPDATE SET
                pts=excluded.pts, qts=excluded.qts, date=excluded.date, seq=excluded.seq,
                unread_count=excluded.unread_count
            """
            _, entity_ids, ptses, qtses, timestamps, seqs, unread_counts = zip(*rows)
            await self.db.execute(
                q, self.session_id, entity_ids, ptses, qtses, timestamps, seqs, unread_counts
            )
        else:
            await self.db.executemany(self._set_update_state_q, rows)

    async def delete_update_state(self, entity_id: int) -> None:
        q = "DELETE FROM telethon_update_state WHERE session_id=$1 AND entity_id=$2"
        await self.db.execute(q, self.session_id, entity_id)

    async def get_update_states(self) -> Iterable[tuple[int, updates.State], ...]:
        q = (
            "SELECT entity_id, pts, qts, date, seq, unread_count FROM telethon_update_state "
            "WHERE session_id=$1"
        )
        rows = await self.db.fetch(q, self.session_id)
        return (
            (
                row["entity_id"],
                updates.State(
                    row["pts"],
                    row["qts"],
                    datetime.datetime.utcfromtimestamp(row["date"]),
                    row["seq"],
                    row["unread_count"],
                ),
            )
            for row in rows
        )

    def _entity_values_to_row(
        self, id: int, hash: int, username: str | None, phone: str | int | None, name: str | None
    ) -> tuple[str, int, int, str | None, str | None, str | None]:
        return self.session_id, id, hash, username, str(phone) if phone else None, name

    async def process_entities(self, tlo) -> None:
        # Postgres likes to deadlock on simultaneous upserts, so just lock the whole thing here
        # TODO: make sure postgres doesn't deadlock on upserts when session_id is different
        async with self._process_entities_lock:
            await self._locked_process_entities(tlo)

    async def _locked_process_entities(self, tlo) -> None:
        rows: list[tuple[str, int, int, str | None, str | None, str | None]] = (
            self._entities_to_rows(tlo)
        )
        if not rows:
            return
        if self.db.scheme == Scheme.POSTGRES:
            q = (
                "INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
                "VALUES ($1, unnest($2::bigint[]), unnest($3::bigint[]), "
                "        unnest($4::text[]), unnest($5::text[]), unnest($6::text[])) "
                "ON CONFLICT (session_id, id) DO UPDATE"
                "  SET hash=excluded.hash, username=excluded.username,"
                "      phone=excluded.phone, name=excluded.name"
            )
            _, ids, hashes, usernames, phones, names = zip(*rows)
            await self.db.execute(q, self.session_id, ids, hashes, usernames, phones, names)
        else:
            q = (
                "INSERT INTO telethon_entities (session_id, id, hash, username, phone, name) "
                "VALUES ($1, $2, $3, $4, $5, $6) "
                "ON CONFLICT (session_id, id) DO UPDATE "
                "    SET hash=$3, username=$4, phone=$5, name=$6"
            )
            await self.db.executemany(q, rows)

    async def _select_entity(
        self, constraint: str, *args: str | int | tuple[int, ...]
    ) -> tuple[int, int] | None:
        q = f"SELECT id, hash FROM telethon_entities WHERE session_id=$1 AND {constraint}"
        row = await self.db.fetchrow(q, self.session_id, *args)
        if row is None:
            return None
        return row["id"], row["hash"]

    async def get_entity_rows_by_phone(self, key: str | int) -> tuple[int, int] | None:
        return await self._select_entity("phone=$2", str(key))

    async def get_entity_rows_by_username(self, key: str) -> tuple[int, int] | None:
        return await self._select_entity("username=$2", key)

    async def get_entity_rows_by_name(self, key: str) -> tuple[int, int] | None:
        return await self._select_entity("name=$2", key)

    async def get_entity_rows_by_id(self, key: int, exact: bool = True) -> tuple[int, int] | None:
        if exact:
            return await self._select_entity("id=$2", key)

        ids = (
            utils.get_peer_id(PeerUser(key)),
            utils.get_peer_id(PeerChat(key)),
            utils.get_peer_id(PeerChannel(key)),
        )
        if self.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
            return await self._select_entity("id=ANY($2)", ids)
        else:
            return await self._select_entity(f"id IN ($2, $3, $4)", *ids)