hikari-py/hikari

View on GitHub
hikari/impl/voice.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2020 Nekokatt
# Copyright (c) 2021-present davfsa
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Implementation of a simple voice management system."""

from __future__ import annotations

__all__: typing.Sequence[str] = ("VoiceComponentImpl",)

import asyncio
import logging
import types
import typing

from hikari import errors
from hikari import snowflakes
from hikari.api import voice
from hikari.events import voice_events
from hikari.internal import ux

if typing.TYPE_CHECKING:
    from hikari import channels
    from hikari import guilds
    from hikari import traits

    _VoiceConnectionT = typing.TypeVar("_VoiceConnectionT", bound="voice.VoiceConnection")

_LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.voice.management")


class VoiceComponentImpl(voice.VoiceComponent):
    """A standard voice component management implementation.

    This is the regular implementation you will generally use to connect to
    voice channels with.
    """

    __slots__: typing.Sequence[str] = (
        "_app",
        "_connections",
        "connections",
        "_is_alive",
        "_is_closing",
        "_voice_listener",
    )

    _connections: typing.Dict[snowflakes.Snowflake, voice.VoiceConnection]
    connections: typing.Mapping[snowflakes.Snowflake, voice.VoiceConnection]

    def __init__(self, app: traits.GatewayBotAware) -> None:
        self._app = app
        self._connections = {}
        self.connections = types.MappingProxyType(self._connections)
        self._is_alive = False
        self._is_closing = False
        self._voice_listener = False

    @property
    def is_alive(self) -> bool:
        return self._is_alive

    def _check_if_alive(self) -> None:
        if not self._is_alive:
            raise errors.ComponentStateConflictError("Component cannot be used while it's not alive")

        if self._is_closing:
            raise errors.ComponentStateConflictError("Component cannot be used while it's closing")

    async def disconnect(self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild]) -> None:
        self._check_if_alive()
        guild_id = snowflakes.Snowflake(guild)

        if guild_id not in self._connections:
            raise errors.VoiceError("This application doesn't have any active voice connection in this server")

        conn = self._connections[guild_id]
        # We rely on the assumption that _on_connection_close will be called here rather than explicitly
        # to remove the connection from self._connections.
        await conn.disconnect()

    async def _disconnect_all(self) -> None:
        # We rely on the assumption that _on_connection_close will be called here rather than explicitly
        # emptying self._connections.
        await asyncio.gather(*(c.disconnect() for c in self._connections.values()))

    async def disconnect_all(self) -> None:
        self._check_if_alive()
        await self._disconnect_all()

    async def close(self) -> None:
        self._check_if_alive()
        self._is_closing = True

        if self._voice_listener:
            self._app.event_manager.unsubscribe(voice_events.VoiceEvent, self._on_voice_event)

        if self._connections:
            _LOGGER.info("shutting down %s active voice connection(s)", len(self._connections))
            await self._disconnect_all()

        self._is_alive = False
        self._is_closing = False
        self._voice_listener = False

    def start(self) -> None:
        """Start this voice component."""
        if self._is_alive:
            raise errors.ComponentStateConflictError("Cannot start a voice component which is already running")

        self._is_alive = True
        self._voice_listener = False

    async def connect_to(
        self,
        guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
        channel: snowflakes.SnowflakeishOr[channels.GuildVoiceChannel],
        voice_connection_type: typing.Type[_VoiceConnectionT],
        *,
        deaf: bool = False,
        mute: bool = False,
        timeout: typing.Optional[int] = 5,
        **kwargs: typing.Any,
    ) -> _VoiceConnectionT:
        self._check_if_alive()
        guild_id = snowflakes.Snowflake(guild)

        if guild_id in self._connections:
            raise errors.VoiceError(
                "Already in a voice channel for that guild. Disconnect before attempting to connect again"
            )

        shard_id = snowflakes.calculate_shard_id(self._app, guild_id)
        try:
            shard = self._app.shards[shard_id]
        except KeyError:
            raise errors.VoiceError(
                f"Cannot connect to shard {shard_id} as it is not present in this application"
            ) from None

        user = self._app.cache.get_me()
        if not user:
            user = await self._app.rest.fetch_my_user()

        _LOGGER.log(ux.TRACE, "attempting to connect to voice channel %s in %s via shard %s", channel, guild, shard_id)

        await shard.update_voice_state(guild, channel, self_deaf=deaf, self_mute=mute)

        _LOGGER.log(
            ux.TRACE,
            "waiting for voice events for connecting to voice channel %s in %s via shard %s",
            channel,
            guild,
            shard_id,
        )

        try:
            state_event, server_event = await asyncio.gather(
                # Voice state update:
                self._app.event_manager.wait_for(
                    voice_events.VoiceStateUpdateEvent,
                    timeout=timeout,
                    predicate=self._init_state_update_predicate(guild_id, user.id),
                ),
                # Server update:
                self._app.event_manager.wait_for(
                    voice_events.VoiceServerUpdateEvent,
                    timeout=timeout,
                    predicate=self._init_server_update_predicate(guild_id),
                ),
            )
        except asyncio.TimeoutError as e:
            raise errors.VoiceError(f"Could not connect to voice channel {channel} in guild {guild}.") from e

        # We will never receive the first endpoint as [`None`][]
        assert server_event.endpoint is not None

        _LOGGER.debug(
            "joined voice channel %s in guild %s via shard %s using endpoint %s. Session will be %s. "
            "Delegating to voice websocket",
            state_event.state.channel_id,
            state_event.state.guild_id,
            shard_id,
            server_event.endpoint,
            state_event.state.session_id,
        )

        try:
            voice_connection = await voice_connection_type.initialize(
                channel_id=snowflakes.Snowflake(channel),
                endpoint=server_event.endpoint,
                guild_id=guild_id,
                on_close=self._on_connection_close,
                owner=self,
                session_id=state_event.state.session_id,
                shard_id=shard_id,
                token=server_event.token,
                user_id=user.id,
                **kwargs,
            )
        except Exception:
            _LOGGER.debug("error occurred in initialization, leaving voice channel %s in guild %s", channel, guild)
            try:
                await asyncio.wait_for(shard.update_voice_state(guild, None), timeout=5.0)
            except asyncio.TimeoutError:
                pass

            raise

        if not self._voice_listener:
            self._app.event_manager.subscribe(voice_events.VoiceEvent, self._on_voice_event)
            self._voice_listener = True

        self._connections[guild_id] = voice_connection
        return voice_connection

    @staticmethod
    def _init_state_update_predicate(
        guild_id: snowflakes.Snowflake, user_id: snowflakes.Snowflake
    ) -> typing.Callable[[voice_events.VoiceStateUpdateEvent], bool]:
        def predicate(event: voice_events.VoiceStateUpdateEvent) -> bool:
            return event.state.guild_id == guild_id and event.state.user_id == user_id

        return predicate

    @staticmethod
    def _init_server_update_predicate(
        guild_id: snowflakes.Snowflake,
    ) -> typing.Callable[[voice_events.VoiceServerUpdateEvent], bool]:
        def predicate(event: voice_events.VoiceServerUpdateEvent) -> bool:
            return event.guild_id == guild_id

        return predicate

    async def _on_connection_close(self, connection: voice.VoiceConnection) -> None:
        try:
            del self._connections[connection.guild_id]

            if not self._connections:
                self._app.event_manager.unsubscribe(voice_events.VoiceEvent, self._on_voice_event)
                self._voice_listener = False

            # Leave the voice channel explicitly, otherwise we will just appear to
            # not leave properly.
            await self._app.shards[connection.shard_id].update_voice_state(guild=connection.guild_id, channel=None)

            _LOGGER.debug(
                "successfully unregistered voice connection %s to guild %s and left voice channel %s",
                connection,
                connection.guild_id,
                connection.channel_id,
            )

        except KeyError:
            _LOGGER.warning(
                "ignored closure of phantom unregistered voice connection %s to guild %s. Perhaps this is a bug?",
                connection,
                connection.guild_id,
            )

    async def _on_voice_event(self, event: voice_events.VoiceEvent) -> None:
        if event.guild_id in self._connections:
            connection = self._connections[event.guild_id]
            _LOGGER.log(
                ux.TRACE, "notifying voice connection %s in guild %s of event %s", connection, event.guild_id, event
            )
            await connection.notify(event)