Skip to content

Commit

Permalink
refs #4286 -- allow setting submodule on declarative pymodules
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Jun 29, 2024
1 parent 8f7450e commit d69a102
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
4 changes: 3 additions & 1 deletion guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pycl
For nested modules, the name of the parent module is automatically added.
In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested
but the `Ext` class will have for `module` the default `builtins` because it not nested.

You can provide the `submodule` argument to `pymodule()` for modules that are not top-level modules.
```rust
# mod declarative_module_module_attr_test {
use pyo3::prelude::*;
Expand All @@ -168,7 +170,7 @@ mod my_extension {
#[pymodule_export]
use super::Ext;

#[pymodule]
#[pymodule(submodule)]
mod submodule {
use super::*;
// This is a submodule
Expand Down
32 changes: 19 additions & 13 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl PyModuleOptions {
}
}

pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
pub fn pymodule_module_impl(mut module: syn::ItemMod, is_submodule: bool) -> Result<TokenStream> {
let syn::ItemMod {
attrs,
vis,
Expand Down Expand Up @@ -286,7 +286,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
}

let initialization = module_initialization(&name, ctx);
let initialization = module_initialization(&name, ctx, is_submodule);
Ok(quote!(
#(#attrs)*
#vis mod #ident {
Expand Down Expand Up @@ -335,7 +335,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let vis = &function.vis;
let doc = get_doc(&function.attrs, None, ctx);

let initialization = module_initialization(&name, ctx);
let initialization = module_initialization(&name, ctx, false);

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
Expand Down Expand Up @@ -400,28 +400,34 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
})
}

fn module_initialization(name: &syn::Ident, ctx: &Ctx) -> TokenStream {
fn module_initialization(name: &syn::Ident, ctx: &Ctx, is_submodule: bool) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx;
let pyinit_symbol = format!("PyInit_{}", name);
let name = name.to_string();
let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);

quote! {
let mut base = quote! {
#[doc(hidden)]
pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;

pub(super) struct MakeDef;
#[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = MakeDef::make_def();

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
}
};
if !is_submodule {
base = quote! {
#base

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[doc(hidden)]
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
}
};
}
base
}

/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
Expand Down
24 changes: 20 additions & 4 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Span, TokenStream as TokenStream2};
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType,
Expand Down Expand Up @@ -35,10 +35,26 @@ use syn::{parse::Nothing, parse_macro_input, Item};
/// [1]: https://pyo3.rs/latest/module.html
#[proc_macro_attribute]
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(args as Nothing);
match parse_macro_input!(input as Item) {
Item::Mod(module) => pymodule_module_impl(module),
Item::Fn(function) => pymodule_function_impl(function),
Item::Mod(module) => {
let is_submodule = match parse_macro_input!(args as Option<syn::Ident>) {
Some(i) if i == "submodule" => true,
Some(_) => {
return syn::Error::new(
Span::call_site(),
"#[pymodule] only accepts submodule as an argument",
)
.into_compile_error()
.into();
}
None => false,
};
pymodule_module_impl(module, is_submodule)
}
Item::Fn(function) => {
parse_macro_input!(args as Nothing);
pymodule_function_impl(function)
}
unsupported => Err(syn::Error::new_spanned(
unsupported,
"#[pymodule] only supports modules and functions.",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ mod declarative_module {
}
}

#[pymodule]
#[pymodule(submodule)]
#[pyo3(module = "custom_root")]
mod inner_custom_root {
use super::*;
Expand Down

0 comments on commit d69a102

Please sign in to comment.