From a667be21d5887a427ae3525f9441c15a1430206b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 6 Dec 2023 19:47:24 +0200 Subject: [PATCH 1/5] test: add test of enum with mixed variants --- .../src/internals/serialize/enums/mod.rs | 16 ++++++++++ .../snapshots/mixed_with_unit_variants.snap | 31 +++++++++++++++++++ .../test_simple_structs__mixed_enum-2.snap | 7 +++++ .../test_simple_structs__mixed_enum-3.snap | 15 +++++++++ .../test_simple_structs__mixed_enum-4.snap | 7 +++++ .../test_simple_structs__mixed_enum.snap | 9 ++++++ borsh/tests/test_simple_structs.rs | 27 ++++++++++++++++ 7 files changed, 112 insertions(+) create mode 100644 borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap create mode 100644 borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap create mode 100644 borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap create mode 100644 borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap create mode 100644 borsh/tests/snapshots/test_simple_structs__mixed_enum.snap diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 4d7114407..0499d8fce 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -425,4 +425,20 @@ mod tests { local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); } + + #[test] + fn mixed_with_unit_variants() { + let item_enum: ItemEnum = syn::parse2(quote! { + enum X { + A(u16), + B, + C {x: i32, y: i32}, + D, + } + }) + .unwrap(); + let actual = process(&item_enum, default_cratename()).unwrap(); + + local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap()); + } } diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap new file mode 100644 index 000000000..2f3682d61 --- /dev/null +++ b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap @@ -0,0 +1,31 @@ +--- +source: borsh-derive/src/internals/serialize/enums/mod.rs +expression: pretty_print_syn_str(&actual).unwrap() +--- +impl borsh::ser::BorshSerialize for X { + fn serialize( + &self, + writer: &mut W, + ) -> ::core::result::Result<(), borsh::io::Error> { + let variant_idx: u8 = match self { + X::A(..) => 0u8, + X::B => 1u8, + X::C { .. } => 2u8, + X::D => 3u8, + }; + writer.write_all(&variant_idx.to_le_bytes())?; + match self { + X::A(id0) => { + borsh::BorshSerialize::serialize(id0, writer)?; + } + X::B => {} + X::C { x, y, .. } => { + borsh::BorshSerialize::serialize(x, writer)?; + borsh::BorshSerialize::serialize(y, writer)?; + } + X::D => {} + } + Ok(()) + } +} + diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap new file mode 100644 index 000000000..e70b0c847 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-2.snap @@ -0,0 +1,7 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 1, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap new file mode 100644 index 000000000..121d21f21 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-3.snap @@ -0,0 +1,15 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 2, + 132, + 0, + 0, + 0, + 239, + 255, + 255, + 255, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap new file mode 100644 index 000000000..130ef3f52 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum-4.snap @@ -0,0 +1,7 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 3, +] diff --git a/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap b/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap new file mode 100644 index 000000000..ac859bf65 --- /dev/null +++ b/borsh/tests/snapshots/test_simple_structs__mixed_enum.snap @@ -0,0 +1,9 @@ +--- +source: borsh/tests/test_simple_structs.rs +expression: encoded +--- +[ + 0, + 13, + 0, +] diff --git a/borsh/tests/test_simple_structs.rs b/borsh/tests/test_simple_structs.rs index 647bf3fe3..5385e2a8c 100644 --- a/borsh/tests/test_simple_structs.rs +++ b/borsh/tests/test_simple_structs.rs @@ -223,3 +223,30 @@ fn test_object_length() { assert_eq!(encoded_a_len, len_helper_result); } + +#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)] +enum MixedWithUnitVariants { + A(u16), + B, + C { x: i32, y: i32 }, + D, +} + +#[test] +fn test_mixed_enum() { + let vars = vec![ + MixedWithUnitVariants::A(13), + MixedWithUnitVariants::B, + MixedWithUnitVariants::C { x: 132, y: -17 }, + MixedWithUnitVariants::D, + ]; + for variant in vars { + let encoded = to_vec(&variant).unwrap(); + #[cfg(feature = "std")] + insta::assert_debug_snapshot!(encoded); + + let decoded = from_slice::(&encoded).unwrap(); + + assert_eq!(variant, decoded); + } +} From fd1d1172bb217db7bdb6b349d41c30365f17305b Mon Sep 17 00:00:00 2001 From: Mateusz Kowalczyk Date: Wed, 6 Dec 2023 11:09:20 +0900 Subject: [PATCH 2/5] Do not produce useless match cases unit enums Given enums like these: ```rust enum Foo { A(u16), B, C(i32, i32), } enum Bar { A, B, C, } ``` serialise derive produces something like these: ```rust // Recursive expansion of borsh::BorshSerialize macro // =================================================== impl borsh::ser::BorshSerialize for Foo { fn serialize( &self, writer: &mut W, ) -> ::core::result::Result<(), borsh::io::Error> { let variant_idx: u8 = match self { Foo::A(..) => 0u8, Foo::B => 1u8, Foo::C(..) => 2u8, }; writer.write_all(&variant_idx.to_le_bytes())?; match self { Foo::A(id0) => { borsh::BorshSerialize::serialize(id0, writer)?; } Foo::B => {} Foo::C(id0, id1) => { borsh::BorshSerialize::serialize(id0, writer)?; borsh::BorshSerialize::serialize(id1, writer)?; } } Ok(()) } } // Recursive expansion of borsh::BorshSerialize macro // =================================================== impl borsh::ser::BorshSerialize for Bar { fn serialize( &self, writer: &mut W, ) -> ::core::result::Result<(), borsh::io::Error> { let variant_idx: u8 = match self { Bar::A => 0u8, Bar::B => 1u8, Bar::C => 2u8, }; writer.write_all(&variant_idx.to_le_bytes())?; match self { Bar::A => {} Bar::B => {} Bar::C => {} } Ok(()) } } ``` Notably in `Bar` case, the whole `match self` is useless because there's nothing left to serialise. With this patch, the derives now look like this: ```rust // Recursive expansion of borsh::BorshSerialize macro // =================================================== impl borsh::ser::BorshSerialize for Foo { fn serialize( &self, writer: &mut W, ) -> ::core::result::Result<(), borsh::io::Error> { let variant_idx: u8 = match self { Foo::A(..) => 0u8, Foo::B => 1u8, Foo::C(..) => 2u8, }; writer.write_all(&variant_idx.to_le_bytes())?; match self { Foo::A(id0) => { borsh::BorshSerialize::serialize(id0, writer)?; } Foo::C(id0, id1) => { borsh::BorshSerialize::serialize(id0, writer)?; borsh::BorshSerialize::serialize(id1, writer)?; } _ => {} } Ok(()) } } // Recursive expansion of borsh::BorshSerialize macro // =================================================== impl borsh::ser::BorshSerialize for Bar { fn serialize( &self, writer: &mut W, ) -> ::core::result::Result<(), borsh::io::Error> { let variant_idx: u8 = match self { Bar::A => 0u8, Bar::B => 1u8, Bar::C => 2u8, }; writer.write_all(&variant_idx.to_le_bytes())?; Ok(()) } } ``` Notably, the whole `match self` is gone for `Bar`. For `Foo`, any unit field cases are now inside a `_ => {}` catch-all. What's the point? Well, it's just nice to produce less for compiler to deal with, reducing upstream build times. Further, it makes it much nicer for anyone reading/copying the expanded macros. However patch this does add some amount of code complexity so it's up to the maintainers to decide if it's worth taking. --- .../src/internals/serialize/enums/mod.rs | 118 ++++++++++++------ .../snapshots/borsh_discriminant_false.snap | 8 -- .../snapshots/borsh_discriminant_true.snap | 8 -- 3 files changed, 79 insertions(+), 55 deletions(-) diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 0499d8fce..c1ed0dc03 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -18,6 +18,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut fields_body = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; let discriminants = Discriminants::new(&input.variants); + let mut blank_variants = false; for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_ident = &variant.ident; @@ -30,15 +31,42 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { &mut generics_output, )?; all_variants_idx_body.extend(variant_output.variant_idx_body); - let (variant_header, variant_body) = (variant_output.header, variant_output.body); - fields_body.extend(quote!( - #enum_ident::#variant_ident #variant_header => { - #variant_body - } - )) + match variant_output.body { + VariantBody::Blank => blank_variants = true, + VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!( + #enum_ident::#variant_ident #header => { + #body + } + )), + } } generics_output.extend(&mut where_clause, &cratename); + let fields_match = if fields_body.is_empty() { + // If we no variants with fields, there's nothing to match against. Just + // re-use the empty token stream. + fields_body + } else { + let unit_fields_catchall = if blank_variants { + // We had some variants with unit fields, create a catch-all for + // these to be used at the bottom. + quote!( + _ => {} + ) + } else { + TokenStream2::new() + }; + // Create a match that serialises all the fields for each non-unit + // variant and add a catch-all at the bottom if we do have unit + // variants. + quote!( + match self { + #fields_body + #unit_fields_catchall + } + ) + }; + Ok(quote! { impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause { fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::io::Error> { @@ -47,29 +75,29 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - #fields_body - } + #fields_match Ok(()) } } }) } -struct VariantOutput { +#[derive(Default)] +struct VariantFields { header: TokenStream2, body: TokenStream2, - variant_idx_body: TokenStream2, } -impl VariantOutput { - fn new() -> Self { - Self { - body: TokenStream2::new(), - header: TokenStream2::new(), - variant_idx_body: TokenStream2::new(), - } - } +enum VariantBody { + // No body variant, unit enum variant. + Blank, + // Variant with body (fields) + Fields(VariantFields), +} + +struct VariantOutput { + body: VariantBody, + variant_idx_body: TokenStream2, } fn process_variant( @@ -80,36 +108,48 @@ fn process_variant( generics: &mut serialize::GenericsOutput, ) -> syn::Result { let variant_ident = &variant.ident; - let mut variant_output = VariantOutput::new(); - match &variant.fields { + let variant_output = match &variant.fields { Fields::Named(fields) => { + let mut variant_fields = VariantFields::default(); for field in &fields.named { let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap()); - process_field(field, field_id, cratename, generics, &mut variant_output)?; + process_field(field, field_id, cratename, generics, &mut variant_fields)?; + } + let header = variant_fields.header; + VariantOutput { + body: VariantBody::Fields(VariantFields { + // `..` pattern matching works even if all fields were specified + header: quote! { { #header.. }}, + body: variant_fields.body, + }), + variant_idx_body: quote!( + #enum_ident::#variant_ident {..} => #discriminant_value, + ), } - let header = variant_output.header; - // `..` pattern matching works even if all fields were specified - variant_output.header = quote! { { #header.. }}; - variant_output.variant_idx_body = quote!( - #enum_ident::#variant_ident {..} => #discriminant_value, - ); } Fields::Unnamed(fields) => { + let mut variant_fields = VariantFields::default(); for (field_idx, field) in fields.unnamed.iter().enumerate() { let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?; - process_field(field, field_id, cratename, generics, &mut variant_output)?; + process_field(field, field_id, cratename, generics, &mut variant_fields)?; + } + let header = variant_fields.header; + VariantOutput { + body: VariantBody::Fields(VariantFields { + header: quote! { ( #header )}, + body: variant_fields.body, + }), + variant_idx_body: quote!( + #enum_ident::#variant_ident(..) => #discriminant_value, + ), } - let header = variant_output.header; - variant_output.header = quote! { ( #header )}; - variant_output.variant_idx_body = quote!( - #enum_ident::#variant_ident(..) => #discriminant_value, - ); } - Fields::Unit => { - variant_output.variant_idx_body = quote!( + Fields::Unit => VariantOutput { + body: VariantBody::Blank, + variant_idx_body: quote!( #enum_ident::#variant_ident => #discriminant_value, - ); - } + ), + }, }; Ok(variant_output) } @@ -119,7 +159,7 @@ fn process_field( field_id: serialize::FieldId, cratename: &Path, generics: &mut serialize::GenericsOutput, - output: &mut VariantOutput, + output: &mut VariantFields, ) -> syn::Result<()> { let parsed = field::Attributes::parse(&field.attrs)?; diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap index f2d9971f7..9fc0d1487 100644 --- a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap +++ b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_false.snap @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X { X::F => 5u8, }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - X::A => {} - X::B => {} - X::C => {} - X::D => {} - X::E => {} - X::F => {} - } Ok(()) } } diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap index 7191ea5bb..75ae04424 100644 --- a/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap +++ b/borsh-derive/src/internals/serialize/enums/snapshots/borsh_discriminant_true.snap @@ -16,14 +16,6 @@ impl borsh::ser::BorshSerialize for X { X::F => 10 + 1, }; writer.write_all(&variant_idx.to_le_bytes())?; - match self { - X::A => {} - X::B => {} - X::C => {} - X::D => {} - X::E => {} - X::F => {} - } Ok(()) } } From 0b7ecb05ba581f4876c89e3de625c4f91c50c36a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 6 Dec 2023 19:56:03 +0200 Subject: [PATCH 3/5] test: changed snapshot of expansion of mixed enum serialize derive --- .../serialize/enums/snapshots/mixed_with_unit_variants.snap | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap index 2f3682d61..6de338b4b 100644 --- a/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap +++ b/borsh-derive/src/internals/serialize/enums/snapshots/mixed_with_unit_variants.snap @@ -18,12 +18,11 @@ impl borsh::ser::BorshSerialize for X { X::A(id0) => { borsh::BorshSerialize::serialize(id0, writer)?; } - X::B => {} X::C { x, y, .. } => { borsh::BorshSerialize::serialize(x, writer)?; borsh::BorshSerialize::serialize(y, writer)?; } - X::D => {} + _ => {} } Ok(()) } From b43074331b431e9e4caa2a6c1ad0865d1efbe0c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 6 Dec 2023 20:37:43 +0200 Subject: [PATCH 4/5] chore: extract function `optimize_fields_body` --- .../src/internals/serialize/enums/mod.rs | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index c1ed0dc03..149b6517a 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -18,7 +18,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { let mut fields_body = TokenStream2::new(); let use_discriminant = item::contains_use_discriminant(input)?; let discriminants = Discriminants::new(&input.variants); - let mut blank_variants = false; + let mut has_unit_variant = false; for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_ident = &variant.ident; @@ -32,7 +32,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { )?; all_variants_idx_body.extend(variant_output.variant_idx_body); match variant_output.body { - VariantBody::Blank => blank_variants = true, + VariantBody::Unit => has_unit_variant = true, VariantBody::Fields(VariantFields { header, body }) => fields_body.extend(quote!( #enum_ident::#variant_ident #header => { #body @@ -40,14 +40,31 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { )), } } + let fields_body = optimize_fields_body(fields_body, has_unit_variant); generics_output.extend(&mut where_clause, &cratename); - let fields_match = if fields_body.is_empty() { + Ok(quote! { + impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause { + fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::io::Error> { + let variant_idx: u8 = match self { + #all_variants_idx_body + }; + writer.write_all(&variant_idx.to_le_bytes())?; + + #fields_body + Ok(()) + } + } + }) +} + +fn optimize_fields_body(fields_body: TokenStream2, has_unit_variant: bool) -> TokenStream2 { + if fields_body.is_empty() { // If we no variants with fields, there's nothing to match against. Just // re-use the empty token stream. fields_body } else { - let unit_fields_catchall = if blank_variants { + let unit_fields_catchall = if has_unit_variant { // We had some variants with unit fields, create a catch-all for // these to be used at the bottom. quote!( @@ -65,21 +82,7 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result { #unit_fields_catchall } ) - }; - - Ok(quote! { - impl #impl_generics #cratename::ser::BorshSerialize for #enum_ident #ty_generics #where_clause { - fn serialize(&self, writer: &mut W) -> ::core::result::Result<(), #cratename::io::Error> { - let variant_idx: u8 = match self { - #all_variants_idx_body - }; - writer.write_all(&variant_idx.to_le_bytes())?; - - #fields_match - Ok(()) - } - } - }) + } } #[derive(Default)] @@ -90,7 +93,7 @@ struct VariantFields { enum VariantBody { // No body variant, unit enum variant. - Blank, + Unit, // Variant with body (fields) Fields(VariantFields), } @@ -145,7 +148,7 @@ fn process_variant( } } Fields::Unit => VariantOutput { - body: VariantBody::Blank, + body: VariantBody::Unit, variant_idx_body: quote!( #enum_ident::#variant_ident => #discriminant_value, ), From 015d190e1921452ac94160ac1fef9f29a7483899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dj8yf0=CE=BCl?= Date: Wed, 6 Dec 2023 23:15:50 +0200 Subject: [PATCH 5/5] chore: add methods on introduced `VariantFields` --- .../src/internals/serialize/enums/mod.rs | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/borsh-derive/src/internals/serialize/enums/mod.rs b/borsh-derive/src/internals/serialize/enums/mod.rs index 149b6517a..923f1ee62 100644 --- a/borsh-derive/src/internals/serialize/enums/mod.rs +++ b/borsh-derive/src/internals/serialize/enums/mod.rs @@ -91,6 +91,26 @@ struct VariantFields { body: TokenStream2, } +impl VariantFields { + fn named_header(self) -> Self { + let header = self.header; + + VariantFields { + // `..` pattern matching works even if all fields were specified + header: quote! { { #header.. }}, + body: self.body, + } + } + fn unnamed_header(self) -> Self { + let header = self.header; + + VariantFields { + header: quote! { ( #header )}, + body: self.body, + } + } +} + enum VariantBody { // No body variant, unit enum variant. Unit, @@ -118,13 +138,8 @@ fn process_variant( let field_id = serialize::FieldId::Enum(field.ident.clone().unwrap()); process_field(field, field_id, cratename, generics, &mut variant_fields)?; } - let header = variant_fields.header; VariantOutput { - body: VariantBody::Fields(VariantFields { - // `..` pattern matching works even if all fields were specified - header: quote! { { #header.. }}, - body: variant_fields.body, - }), + body: VariantBody::Fields(variant_fields.named_header()), variant_idx_body: quote!( #enum_ident::#variant_ident {..} => #discriminant_value, ), @@ -136,12 +151,8 @@ fn process_variant( let field_id = serialize::FieldId::new_enum_unnamed(field_idx)?; process_field(field, field_id, cratename, generics, &mut variant_fields)?; } - let header = variant_fields.header; VariantOutput { - body: VariantBody::Fields(VariantFields { - header: quote! { ( #header )}, - body: variant_fields.body, - }), + body: VariantBody::Fields(variant_fields.unnamed_header()), variant_idx_body: quote!( #enum_ident::#variant_ident(..) => #discriminant_value, ),