mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Recreate the FileAudioSource (might break some stuff!)
This commit is contained in:
parent
5bbae7da8c
commit
4764e27d72
7 changed files with 144 additions and 11 deletions
|
@ -23,6 +23,8 @@ class SummonCommand(Command):
|
||||||
if self.interface.name != "discord":
|
if self.interface.name != "discord":
|
||||||
# TODO: use a Herald Event to remotely connect the bot
|
# TODO: use a Herald Event to remotely connect the bot
|
||||||
raise UnsupportedError()
|
raise UnsupportedError()
|
||||||
|
if discord is None:
|
||||||
|
raise ConfigurationError("'discord' extra is not installed.")
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
message: discord.Message = data.message
|
message: discord.Message = data.message
|
||||||
member: Union[discord.User, discord.Member] = message.author
|
member: Union[discord.User, discord.Member] = message.author
|
||||||
|
@ -92,7 +94,7 @@ class SummonCommand(Command):
|
||||||
|
|
||||||
# Try to connect to the voice channel
|
# Try to connect to the voice channel
|
||||||
try:
|
try:
|
||||||
voice: discord.VoiceClient = await channel.connect()
|
await channel.connect()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise ExternalError("Timed out while trying to connect to the channel")
|
raise ExternalError("Timed out while trying to connect to the channel")
|
||||||
except discord.opus.OpusNotLoaded:
|
except discord.opus.OpusNotLoaded:
|
||||||
|
|
78
royalnet/bard/ytdldiscord.py
Normal file
78
royalnet/bard/ytdldiscord.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import typing
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from royalnet.utils import asyncify, MultiLock
|
||||||
|
from .ytdlinfo import YtdlInfo
|
||||||
|
from .ytdlfile import YtdlFile
|
||||||
|
|
||||||
|
try:
|
||||||
|
import discord
|
||||||
|
from royalnet.serf.discord import FileAudioSource
|
||||||
|
except ImportError:
|
||||||
|
discord = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ffmpeg
|
||||||
|
except ImportError:
|
||||||
|
ffmpeg = None
|
||||||
|
|
||||||
|
|
||||||
|
class YtdlDiscord:
|
||||||
|
"""A representation of a YtdlFile conversion to the :mod:`discord` PCM format."""
|
||||||
|
def __init__(self, ytdl_file: YtdlFile):
|
||||||
|
self.ytdl_file: YtdlFile = ytdl_file
|
||||||
|
self.pcm_filename: typing.Optional[str] = None
|
||||||
|
self.lock: MultiLock = MultiLock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_converted(self):
|
||||||
|
"""Has the file been converted?"""
|
||||||
|
return self.pcm_filename is not None
|
||||||
|
|
||||||
|
async def convert_to_pcm(self) -> None:
|
||||||
|
"""Convert the file to pcm with :mod:`ffmpeg`."""
|
||||||
|
if ffmpeg is None:
|
||||||
|
raise ImportError("'bard' extra is not installed")
|
||||||
|
await self.ytdl_file.download_file()
|
||||||
|
if self.pcm_filename is None:
|
||||||
|
async with self.ytdl_file.lock.normal():
|
||||||
|
destination_filename = re.sub(r"\.[^.]+$", ".pcm", self.ytdl_file.filename)
|
||||||
|
async with self.lock.exclusive():
|
||||||
|
await asyncify(
|
||||||
|
ffmpeg.input(self.ytdl_file.filename)
|
||||||
|
.output(destination_filename, format="s16le", ac=2, ar="48000")
|
||||||
|
.overwrite_output()
|
||||||
|
.run
|
||||||
|
)
|
||||||
|
self.pcm_filename = destination_filename
|
||||||
|
|
||||||
|
async def delete_asap(self) -> None:
|
||||||
|
"""Delete the mp3 file."""
|
||||||
|
if self.is_converted:
|
||||||
|
async with self.lock.exclusive():
|
||||||
|
os.remove(self.pcm_filename)
|
||||||
|
self.pcm_filename = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_url(cls, url, **ytdl_args) -> typing.List["YtdlDiscord"]:
|
||||||
|
"""Create a :class:`list` of :class:`YtdlMp3` from a URL."""
|
||||||
|
files = await YtdlFile.from_url(url, **ytdl_args)
|
||||||
|
dfiles = []
|
||||||
|
for file in files:
|
||||||
|
dfile = YtdlDiscord(file)
|
||||||
|
dfiles.append(dfile)
|
||||||
|
return dfiles
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self) -> typing.Optional[YtdlInfo]:
|
||||||
|
"""Shortcut to get the :class:`YtdlInfo` of the object."""
|
||||||
|
return self.ytdl_file.info
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def spawn_audiosource(self):
|
||||||
|
if discord is None:
|
||||||
|
raise ImportError("'discord' extra is not installed")
|
||||||
|
await self.convert_to_pcm()
|
||||||
|
with open(self.pcm_filename, "rb") as stream:
|
||||||
|
yield FileAudioSource(stream)
|
|
@ -1,14 +1,18 @@
|
||||||
import typing
|
import typing
|
||||||
import re
|
import re
|
||||||
import ffmpeg
|
|
||||||
import os
|
import os
|
||||||
from royalnet.utils import asyncify, MultiLock
|
from royalnet.utils import asyncify, MultiLock
|
||||||
from .ytdlinfo import YtdlInfo
|
from .ytdlinfo import YtdlInfo
|
||||||
from .ytdlfile import YtdlFile
|
from .ytdlfile import YtdlFile
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ffmpeg
|
||||||
|
except ImportError:
|
||||||
|
ffmpeg = None
|
||||||
|
|
||||||
|
|
||||||
class YtdlMp3:
|
class YtdlMp3:
|
||||||
"""A representation of a YtdlFile conversion to mp3."""
|
"""A representation of a :class:`YtdlFile` conversion to mp3."""
|
||||||
def __init__(self, ytdl_file: YtdlFile):
|
def __init__(self, ytdl_file: YtdlFile):
|
||||||
self.ytdl_file: YtdlFile = ytdl_file
|
self.ytdl_file: YtdlFile = ytdl_file
|
||||||
self.mp3_filename: typing.Optional[str] = None
|
self.mp3_filename: typing.Optional[str] = None
|
||||||
|
@ -20,7 +24,9 @@ class YtdlMp3:
|
||||||
return self.mp3_filename is not None
|
return self.mp3_filename is not None
|
||||||
|
|
||||||
async def convert_to_mp3(self) -> None:
|
async def convert_to_mp3(self) -> None:
|
||||||
"""Convert the file to mp3 with ``ffmpeg``."""
|
"""Convert the file to mp3 with :mod:`ffmpeg`."""
|
||||||
|
if ffmpeg is None:
|
||||||
|
raise ImportError("'bard' extra is not installed")
|
||||||
await self.ytdl_file.download_file()
|
await self.ytdl_file.download_file()
|
||||||
if self.mp3_filename is None:
|
if self.mp3_filename is None:
|
||||||
async with self.ytdl_file.lock.normal():
|
async with self.ytdl_file.lock.normal():
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .create_rich_embed import create_rich_embed
|
from .createrichembed import create_rich_embed
|
||||||
from .escape import escape
|
from .escape import escape
|
||||||
from .discordserf import DiscordSerf
|
from .discordserf import DiscordSerf
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class DiscordSerf(Serf):
|
||||||
interface: CommandInterface,
|
interface: CommandInterface,
|
||||||
session,
|
session,
|
||||||
loop: asyncio.AbstractEventLoop,
|
loop: asyncio.AbstractEventLoop,
|
||||||
message: discord.Message):
|
message: "discord.Message"):
|
||||||
super().__init__(interface=interface, session=session, loop=loop)
|
super().__init__(interface=interface, session=session, loop=loop)
|
||||||
data.message = message
|
data.message = message
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ class DiscordSerf(Serf):
|
||||||
await data.message.channel.send(escape(text))
|
await data.message.channel.send(escape(text))
|
||||||
|
|
||||||
async def get_author(data, error_if_none=False):
|
async def get_author(data, error_if_none=False):
|
||||||
user: discord.Member = data.message.author
|
user: "discord.Member" = data.message.author
|
||||||
query = data.session.query(self._master_table)
|
query = data.session.query(self._master_table)
|
||||||
for link in self._identity_chain:
|
for link in self._identity_chain:
|
||||||
query = query.join(link.mapper.class_)
|
query = query.join(link.mapper.class_)
|
||||||
|
@ -90,7 +90,7 @@ class DiscordSerf(Serf):
|
||||||
|
|
||||||
return DiscordData
|
return DiscordData
|
||||||
|
|
||||||
async def handle_message(self, message: discord.Message):
|
async def handle_message(self, message: "discord.Message"):
|
||||||
"""Handle a Discord message by calling a command if appropriate."""
|
"""Handle a Discord message by calling a command if appropriate."""
|
||||||
text = message.content
|
text = message.content
|
||||||
# Skip non-text messages
|
# Skip non-text messages
|
||||||
|
@ -100,7 +100,7 @@ class DiscordSerf(Serf):
|
||||||
if not text.startswith("!"):
|
if not text.startswith("!"):
|
||||||
return
|
return
|
||||||
# Skip bot messages
|
# Skip bot messages
|
||||||
author: Union[discord.User] = message.author
|
author: Union["discord.User"] = message.author
|
||||||
if author.bot:
|
if author.bot:
|
||||||
return
|
return
|
||||||
# Find and clean parameters
|
# Find and clean parameters
|
||||||
|
@ -129,11 +129,11 @@ class DiscordSerf(Serf):
|
||||||
if session is not None:
|
if session is not None:
|
||||||
await asyncify(session.close)
|
await asyncify(session.close)
|
||||||
|
|
||||||
def bot_factory(self) -> Type[discord.Client]:
|
def bot_factory(self) -> Type["discord.Client"]:
|
||||||
"""Create a custom class inheriting from :py:class:`discord.Client`."""
|
"""Create a custom class inheriting from :py:class:`discord.Client`."""
|
||||||
# noinspection PyMethodParameters
|
# noinspection PyMethodParameters
|
||||||
class DiscordClient(discord.Client):
|
class DiscordClient(discord.Client):
|
||||||
async def on_message(cli, message: discord.Message):
|
async def on_message(cli, message: "discord.Message"):
|
||||||
"""Handle messages received by passing them to the handle_message method of the bot."""
|
"""Handle messages received by passing them to the handle_message method of the bot."""
|
||||||
# TODO: keep reference to these tasks somewhere
|
# TODO: keep reference to these tasks somewhere
|
||||||
self.loop.create_task(self.handle_message(message))
|
self.loop.create_task(self.handle_message(message))
|
||||||
|
@ -144,6 +144,13 @@ class DiscordSerf(Serf):
|
||||||
|
|
||||||
return DiscordClient
|
return DiscordClient
|
||||||
|
|
||||||
|
def get_voice_client(self, guild: "discord.Guild") -> Optional["discord.VoiceClient"]:
|
||||||
|
voice_clients: List["discord.VoiceClient"] = self.client.voice_clients
|
||||||
|
for voice_client in voice_clients:
|
||||||
|
if voice_client.guild == guild:
|
||||||
|
return voice_client
|
||||||
|
return None
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await super().run()
|
await super().run()
|
||||||
token = self.get_secret("discord")
|
token = self.get_secret("discord")
|
||||||
|
|
40
royalnet/serf/discord/fileaudiosource.py
Normal file
40
royalnet/serf/discord/fileaudiosource.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
try:
|
||||||
|
import discord
|
||||||
|
except ImportError:
|
||||||
|
discord = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileAudioSource(discord.AudioSource):
|
||||||
|
"""A :py:class:`discord.AudioSource` that uses a :py:class:`io.BufferedIOBase` as an input instead of memory.
|
||||||
|
|
||||||
|
The stream should be in the usual PCM encoding.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This AudioSource will consume (and close) the passed stream."""
|
||||||
|
|
||||||
|
def __init__(self, file):
|
||||||
|
self.file = file
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.file.seekable():
|
||||||
|
return f"<{self.__class__.__name__} @{self.file.tell()}>"
|
||||||
|
else:
|
||||||
|
return f"<{self.__class__.__name__}>"
|
||||||
|
|
||||||
|
def is_opus(self):
|
||||||
|
"""This audio file isn't Opus-encoded, but PCM-encoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``False``."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
"""Reads 20ms worth of audio.
|
||||||
|
|
||||||
|
If the stream has ended, then return an empty :py:class:`bytes`-like object."""
|
||||||
|
data: bytes = self.file.read(discord.opus.Encoder.FRAME_SIZE)
|
||||||
|
# If there is no more data to be streamed
|
||||||
|
if len(data) != discord.opus.Encoder.FRAME_SIZE:
|
||||||
|
# Return that the stream has ended
|
||||||
|
return b""
|
||||||
|
return data
|
Loading…
Reference in a new issue