1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 19:44:20 +00:00

Recreate the FileAudioSource (might break some stuff!)

This commit is contained in:
Steffo 2019-11-17 18:57:02 +01:00
parent 5bbae7da8c
commit 4764e27d72
7 changed files with 144 additions and 11 deletions

View file

@ -23,6 +23,8 @@ class SummonCommand(Command):
if self.interface.name != "discord": if self.interface.name != "discord":
# TODO: use a Herald Event to remotely connect the bot # TODO: use a Herald Event to remotely connect the bot
raise UnsupportedError() raise UnsupportedError()
if discord is None:
raise ConfigurationError("'discord' extra is not installed.")
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
message: discord.Message = data.message message: discord.Message = data.message
member: Union[discord.User, discord.Member] = message.author member: Union[discord.User, discord.Member] = message.author
@ -92,7 +94,7 @@ class SummonCommand(Command):
# Try to connect to the voice channel # Try to connect to the voice channel
try: try:
voice: discord.VoiceClient = await channel.connect() await channel.connect()
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise ExternalError("Timed out while trying to connect to the channel") raise ExternalError("Timed out while trying to connect to the channel")
except discord.opus.OpusNotLoaded: except discord.opus.OpusNotLoaded:

View file

@ -0,0 +1,78 @@
import typing
import re
import os
from contextlib import asynccontextmanager
from royalnet.utils import asyncify, MultiLock
from .ytdlinfo import YtdlInfo
from .ytdlfile import YtdlFile
try:
import discord
from royalnet.serf.discord import FileAudioSource
except ImportError:
discord = None
try:
import ffmpeg
except ImportError:
ffmpeg = None
class YtdlDiscord:
"""A representation of a YtdlFile conversion to the :mod:`discord` PCM format."""
def __init__(self, ytdl_file: YtdlFile):
self.ytdl_file: YtdlFile = ytdl_file
self.pcm_filename: typing.Optional[str] = None
self.lock: MultiLock = MultiLock()
@property
def is_converted(self):
"""Has the file been converted?"""
return self.pcm_filename is not None
async def convert_to_pcm(self) -> None:
"""Convert the file to pcm with :mod:`ffmpeg`."""
if ffmpeg is None:
raise ImportError("'bard' extra is not installed")
await self.ytdl_file.download_file()
if self.pcm_filename is None:
async with self.ytdl_file.lock.normal():
destination_filename = re.sub(r"\.[^.]+$", ".pcm", self.ytdl_file.filename)
async with self.lock.exclusive():
await asyncify(
ffmpeg.input(self.ytdl_file.filename)
.output(destination_filename, format="s16le", ac=2, ar="48000")
.overwrite_output()
.run
)
self.pcm_filename = destination_filename
async def delete_asap(self) -> None:
"""Delete the mp3 file."""
if self.is_converted:
async with self.lock.exclusive():
os.remove(self.pcm_filename)
self.pcm_filename = None
@classmethod
async def from_url(cls, url, **ytdl_args) -> typing.List["YtdlDiscord"]:
"""Create a :class:`list` of :class:`YtdlMp3` from a URL."""
files = await YtdlFile.from_url(url, **ytdl_args)
dfiles = []
for file in files:
dfile = YtdlDiscord(file)
dfiles.append(dfile)
return dfiles
@property
def info(self) -> typing.Optional[YtdlInfo]:
"""Shortcut to get the :class:`YtdlInfo` of the object."""
return self.ytdl_file.info
@asynccontextmanager
async def spawn_audiosource(self):
if discord is None:
raise ImportError("'discord' extra is not installed")
await self.convert_to_pcm()
with open(self.pcm_filename, "rb") as stream:
yield FileAudioSource(stream)

View file

@ -1,14 +1,18 @@
import typing import typing
import re import re
import ffmpeg
import os import os
from royalnet.utils import asyncify, MultiLock from royalnet.utils import asyncify, MultiLock
from .ytdlinfo import YtdlInfo from .ytdlinfo import YtdlInfo
from .ytdlfile import YtdlFile from .ytdlfile import YtdlFile
try:
import ffmpeg
except ImportError:
ffmpeg = None
class YtdlMp3: class YtdlMp3:
"""A representation of a YtdlFile conversion to mp3.""" """A representation of a :class:`YtdlFile` conversion to mp3."""
def __init__(self, ytdl_file: YtdlFile): def __init__(self, ytdl_file: YtdlFile):
self.ytdl_file: YtdlFile = ytdl_file self.ytdl_file: YtdlFile = ytdl_file
self.mp3_filename: typing.Optional[str] = None self.mp3_filename: typing.Optional[str] = None
@ -20,7 +24,9 @@ class YtdlMp3:
return self.mp3_filename is not None return self.mp3_filename is not None
async def convert_to_mp3(self) -> None: async def convert_to_mp3(self) -> None:
"""Convert the file to mp3 with ``ffmpeg``.""" """Convert the file to mp3 with :mod:`ffmpeg`."""
if ffmpeg is None:
raise ImportError("'bard' extra is not installed")
await self.ytdl_file.download_file() await self.ytdl_file.download_file()
if self.mp3_filename is None: if self.mp3_filename is None:
async with self.ytdl_file.lock.normal(): async with self.ytdl_file.lock.normal():

