Skip to content

Commit

Permalink
Sync commands on plugin reload
Browse files Browse the repository at this point in the history
  • Loading branch information
4Kaylum committed Jan 18, 2024
1 parent 2affdf4 commit c8bf510
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 22 deletions.
15 changes: 15 additions & 0 deletions novus/ext/client/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,32 @@ def create_console(bot: client.Client) -> AsynchronousCli:

plugin_parser = ArgumentParser()
plugin_parser.add_argument("plugin")
sync_parser = ArgumentParser()
sync_parser.add_argument("to_sync", nargs=REMAINDER)
run_parser = ArgumentParser()
run_parser.add_argument("command", nargs=REMAINDER)
command_locals: dict[str, Any] = {}

async def add(reader: Any, writer: Any, plugin: str) -> None:
bot.add_plugin_file(plugin, load=True)
await bot.sync_commands(create=False, edit=False, delete=False)

async def remove(reader: Any, writer: Any, plugin: str) -> None:
bot.remove_plugin_file(plugin)
await bot.sync_commands(create=False, edit=False, delete=False)

async def reload(reader: Any, writer: Any, plugin: str) -> None:
bot.remove_plugin_file(plugin)
bot.add_plugin_file(plugin, load=True, reload_import=True)
await bot.sync_commands(create=False, edit=False, delete=False)

async def sync(reader: Any, writer: Any, to_sync: list[str]) -> None:
args = {}
if to_sync:
args = {"create": False, "edit": False, "delete": False}
for i in to_sync:
args[i] = True
await bot.sync_commands(**args)

async def run_state(reader: Any, writer: asyncio.StreamWriter, command: list[str]) -> None:
command_full = " ".join(command)
Expand Down Expand Up @@ -154,6 +167,8 @@ async def run_state(reader: Any, writer: asyncio.StreamWriter, command: list[str
"remove-plugin": (remove, plugin_parser,),
"reload-plugin": (reload, plugin_parser,),

"sync-commands": (sync, plugin_parser,),

"add": (add, plugin_parser,),
"remove": (remove, plugin_parser,),
"reload": (reload, plugin_parser,),
Expand Down
64 changes: 42 additions & 22 deletions novus/ext/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,28 +531,32 @@ async def _handle_command_sync(
self,
application_id: int,
guild_id: int | None,
commands: dict[str, Command]) -> None:
commands: dict[str, Command],
*,
create: bool = True,
edit: bool = True,
delete: bool = True) -> None:
"""
Handle a guild's list of commands being changed.
"""

# Set up our requests
state = self.state.interaction
if guild_id is None:
get = partial(state.get_global_application_commands, application_id, with_localizations=True)
create = partial(state.create_global_application_command, application_id)
edit = partial(state.edit_global_application_command, application_id)
delete = partial(state.delete_global_application_command, application_id)
bulk = partial(state.bulk_overwrite_global_application_commands, application_id)
get_ = partial(state.get_global_application_commands, application_id, with_localizations=True)
create_ = partial(state.create_global_application_command, application_id)
edit_ = partial(state.edit_global_application_command, application_id)
delete_ = partial(state.delete_global_application_command, application_id)
bulk_ = partial(state.bulk_overwrite_global_application_commands, application_id)
else:
get = partial(state.get_guild_application_commands, application_id, guild_id, with_localizations=True)
create = partial(state.create_guild_application_command, application_id, guild_id)
edit = partial(state.edit_guild_application_command, application_id, guild_id)
delete = partial(state.delete_guild_application_command, application_id, guild_id)
bulk = partial(state.bulk_overwrite_guild_application_commands, application_id, guild_id)
get_ = partial(state.get_guild_application_commands, application_id, guild_id, with_localizations=True)
create_ = partial(state.create_guild_application_command, application_id, guild_id)
edit_ = partial(state.edit_guild_application_command, application_id, guild_id)
delete_ = partial(state.delete_guild_application_command, application_id, guild_id)
bulk_ = partial(state.bulk_overwrite_guild_application_commands, application_id, guild_id)

# See what we need to do
on_server = await get()
on_server = await get_()
unchecked_local = commands.copy()
to_add: list[Command] = []
to_delete: list[int] = []
Expand All @@ -575,42 +579,51 @@ async def _handle_command_sync(
to_add = list(unchecked_local.values())

# Bulk change
if len(to_add) + len(to_delete) + len(to_edit) > int(os.getenv("NOVUS_BULK_COMMAND_LIMIT", 10)):
command_limit = int(os.getenv("NOVUS_BULK_COMMAND_LIMIT", 5))
can_bulk = edit and delete and create
command_change_count = len(to_add) + len(to_delete) + len(to_edit)
if command_change_count > command_limit and can_bulk:
local_commands = [
i.application_command._to_data()
for i in commands.values()
]
log.info("Bulk updating %s app commands in guild %s", len(local_commands), guild_id)
for dis_com in on_server:
self._commands_by_id.pop(dis_com.id, None)
on_server = await bulk(local_commands)
on_server = await bulk_(local_commands)
for dis_com in on_server:
commands[dis_com.name].add_id(guild_id, dis_com.id)
self._commands_by_id[dis_com.id] = (
self._commands[(guild_id, dis_com.name)]
)
return

# Add new command
if to_add:
if to_add and create:
for comm in to_add:
log.info("Adding app command %s in guild %s", comm, guild_id)
on_server = await create(**comm.application_command._to_data())
on_server = await create_(**comm.application_command._to_data())
comm.add_id(guild_id, on_server.id)
self._commands_by_id[on_server.id] = comm

# Delete command
if to_delete:
if to_delete and delete:
for comm in to_delete:
log.info("Deleting app command %s in guild %s", comm, guild_id)
await delete(comm)
await delete_(comm)

# Edit single command
if to_edit:
if to_edit and edit:
for id, comm in to_edit.items():
log.info("Editing app command %s %s in guild %s", id, comm, guild_id)
await edit(id, **comm.application_command._to_data())
await edit_(id, **comm.application_command._to_data())

async def sync_commands(self) -> None:
async def sync_commands(
self,
*,
create: bool = True,
edit: bool = True,
delete: bool = True) -> None:
"""
Get all commands from Discord. Determine if they all already exist. If
not, PUT them there. If so, save command IDs.
Expand Down Expand Up @@ -641,7 +654,14 @@ async def sync_commands(self) -> None:

# See which commands we have that exist already
for guild_id, commands in commands_by_guild.items():
await self._handle_command_sync(aid, guild_id, commands)
await self._handle_command_sync(
aid,
guild_id,
commands,
create=create,
edit=edit,
delete=delete,
)

async def connect(self, check_concurrency: bool = False, sleep: bool = False) -> None:
"""
Expand Down

0 comments on commit c8bf510

Please sign in to comment.