From 27e5e21d3950395a700d369952d317be85b69a62 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Fri, 10 May 2024 16:17:08 -0400 Subject: [PATCH 1/5] more thorough checks to see if we need to adapt (marshal+unmarshal) value to right descriptor --- internal/editions/editions.go | 80 ++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 21 deletions(-) diff --git a/internal/editions/editions.go b/internal/editions/editions.go index b1de386..560cb9c 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -91,6 +91,12 @@ func ResolveFeature( features = withFeatures.GetFeatures() } + // TODO: adaptFeatureSet is only looking at the first field. But if we needed to + // support an extension field inside a custom feature, we'd really need + // to check all fields. That gets particularly complicated if the traversal + // path of fields includes list and map values. Luckily, features are not + // supposed to be repeated and not supposed to themselves have extensions. + // So this should be fine, at least for now. msgRef, err := adaptFeatureSet(features, fields[0]) if err != nil { return protoreflect.Value{}, err @@ -255,32 +261,36 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) { msgRef := msg.ProtoReflect() if field.IsExtension() { - // Extensions can always be used directly with the feature set, even if - // field.ContainingMessage() != FeatureSetDescriptor. - if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { - return msgRef, nil - } - // The field is not present, but the message has unrecognized values. So - // let's try to parse the unrecognized bytes, just in case they contain - // this extension. - temp := &descriptorpb.FeatureSet{} - unmarshaler := proto.UnmarshalOptions{ - AllowPartial: true, - Resolver: resolverForExtension{field}, - } - if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil { - return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err) + // Extensions can be used directly with the feature set, even if + // field.ContainingMessage() != FeatureSetDescriptor. But only if + // the value is either not a message or is a message with the + // right descriptor, i.e. val.Descriptor() == field.Message(). + if valueMatchesDescriptor(msgRef, field) { + if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { + return msgRef, nil + } + // The field is not present, but the message has unrecognized values. So + // let's try to parse the unrecognized bytes, just in case they contain + // this extension. + temp := &descriptorpb.FeatureSet{} + unmarshaler := proto.UnmarshalOptions{ + AllowPartial: true, + Resolver: resolverForExtension{field}, + } + if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil { + return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err) + } + return temp.ProtoReflect(), nil } - return temp.ProtoReflect(), nil - } - - if field.ContainingMessage() == FeatureSetDescriptor { + } else if field.ContainingMessage() == FeatureSetDescriptor { // Known field, not dynamically generated. Can directly use with the feature set. return msgRef, nil } - // If we get here, we have a dynamic field descriptor. We want to copy its - // value into a dynamic message, which requires marshalling/unmarshalling. + // If we get here, we have a dynamic field descriptor or an extension + // descriptor whose message type does not match the descriptor of the + // stored value. We need to copy its value into a dynamic message, + // which requires marshalling/unmarshalling. msgField := FeatureSetDescriptor.Fields().ByNumber(field.Number()) // We only need to copy over the unrecognized bytes (if any) // and the same field (if present). @@ -354,3 +364,31 @@ func computeSupportedEditions(min, max descriptorpb.Edition) map[string]descript } return supportedEditions } + +func valueMatchesDescriptor(msg protoreflect.Message, field protoreflect.FieldDescriptor) bool { + if !msg.Has(field) || field.Message() == nil { + // nothing to match + return true + } + val := msg.Get(field) + switch { + case field.IsMap(): + if expectedDescriptor := field.MapValue().Message(); expectedDescriptor != nil { + // We know msg.Has(field) is true, from above, so there's at least one entry. + var matches bool + val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool { + matches = val.Message().Descriptor() == expectedDescriptor + return false + }) + return matches + } + return true // nothing to match + case field.IsList(): + // We know msg.Has(field) is true, from above, so there's at least one entry. + return val.List().Get(0).Message().Descriptor() == field.Message() + case !field.IsMap(): + return val.Message().Descriptor() == field.Message() + default: + return true + } +} From 4f70e7e59a2f8fb41a66dc915874644ca378b993 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Mon, 20 May 2024 13:00:41 -0400 Subject: [PATCH 2/5] add new test to exercise new flow --- protoutil/editions_test.go | 139 +++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/protoutil/editions_test.go b/protoutil/editions_test.go index 4c20548..f36f420 100644 --- a/protoutil/editions_test.go +++ b/protoutil/editions_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/gofeaturespb" "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/internal/editions" @@ -377,6 +378,7 @@ func TestResolveCustomFeature(t *testing.T) { }), } file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) + // First we resolve the feature with the given file. // Then we'll do a second pass where we resolve the // feature, but all extensions are unrecognized. Both // ways should work. @@ -414,6 +416,143 @@ func TestResolveCustomFeature(t *testing.T) { }) } +func TestResolveCustomFeature_Generated(t *testing.T) { + t.Parallel() + descriptorProto := protodesc.ToFileDescriptorProto( + (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + goFeaturesProto := protodesc.ToFileDescriptorProto( + (*gofeaturespb.GoFeatures)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + + // We can do proto2 and proto3 in the same way since they + // can't override feature values. + preEditionsTestCases := []struct { + syntax string + expectedValue bool + }{ + { + syntax: "proto2", + expectedValue: true, + }, + { + syntax: "proto3", + expectedValue: false, + }, + } + for _, testCase := range preEditionsTestCases { + testCase := testCase + t.Run(testCase.syntax, func(t *testing.T) { + t.Parallel() + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "test.proto": ` + syntax = "` + testCase.syntax + `"; + import "google/protobuf/go_features.proto"; + enum Foo { + ZERO = 0; + }`, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto) + // First we resolve the feature with the given file. + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } + + extType := gofeaturespb.E_Go + feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum") + require.NotNil(t, feature) + + // Default for edition + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedValue, val.Bool()) + + // Same value for an element therein + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedValue, val.Bool()) + } + }) + } + + editionsTestCases := []struct { + name string + source string + expectedFileValue bool + expectedEnumValue bool + }{ + { + name: "editions-2023-default", + source: ` + edition = "2023"; + import "google/protobuf/go_features.proto"; + enum Foo { + ZERO = 0; + }`, + expectedFileValue: false, + expectedEnumValue: false, + }, + { + name: "editions-override", + source: ` + edition = "2023"; + import "google/protobuf/go_features.proto"; + enum Foo { + option features.(pb.go).legacy_unmarshal_json_enum = true; + ZERO = 0; + }`, + expectedFileValue: false, + expectedEnumValue: true, + }, + } + + for _, testCase := range editionsTestCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "test.proto": testCase.source, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto) + // First we resolve the feature with the given file. + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } + + extType := gofeaturespb.E_Go + feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum") + require.NotNil(t, feature) + + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedFileValue, val.Bool()) + + // Override + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedEnumValue, val.Bool()) + } + }) + } +} + func compileFile( t *testing.T, filename string, From 45da2c066a7fc301328c95565b4c37c2620272a5 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Mon, 20 May 2024 12:38:17 -0400 Subject: [PATCH 3/5] fix logic bugs in previous commit --- internal/editions/editions.go | 80 +++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/internal/editions/editions.go b/internal/editions/editions.go index 560cb9c..9a1526f 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -260,12 +260,13 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) { msgRef := msg.ProtoReflect() + var actualField protoreflect.FieldDescriptor if field.IsExtension() { // Extensions can be used directly with the feature set, even if // field.ContainingMessage() != FeatureSetDescriptor. But only if // the value is either not a message or is a message with the // right descriptor, i.e. val.Descriptor() == field.Message(). - if valueMatchesDescriptor(msgRef, field) { + if actualField = actualDescriptor(msgRef, field); actualField == nil || actualField == field { if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { return msgRef, nil } @@ -285,24 +286,26 @@ func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescr } else if field.ContainingMessage() == FeatureSetDescriptor { // Known field, not dynamically generated. Can directly use with the feature set. return msgRef, nil + } else { + actualField = FeatureSetDescriptor.Fields().ByNumber(field.Number()) + } // If we get here, we have a dynamic field descriptor or an extension // descriptor whose message type does not match the descriptor of the // stored value. We need to copy its value into a dynamic message, // which requires marshalling/unmarshalling. - msgField := FeatureSetDescriptor.Fields().ByNumber(field.Number()) // We only need to copy over the unrecognized bytes (if any) // and the same field (if present). data := msgRef.GetUnknown() - if msgField != nil && msgRef.Has(msgField) { + if actualField != nil && msgRef.Has(actualField) { subset := &descriptorpb.FeatureSet{} - subset.ProtoReflect().Set(msgField, msgRef.Get(msgField)) - fieldBytes, err := proto.MarshalOptions{AllowPartial: true}.Marshal(subset) + subset.ProtoReflect().Set(actualField, msgRef.Get(actualField)) + var err error + data, err = proto.MarshalOptions{AllowPartial: true}.MarshalAppend(data, subset) if err != nil { return nil, fmt.Errorf("failed to marshal FeatureSet field %s to bytes: %w", field.Name(), err) } - data = append(data, fieldBytes...) } if len(data) == 0 { // No relevant data to copy over, so we can just return @@ -365,30 +368,53 @@ func computeSupportedEditions(min, max descriptorpb.Edition) map[string]descript return supportedEditions } -func valueMatchesDescriptor(msg protoreflect.Message, field protoreflect.FieldDescriptor) bool { - if !msg.Has(field) || field.Message() == nil { - // nothing to match - return true +// actualDescriptor returns the actual field descriptor referenced by msg that +// corresponds to the given ext (i.e. same number). It returns nil if msg has +// no reference, if the actual descriptor is the same as ext, or if ext is +// otherwise safe to use as is. +func actualDescriptor(msg protoreflect.Message, ext protoreflect.ExtensionDescriptor) protoreflect.FieldDescriptor { + if !msg.Has(ext) || ext.Message() == nil { + // nothing to match; safe as is + return nil } - val := msg.Get(field) + val := msg.Get(ext) switch { - case field.IsMap(): - if expectedDescriptor := field.MapValue().Message(); expectedDescriptor != nil { - // We know msg.Has(field) is true, from above, so there's at least one entry. - var matches bool - val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool { - matches = val.Message().Descriptor() == expectedDescriptor - return false - }) - return matches + case ext.IsMap(): // should not actually be possible + expectedDescriptor := ext.MapValue().Message() + if expectedDescriptor == nil { + return nil // nothing to match } - return true // nothing to match - case field.IsList(): // We know msg.Has(field) is true, from above, so there's at least one entry. - return val.List().Get(0).Message().Descriptor() == field.Message() - case !field.IsMap(): - return val.Message().Descriptor() == field.Message() - default: - return true + var matches bool + val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool { + matches = val.Message().Descriptor() == expectedDescriptor + return false + }) + if matches { + return nil + } + case ext.IsList(): + // We know msg.Has(field) is true, from above, so there's at least one entry. + if val.List().Get(0).Message().Descriptor() == ext.Message() { + return nil + } + case !ext.IsMap(): + if val.Message().Descriptor() == ext.Message() { + return nil + } } + // The underlying message descriptors do not match. So we need to return + // the actual field descriptor. Sadly, protoreflect.Message provides no way + // to query the field descriptor in a message by number. For non-extensions, + // one can query the associated message descriptor. But for extensions, we + // have to do the slow thing, and range through all fields looking for it. + var actualField protoreflect.FieldDescriptor + msg.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { + if fd.Number() == ext.Number() { + actualField = fd + return false + } + return true + }) + return actualField } From 2a47c77d8cc56436dcf9bc597f092d4f84df0b61 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Mon, 20 May 2024 13:29:15 -0400 Subject: [PATCH 4/5] make linter happy --- internal/editions/editions.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/editions/editions.go b/internal/editions/editions.go index 9a1526f..850ac36 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -261,7 +261,8 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) { msgRef := msg.ProtoReflect() var actualField protoreflect.FieldDescriptor - if field.IsExtension() { + switch { + case field.IsExtension(): // Extensions can be used directly with the feature set, even if // field.ContainingMessage() != FeatureSetDescriptor. But only if // the value is either not a message or is a message with the @@ -283,12 +284,11 @@ func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescr } return temp.ProtoReflect(), nil } - } else if field.ContainingMessage() == FeatureSetDescriptor { + case field.ContainingMessage() == FeatureSetDescriptor: // Known field, not dynamically generated. Can directly use with the feature set. return msgRef, nil - } else { + default: actualField = FeatureSetDescriptor.Fields().ByNumber(field.Number()) - } // If we get here, we have a dynamic field descriptor or an extension From 01bcced5bae2896468ca0a07514db77ed2402317 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 22 May 2024 09:03:40 -0400 Subject: [PATCH 5/5] minor cleanup in test since file value is always false --- protoutil/editions_test.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/protoutil/editions_test.go b/protoutil/editions_test.go index f36f420..2024234 100644 --- a/protoutil/editions_test.go +++ b/protoutil/editions_test.go @@ -484,10 +484,9 @@ func TestResolveCustomFeature_Generated(t *testing.T) { } editionsTestCases := []struct { - name string - source string - expectedFileValue bool - expectedEnumValue bool + name string + source string + exopectedValue bool }{ { name: "editions-2023-default", @@ -497,8 +496,7 @@ func TestResolveCustomFeature_Generated(t *testing.T) { enum Foo { ZERO = 0; }`, - expectedFileValue: false, - expectedEnumValue: false, + exopectedValue: false, }, { name: "editions-override", @@ -509,8 +507,7 @@ func TestResolveCustomFeature_Generated(t *testing.T) { option features.(pb.go).legacy_unmarshal_json_enum = true; ZERO = 0; }`, - expectedFileValue: false, - expectedEnumValue: true, + exopectedValue: true, }, } @@ -540,14 +537,16 @@ func TestResolveCustomFeature_Generated(t *testing.T) { val, err := protoutil.ResolveCustomFeature(file, extType, feature) require.NoError(t, err) - require.Equal(t, testCase.expectedFileValue, val.Bool()) + // Edition default is false, and can't be overridden at the file level, + // so this should always be false. + require.False(t, val.Bool()) // Override elem := file.FindDescriptorByName("Foo") require.NotNil(t, elem) val, err = protoutil.ResolveCustomFeature(elem, extType, feature) require.NoError(t, err) - require.Equal(t, testCase.expectedEnumValue, val.Bool()) + require.Equal(t, testCase.exopectedValue, val.Bool()) } }) }