View file

@ -1,4 +1,4 @@
from .create_rich_embed import create_rich_embed from .createrichembed import create_rich_embed
from .escape import escape from .escape import escape
from .discordserf import DiscordSerf from .discordserf import DiscordSerf

View file

@ -67,7 +67,7 @@ class DiscordSerf(Serf):
interface: CommandInterface, interface: CommandInterface,
session, session,
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
message: discord.Message): message: "discord.Message"):
super().__init__(interface=interface, session=session, loop=loop) super().__init__(interface=interface, session=session, loop=loop)
data.message = message data.message = message
@ -75,7 +75,7 @@ class DiscordSerf(Serf):
await data.message.channel.send(escape(text)) await data.message.channel.send(escape(text))
async def get_author(data, error_if_none=False): async def get_author(data, error_if_none=False):
user: discord.Member = data.message.author user: "discord.Member" = data.message.author
query = data.session.query(self._master_table) query = data.session.query(self._master_table)
for link in self._identity_chain: for link in self._identity_chain:
query = query.join(link.mapper.class_) query = query.join(link.mapper.class_)
@ -90,7 +90,7 @@ class DiscordSerf(Serf):
return DiscordData return DiscordData
async def handle_message(self, message: discord.Message): async def handle_message(self, message: "discord.Message"):
"""Handle a Discord message by calling a command if appropriate.""" """Handle a Discord message by calling a command if appropriate."""
text = message.content text = message.content
# Skip non-text messages # Skip non-text messages
@ -100,7 +100,7 @@ class DiscordSerf(Serf):
if not text.startswith("!"): if not text.startswith("!"):
return return
# Skip bot messages # Skip bot messages
author: Union[discord.User] = message.author author: Union["discord.User"] = message.author
if author.bot: if author.bot:
return return
# Find and clean parameters # Find and clean parameters
@ -129,11 +129,11 @@ class DiscordSerf(Serf):
if session is not None: if session is not None:
await asyncify(session.close) await asyncify(session.close)
def bot_factory(self) -> Type[discord.Client]: def bot_factory(self) -> Type["discord.Client"]:
"""Create a custom class inheriting from :py:class:`discord.Client`.""" """Create a custom class inheriting from :py:class:`discord.Client`."""
# noinspection PyMethodParameters # noinspection PyMethodParameters
class DiscordClient(discord.Client): class DiscordClient(discord.Client):
async def on_message(cli, message: discord.Message): async def on_message(cli, message: "discord.Message"):
"""Handle messages received by passing them to the handle_message method of the bot.""" """Handle messages received by passing them to the handle_message method of the bot."""
# TODO: keep reference to these tasks somewhere # TODO: keep reference to these tasks somewhere
self.loop.create_task(self.handle_message(message)) self.loop.create_task(self.handle_message(message))
@ -144,6 +144,13 @@ class DiscordSerf(Serf):
return DiscordClient return DiscordClient
def get_voice_client(self, guild: "discord.Guild") -> Optional["discord.VoiceClient"]:
voice_clients: List["discord.VoiceClient"] = self.client.voice_clients
for voice_client in voice_clients:
if voice_client.guild == guild:
return voice_client
return None
async def run(self): async def run(self):
await super().run() await super().run()
token = self.get_secret("discord") token = self.get_secret("discord")

View file

@ -0,0 +1,40 @@
try:
import discord
except ImportError:
discord = None
class FileAudioSource(discord.AudioSource):
"""A :py:class:`discord.AudioSource` that uses a :py:class:`io.BufferedIOBase` as an input instead of memory.
The stream should be in the usual PCM encoding.
Warning:
This AudioSource will consume (and close) the passed stream."""
def __init__(self, file):
self.file = file
def __repr__(self):
if self.file.seekable():
return f"<{self.__class__.__name__} @{self.file.tell()}>"
else:
return f"<{self.__class__.__name__}>"
def is_opus(self):
"""This audio file isn't Opus-encoded, but PCM-encoded.
Returns:
``False``."""
return False
def read(self):
"""Reads 20ms worth of audio.
If the stream has ended, then return an empty :py:class:`bytes`-like object."""
data: bytes = self.file.read(discord.opus.Encoder.FRAME_SIZE)
# If there is no more data to be streamed
if len(data) != discord.opus.Encoder.FRAME_SIZE:
# Return that the stream has ended
return b""
return data