diff --git a/examples/group_testing/main.rs b/examples/group_testing/main.rs new file mode 100644 index 000000000000..c4a7f4ada96b --- /dev/null +++ b/examples/group_testing/main.rs @@ -0,0 +1,116 @@ +use poise::{serenity_prelude as serenity, Command, CommandGroup}; +use std::{env::var, sync::Arc, time::Duration, vec}; +// Types used by all command functions +type Error = Box; +type Context<'a> = poise::Context<'a, Data, Error>; + +// Custom user data passed to all command functions +pub struct Data {} + +// Group struct +struct Test {} + +#[poise::group(category = "Foo")] +impl Test { + // Just a test + #[poise::command(slash_command, prefix_command, rename = "test")] + async fn test_command(ctx: Context<'_>) -> Result<(), Error> { + let name = ctx.author(); + ctx.say(format!("Hello, {}", name)).await?; + Ok(()) + } +} + +// Handlers +async fn on_error(error: poise::FrameworkError<'_, Data, Error>) { + // This is our custom error handler + // They are many errors that can occur, so we only handle the ones we want to customize + // and forward the rest to the default handler + match error { + poise::FrameworkError::Setup { error, .. } => panic!("Failed to start bot: {:?}", error), + poise::FrameworkError::Command { error, ctx, .. } => { + println!("Error in command `{}`: {:?}", ctx.command().name, error,); + } + error => { + if let Err(e) = poise::builtins::on_error(error).await { + println!("Error while handling error: {}", e) + } + } + } +} + +#[tokio::main] +async fn main() { + // FrameworkOptions contains all of poise's configuration option in one struct + // Every option can be omitted to use its default value + // println!("{:#?}", Test::commands()); + let commands: Vec> = Test::commands(); + + let options = poise::FrameworkOptions { + commands: commands, + prefix_options: poise::PrefixFrameworkOptions { + prefix: Some("--".into()), + edit_tracker: Some(Arc::new(poise::EditTracker::for_timespan( + Duration::from_secs(3600), + ))), + ..Default::default() + }, + // The global error handler for all error cases that may occur + on_error: |error| Box::pin(on_error(error)), + // This code is run before every command + pre_command: |ctx| { + Box::pin(async move { + println!("Executing command {}...", ctx.command().qualified_name); + }) + }, + // This code is run after a command if it was successful (returned Ok) + post_command: |ctx| { + Box::pin(async move { + println!("Executed command {}!", ctx.command().qualified_name); + }) + }, + // Every command invocation must pass this check to continue execution + command_check: Some(|ctx| { + Box::pin(async move { + if ctx.author().id == 123456789 { + return Ok(false); + } + Ok(true) + }) + }), + // Enforce command checks even for owners (enforced by default) + // Set to true to bypass checks, which is useful for testing + skip_checks_for_owners: false, + event_handler: |_ctx, event, _framework, _data| { + Box::pin(async move { + println!( + "Got an event in event handler: {:?}", + event.snake_case_name() + ); + Ok(()) + }) + }, + ..Default::default() + }; + + let framework = poise::Framework::builder() + .setup(move |ctx, _ready, framework| { + Box::pin(async move { + println!("Logged in as {}", _ready.user.name); + poise::builtins::register_globally(ctx, &framework.options().commands).await?; + Ok(Data {}) + }) + }) + .options(options) + .build(); + let token = var("DISCORD_TOKEN") + .expect("Missing `DISCORD_TOKEN` env var, see README for more information."); + let intents = + serenity::GatewayIntents::non_privileged() | serenity::GatewayIntents::MESSAGE_CONTENT; + + let client = serenity::ClientBuilder::new(token, intents) + .framework(framework) + .await; + + client.unwrap().start().await.unwrap() +} diff --git a/macros/src/command/mod.rs b/macros/src/command/mod.rs index 8c7f8bbd6136..ca1df93807c6 100644 --- a/macros/src/command/mod.rs +++ b/macros/src/command/mod.rs @@ -57,6 +57,227 @@ pub struct CommandArgs { member_cooldown: Option, } +impl CommandArgs { + // Check if a field has the default value + fn is_default(value: &T) -> bool { + value == &T::default() + } + + // create a new CommandArgs from self, with default fields replaced by value from GroupArgs + pub fn from_group_args(&self, group_args: &GroupArgs) -> CommandArgs { + CommandArgs { + prefix_command: if Self::is_default(&self.prefix_command) { + group_args.prefix_command + } else { + self.prefix_command + }, + slash_command: if Self::is_default(&self.slash_command) { + group_args.slash_command + } else { + self.slash_command + }, + context_menu_command: if Self::is_default(&self.context_menu_command) { + group_args.context_menu_command.clone() + } else { + self.context_menu_command.clone() + }, + subcommands: self.subcommands.clone(), // `GroupArgs` doesn't have `subcommands` + aliases: self.aliases.clone(), // `GroupArgs` doesn't have `aliases` + subcommand_required: self.subcommand_required, // `GroupArgs` doesn't have `subcommand_required` + invoke_on_edit: if Self::is_default(&self.invoke_on_edit) { + group_args.invoke_on_edit + } else { + self.invoke_on_edit + }, + reuse_response: if Self::is_default(&self.reuse_response) { + group_args.reuse_response + } else { + self.reuse_response + }, + track_deletion: if Self::is_default(&self.track_deletion) { + group_args.track_deletion + } else { + self.track_deletion + }, + track_edits: if Self::is_default(&self.track_edits) { + group_args.track_edits + } else { + self.track_edits + }, + broadcast_typing: if Self::is_default(&self.broadcast_typing) { + group_args.broadcast_typing + } else { + self.broadcast_typing + }, + help_text_fn: if Self::is_default(&self.help_text_fn) { + group_args.help_text_fn.clone() + } else { + self.help_text_fn.clone() + }, + check: if Self::is_default(&self.check) { + group_args.check.clone() + } else { + self.check.clone() + }, + on_error: if Self::is_default(&self.on_error) { + group_args.on_error.clone() + } else { + self.on_error.clone() + }, + rename: self.rename.clone(), // `GroupArgs` doesn't have `rename` + name_localized: if Self::is_default(&self.name_localized) { + group_args.name_localized.clone() + } else { + self.name_localized.clone() + }, + description_localized: if Self::is_default(&self.description_localized) { + group_args.description_localized.clone() + } else { + self.description_localized.clone() + }, + discard_spare_arguments: if Self::is_default(&self.discard_spare_arguments) { + group_args.discard_spare_arguments + } else { + self.discard_spare_arguments + }, + hide_in_help: if Self::is_default(&self.hide_in_help) { + group_args.hide_in_help + } else { + self.hide_in_help + }, + ephemeral: if Self::is_default(&self.ephemeral) { + group_args.ephemeral + } else { + self.ephemeral + }, + default_member_permissions: if Self::is_default(&self.default_member_permissions) { + group_args.default_member_permissions.clone() + } else { + self.default_member_permissions.clone() + }, + required_permissions: if Self::is_default(&self.required_permissions) { + group_args.required_permissions.clone() + } else { + self.required_permissions.clone() + }, + required_bot_permissions: if Self::is_default(&self.required_bot_permissions) { + group_args.required_bot_permissions.clone() + } else { + self.required_bot_permissions.clone() + }, + owners_only: if Self::is_default(&self.owners_only) { + group_args.owners_only + } else { + self.owners_only + }, + guild_only: if Self::is_default(&self.guild_only) { + group_args.guild_only + } else { + self.guild_only + }, + dm_only: if Self::is_default(&self.dm_only) { + group_args.dm_only + } else { + self.dm_only + }, + nsfw_only: if Self::is_default(&self.nsfw_only) { + group_args.nsfw_only + } else { + self.nsfw_only + }, + identifying_name: self.identifying_name.clone(), // `GroupArgs` doesn't have `identifying_name` + category: if Self::is_default(&self.category) { + group_args.category.clone() + } else { + self.category.clone() + }, + custom_data: if Self::is_default(&self.custom_data) { + group_args.custom_data.clone() + } else { + self.custom_data.clone() + }, + global_cooldown: if Self::is_default(&self.global_cooldown) { + group_args.global_cooldown + } else { + self.global_cooldown + }, + user_cooldown: if Self::is_default(&self.user_cooldown) { + group_args.user_cooldown + } else { + self.user_cooldown + }, + guild_cooldown: if Self::is_default(&self.guild_cooldown) { + group_args.guild_cooldown + } else { + self.guild_cooldown + }, + channel_cooldown: if Self::is_default(&self.channel_cooldown) { + group_args.channel_cooldown + } else { + self.channel_cooldown + }, + member_cooldown: if Self::is_default(&self.member_cooldown) { + group_args.member_cooldown + } else { + self.member_cooldown + }, + } + } +} + +/// Representation of the group attribute arguments (`#[group(...)]`) +/// +/// Same as `CommandArgs`, but with the following removed: +/// - subcommands +/// - aliases +/// - subcommand_required +/// - rename +/// - identifying_name +/// +#[derive(Default, Debug, darling::FromMeta)] +#[darling(default)] +pub struct GroupArgs { + prefix_command: bool, + slash_command: bool, + context_menu_command: Option, + + // When changing these, document it in parent file! + // TODO: decide why darling(multiple) feels wrong here but not in e.g. localizations (because + // if it's actually irrational, the inconsistency should be fixed) + invoke_on_edit: bool, + reuse_response: bool, + track_deletion: bool, + track_edits: bool, + broadcast_typing: bool, + help_text_fn: Option, + #[darling(multiple)] + check: Vec, + on_error: Option, + #[darling(multiple)] + name_localized: Vec>, + #[darling(multiple)] + description_localized: Vec>, + discard_spare_arguments: bool, + hide_in_help: bool, + ephemeral: bool, + default_member_permissions: Option>, + required_permissions: Option>, + required_bot_permissions: Option>, + owners_only: bool, + guild_only: bool, + dm_only: bool, + nsfw_only: bool, + category: Option, + custom_data: Option, + + // In seconds + global_cooldown: Option, + user_cooldown: Option, + guild_cooldown: Option, + channel_cooldown: Option, + member_cooldown: Option, +} + /// Representation of the function parameter attribute arguments #[derive(Default, Debug, darling::FromMeta)] #[darling(default)] diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 5316651af233..feff53f55b3f 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -8,6 +8,8 @@ mod modal; mod util; use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::spanned::Spanned; /** This macro transforms plain functions into poise bot commands. @@ -277,3 +279,146 @@ pub fn modal(input: TokenStream) -> TokenStream { Err(e) => e.write_errors().into(), } } + +/** +Use this macro on an impl block to implement the CommandGroup trait. + +It implements a `commands()` function which returns a Vec with all the commands defined by `#[poise::command]`. + +# Usage + +The following code defines a command group with two commands, +both of which will have the `slash_command` and `prefix_command` attributes. + +`command_one` will have the default user_cooldown of 1000, while `command_two` overrides it to 2000. + +```rust +struct MyCommands; + +#[poise::group(slash_command, prefix_command, user_cooldown=1000)] +impl MyCommands { + /// This is a command + #[poise::command()] + async fn command_one(ctx: Context<'_>) -> Result<(), Error> { + // code + } + + /// This is another command + #[poise::command(user_cooldown=2000)] + async fn command_two(ctx: Context<'_>) -> Result<(), Error> { + // code + } +} +``` +*/ +#[proc_macro_attribute] +pub fn group(args: TokenStream, input_item: TokenStream) -> TokenStream { + match group_impl(args, input_item) { + Ok(x) => x, + Err(err) => err.write_errors().into(), + } +} + +fn group_impl(args: TokenStream, input_item: TokenStream) -> Result { + let args = darling::ast::NestedMeta::parse_meta_list(args.into())?; + + let group_args = ::from_list(&args)?; + + // let item_impl = syn::parse_macro_input!(input_item as syn::ItemImpl); + let item_impl = match syn::parse::(input_item) { + Ok(syntax_tree) => syntax_tree, + Err(err) => return Err(err.into()), + }; + let name = item_impl.self_ty; + + // vector of all #[poise::command(...)] command idents + let mut command_idents = vec![]; + + // collect each ImplItem in a stream + let mut impl_body = proc_macro2::TokenStream::new(); + + // context type for correct type inference in CommandGroup + let mut ctx_type_with_static: Option = None; + + for item in item_impl.items.iter() { + let mut item_stream = quote!(#item); + + // if it's a function... + if let syn::ImplItem::Fn(f) = item { + // ... and it's a command + if let Some(attr) = f.attrs.iter().find(|attr| is_command_attr(attr)) { + // add to command list + command_idents.push(f.sig.ident.clone()); + + // Turn a syn::Attribute into command::CommandArgs + let attr_args = &attr.meta.require_list()?.tokens; + + let command_args = + darling::ast::NestedMeta::parse_meta_list(attr_args.to_token_stream())?; + let command_args = + ::from_list(&command_args)?; + + let new_args = command_args.from_group_args(&group_args); + let function = syn::ItemFn { + attrs: vec![], + vis: f.vis.clone(), + sig: f.sig.clone(), + block: Box::new(f.block.clone()), + }; + + if ctx_type_with_static.is_none() { + let context_type = match function.sig.inputs.first() { + Some(syn::FnArg::Typed(syn::PatType { ty, .. })) => Some(&**ty), + _ => { + return Err(syn::Error::new( + function.sig.span(), + "expected a Context parameter", + ) + .into()) + } + }; + // Needed because we're not allowed to have lifetimes in the hacky use case below (in command::mod.rs) + ctx_type_with_static = Some(syn::fold::fold_type( + &mut crate::util::AllLifetimesToStatic, + context_type + .expect("context_type should have already been set") + .clone(), + )); + } + item_stream = command::command(new_args, function)?.into(); + } + } + impl_body = quote!( + #impl_body + #item_stream + ); + } + + Ok(quote! { + impl ::poise::CommandGroup for #name { + type Data = <#ctx_type_with_static as poise::_GetGenerics>::U; + type Error = <#ctx_type_with_static as poise::_GetGenerics>::E; + + fn commands() -> Vec<::poise::Command< + <#ctx_type_with_static as poise::_GetGenerics>::U, + <#ctx_type_with_static as poise::_GetGenerics>::E, + >> { + vec![#(#name::#command_idents()),*] + } + } + impl #name { + #impl_body + } + } + .into()) +} + +/// Returns true if an `Attribute` has `path` equal to "poise::command" or "command" +fn is_command_attr(attr: &syn::Attribute) -> bool { + let mut segments = attr.path().segments.iter(); + match [segments.next(), segments.next(), segments.next()] { + [Some(first), Some(second), None] => first.ident == "poise" && second.ident == "command", + [Some(first), None, None] => first.ident == "command", + [_, _, _] => false, + } +} diff --git a/macros/src/util.rs b/macros/src/util.rs index 5683ec4bef55..1609f5383a93 100644 --- a/macros/src/util.rs +++ b/macros/src/util.rs @@ -52,7 +52,7 @@ impl syn::fold::Fold for AllLifetimesToStatic { } /// Darling utility type that accepts a list of things, e.g. `#[attr(thing1, thing2...)]` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct List(pub Vec); impl darling::FromMeta for List { fn from_list(items: &[::darling::ast::NestedMeta]) -> darling::Result { @@ -70,7 +70,7 @@ impl Default for List { } /// Darling utility type that accepts a 2-tuple list of things, e.g. `#[attr(thing1, thing2)]` -#[derive(Debug)] +#[derive(Debug, PartialEq, Clone)] pub struct Tuple2(pub T, pub T); impl darling::FromMeta for Tuple2 { fn from_list(items: &[::darling::ast::NestedMeta]) -> darling::Result { diff --git a/src/group.rs b/src/group.rs new file mode 100644 index 000000000000..427c61c96fb3 --- /dev/null +++ b/src/group.rs @@ -0,0 +1,14 @@ +//! CommandGroup trait + +/// Trait for a struct with a group of commands. +/// +/// You should not implement this yourself, but instead use the `poise::group` macro +pub trait CommandGroup { + /// User data, which is stored and accessible in all command invocations + type Data; + /// The error type of your commands + type Error; + /// Return a Vec of the `poise::commands` defined in this group + /// Automatically generated by the macro + fn commands() -> Vec>; +} diff --git a/src/lib.rs b/src/lib.rs index ef3c164d6fa9..2ca6fe228f8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -393,6 +393,7 @@ pub mod choice_parameter; pub mod cooldown; pub mod dispatch; pub mod framework; +pub mod group; pub mod modal; pub mod prefix_argument; pub mod reply; @@ -407,7 +408,7 @@ pub mod macros { #[doc(no_inline)] pub use { - choice_parameter::*, cooldown::*, dispatch::*, framework::*, macros::*, modal::*, + choice_parameter::*, cooldown::*, dispatch::*, framework::*, group::*, macros::*, modal::*, prefix_argument::*, reply::*, slash_argument::*, structs::*, track_edits::*, };