Skip to content

Commit

Permalink
Address latest comments
Browse files Browse the repository at this point in the history
  • Loading branch information
newcomertv committed May 11, 2024
1 parent 791672b commit a85b94b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 66 deletions.
6 changes: 3 additions & 3 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -1207,16 +1207,16 @@ Python::with_gil(|py| {
assert isinstance(square, cls)
assert isinstance(square, cls.RegularPolygon)
assert square._0 == 4
assert square._1 == 10.0
assert square[0] == 4 # Gets _0 field
assert square[1] == 10.0 # Gets _1 field
def count_vertices(cls, shape):
match shape:
case cls.Circle():
return 0
case cls.Rectangle():
return 4
case cls.RegularPolygon(_0=n):
case cls.RegularPolygon(n):
return n
case cls.Nothing():
return 0
Expand Down
116 changes: 56 additions & 60 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -975,19 +975,19 @@ fn impl_complex_enum(
Ok(quote! {
#pytypeinfo

#pyclass_impls
#pyclass_impls

#[doc(hidden)]
#[allow(non_snake_case)]
impl #cls {}
#[doc(hidden)]
#[allow(non_snake_case)]
impl #cls {}

#(#variant_cls_zsts)*
#(#variant_cls_zsts)*

#(#variant_cls_pytypeinfos)*
#(#variant_cls_pytypeinfos)*

#(#variant_cls_pyclass_impls)*
#(#variant_cls_pyclass_impls)*

#(#variant_cls_impls)*
#(#variant_cls_impls)*
})
}

Expand All @@ -1006,6 +1006,36 @@ fn impl_complex_enum_variant_cls(
}
}

fn impl_complex_enum_variant_match_args(
ctx: &Ctx,
variant_cls_type: &syn::Type,
field_names: &mut Vec<Ident>,
) -> (MethodAndMethodDef, syn::ImplItemConst) {
let match_args_const_impl: syn::ImplItemConst = {
let args_tp = field_names.iter().map(|_| {
quote! { &'static str }
});
parse_quote! {
const __match_args__: ( #(#args_tp,)* ) = (
#(stringify!(#field_names),)*
);
}
};

let spec = ConstSpec {
rust_ident: format_ident!("__match_args__"),
attributes: ConstAttributes {
is_class_attr: true,
name: None,
deprecations: Deprecations::new(ctx),
},
};

let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);

(variant_match_args, match_args_const_impl)
}

fn impl_complex_enum_struct_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumStructVariant<'_>,
Expand Down Expand Up @@ -1043,6 +1073,11 @@ fn impl_complex_enum_struct_variant_cls(
field_getter_impls.push(field_getter_impl);
}

let (variant_match_args, match_args_const_impl) =
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);

field_getters.push(variant_match_args);

let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
Expand All @@ -1052,6 +1087,8 @@ fn impl_complex_enum_struct_variant_cls(
#pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
}

#match_args_const_impl

#(#field_getter_impls)*
}
};
Expand Down Expand Up @@ -1171,52 +1208,6 @@ fn impl_complex_enum_tuple_variant_getitem(
Ok((variant_getitem, get_item_method_impl))
}

fn impl_complex_enum_tuple_variant_match_args(
ctx: &Ctx,
variant_cls_type: &syn::Type,
field_names: &mut Vec<Ident>,
) -> (MethodAndMethodDef, syn::ImplItemConst) {
let match_args_const_impl: syn::ImplItemConst = match field_names.len() {
// This covers the case where the tuple variant has no fields (valid Rust)
0 => parse_quote! {
const __match_args__: () = ();
},
1 => {
let ident = &field_names[0];
// We need the trailing comma to make it a tuple
parse_quote! {
const __match_args__: (&'static str ,) = (stringify!(#ident) , );
}
}
_ => {
let args_tp = field_names.iter().map(|_| {
quote! { &'static str }
});
parse_quote! {
const __match_args__: ( #(#args_tp),* ) = (
#(stringify!(#field_names),)*
);
}
}
};

let spec = ConstSpec {
rust_ident: format_ident!("__match_args__"),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute {
kw: syn::parse_quote! { name },
value: NameLitStr(format_ident!("__match_args__")),
}),
deprecations: Deprecations::new(ctx),
},
};

let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);

(variant_match_args, match_args_const_impl)
}

fn impl_complex_enum_tuple_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumTupleVariant<'_>,
Expand Down Expand Up @@ -1256,7 +1247,7 @@ fn impl_complex_enum_tuple_variant_cls(
slots.push(variant_getitem);

let (variant_match_args, match_args_method_impl) =
impl_complex_enum_tuple_variant_match_args(ctx, &variant_cls_type, &mut field_names);
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);

field_getters.push(variant_match_args);

Expand Down Expand Up @@ -1477,10 +1468,6 @@ fn complex_enum_tuple_variant_new<'a>(
let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>);

let args = {
let mut no_pyo3_attrs = vec![];
let _attrs =
crate::pyfunction::PyFunctionArgPyO3Attributes::from_attrs(&mut no_pyo3_attrs)?;

let mut args = vec![FnArg::Py(PyArg {
name: &arg_py_ident,
ty: &arg_py_type,
Expand All @@ -1497,7 +1484,16 @@ fn complex_enum_tuple_variant_new<'a>(
}
args
};
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;

let signature = if let Some(constructor) = variant.options.constructor {
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
args,
constructor.into_signature(),
)?
} else {
crate::pyfunction::FunctionSignature::from_arguments(args)?
};

let spec = FnSpec {
tp: crate::method::FnType::FnNew,
name: &format_ident!("__pymethod_constructor__"),
Expand Down
3 changes: 3 additions & 0 deletions pytests/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ enum SimpleTupleEnum {

#[pyclass]
pub enum TupleEnum {
#[pyo3(constructor = (_0 = 1, _1 = 1.0, _2 = true))]
FullWithDefault(i32, f64, bool),
Full(i32, f64, bool),
EmptyTuple(),
}

#[pyfunction]
pub fn do_tuple_stuff(thing: &TupleEnum) -> TupleEnum {
match thing {
TupleEnum::FullWithDefault(a, b, c) => TupleEnum::FullWithDefault(*a, *b, *c),
TupleEnum::Full(a, b, c) => TupleEnum::Full(*a, *b, *c),
TupleEnum::EmptyTuple() => TupleEnum::EmptyTuple(),
}
Expand Down
8 changes: 8 additions & 0 deletions pytests/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_tuple_enum_variant_constructors():
@pytest.mark.parametrize(
"variant",
[
enums.TupleEnum.FullWithDefault(),
enums.TupleEnum.Full(42, 3.14, False),
enums.TupleEnum.EmptyTuple(),
],
Expand All @@ -158,6 +159,13 @@ def test_tuple_enum_variant_subclasses(variant: enums.TupleEnum):
assert isinstance(variant, enums.TupleEnum)


def test_tuple_enum_defaults():
variant = enums.TupleEnum.FullWithDefault()
assert variant._0 == 1
assert variant._1 == 1.0
assert variant._2 is True


def test_tuple_enum_field_getters():
tuple_variant = enums.TupleEnum.Full(42, 3.14, False)
assert tuple_variant._0 == 42
Expand Down
3 changes: 0 additions & 3 deletions pytests/tests/test_enums_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum):
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
@pytest.mark.skip(
reason="__match_args__ is not supported for struct enums yet. TODO : Open an issue"
)
def test_complex_enum_partial_match(variant: enums.ComplexEnum):
match variant:
case enums.ComplexEnum.MultiFieldStruct(a):
Expand Down

0 comments on commit a85b94b

Please sign in to comment.