diff --git a/internal/editions/editions.go b/internal/editions/editions.go index b1de386..850ac36 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -91,6 +91,12 @@ func ResolveFeature( features = withFeatures.GetFeatures() } + // TODO: adaptFeatureSet is only looking at the first field. But if we needed to + // support an extension field inside a custom feature, we'd really need + // to check all fields. That gets particularly complicated if the traversal + // path of fields includes list and map values. Luckily, features are not + // supposed to be repeated and not supposed to themselves have extensions. + // So this should be fine, at least for now. msgRef, err := adaptFeatureSet(features, fields[0]) if err != nil { return protoreflect.Value{}, err @@ -254,45 +260,52 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) { msgRef := msg.ProtoReflect() - if field.IsExtension() { - // Extensions can always be used directly with the feature set, even if - // field.ContainingMessage() != FeatureSetDescriptor. - if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { - return msgRef, nil - } - // The field is not present, but the message has unrecognized values. So - // let's try to parse the unrecognized bytes, just in case they contain - // this extension. - temp := &descriptorpb.FeatureSet{} - unmarshaler := proto.UnmarshalOptions{ - AllowPartial: true, - Resolver: resolverForExtension{field}, - } - if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil { - return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err) + var actualField protoreflect.FieldDescriptor + switch { + case field.IsExtension(): + // Extensions can be used directly with the feature set, even if + // field.ContainingMessage() != FeatureSetDescriptor. But only if + // the value is either not a message or is a message with the + // right descriptor, i.e. val.Descriptor() == field.Message(). + if actualField = actualDescriptor(msgRef, field); actualField == nil || actualField == field { + if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 { + return msgRef, nil + } + // The field is not present, but the message has unrecognized values. So + // let's try to parse the unrecognized bytes, just in case they contain + // this extension. + temp := &descriptorpb.FeatureSet{} + unmarshaler := proto.UnmarshalOptions{ + AllowPartial: true, + Resolver: resolverForExtension{field}, + } + if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil { + return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err) + } + return temp.ProtoReflect(), nil } - return temp.ProtoReflect(), nil - } - - if field.ContainingMessage() == FeatureSetDescriptor { + case field.ContainingMessage() == FeatureSetDescriptor: // Known field, not dynamically generated. Can directly use with the feature set. return msgRef, nil + default: + actualField = FeatureSetDescriptor.Fields().ByNumber(field.Number()) } - // If we get here, we have a dynamic field descriptor. We want to copy its - // value into a dynamic message, which requires marshalling/unmarshalling. - msgField := FeatureSetDescriptor.Fields().ByNumber(field.Number()) + // If we get here, we have a dynamic field descriptor or an extension + // descriptor whose message type does not match the descriptor of the + // stored value. We need to copy its value into a dynamic message, + // which requires marshalling/unmarshalling. // We only need to copy over the unrecognized bytes (if any) // and the same field (if present). data := msgRef.GetUnknown() - if msgField != nil && msgRef.Has(msgField) { + if actualField != nil && msgRef.Has(actualField) { subset := &descriptorpb.FeatureSet{} - subset.ProtoReflect().Set(msgField, msgRef.Get(msgField)) - fieldBytes, err := proto.MarshalOptions{AllowPartial: true}.Marshal(subset) + subset.ProtoReflect().Set(actualField, msgRef.Get(actualField)) + var err error + data, err = proto.MarshalOptions{AllowPartial: true}.MarshalAppend(data, subset) if err != nil { return nil, fmt.Errorf("failed to marshal FeatureSet field %s to bytes: %w", field.Name(), err) } - data = append(data, fieldBytes...) } if len(data) == 0 { // No relevant data to copy over, so we can just return @@ -354,3 +367,54 @@ func computeSupportedEditions(min, max descriptorpb.Edition) map[string]descript } return supportedEditions } + +// actualDescriptor returns the actual field descriptor referenced by msg that +// corresponds to the given ext (i.e. same number). It returns nil if msg has +// no reference, if the actual descriptor is the same as ext, or if ext is +// otherwise safe to use as is. +func actualDescriptor(msg protoreflect.Message, ext protoreflect.ExtensionDescriptor) protoreflect.FieldDescriptor { + if !msg.Has(ext) || ext.Message() == nil { + // nothing to match; safe as is + return nil + } + val := msg.Get(ext) + switch { + case ext.IsMap(): // should not actually be possible + expectedDescriptor := ext.MapValue().Message() + if expectedDescriptor == nil { + return nil // nothing to match + } + // We know msg.Has(field) is true, from above, so there's at least one entry. + var matches bool + val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool { + matches = val.Message().Descriptor() == expectedDescriptor + return false + }) + if matches { + return nil + } + case ext.IsList(): + // We know msg.Has(field) is true, from above, so there's at least one entry. + if val.List().Get(0).Message().Descriptor() == ext.Message() { + return nil + } + case !ext.IsMap(): + if val.Message().Descriptor() == ext.Message() { + return nil + } + } + // The underlying message descriptors do not match. So we need to return + // the actual field descriptor. Sadly, protoreflect.Message provides no way + // to query the field descriptor in a message by number. For non-extensions, + // one can query the associated message descriptor. But for extensions, we + // have to do the slow thing, and range through all fields looking for it. + var actualField protoreflect.FieldDescriptor + msg.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { + if fd.Number() == ext.Number() { + actualField = fd + return false + } + return true + }) + return actualField +} diff --git a/protoutil/editions_test.go b/protoutil/editions_test.go index 4c20548..2024234 100644 --- a/protoutil/editions_test.go +++ b/protoutil/editions_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/gofeaturespb" "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/internal/editions" @@ -377,6 +378,7 @@ func TestResolveCustomFeature(t *testing.T) { }), } file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) + // First we resolve the feature with the given file. // Then we'll do a second pass where we resolve the // feature, but all extensions are unrecognized. Both // ways should work. @@ -414,6 +416,142 @@ func TestResolveCustomFeature(t *testing.T) { }) } +func TestResolveCustomFeature_Generated(t *testing.T) { + t.Parallel() + descriptorProto := protodesc.ToFileDescriptorProto( + (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + goFeaturesProto := protodesc.ToFileDescriptorProto( + (*gofeaturespb.GoFeatures)(nil).ProtoReflect().Descriptor().ParentFile(), + ) + + // We can do proto2 and proto3 in the same way since they + // can't override feature values. + preEditionsTestCases := []struct { + syntax string + expectedValue bool + }{ + { + syntax: "proto2", + expectedValue: true, + }, + { + syntax: "proto3", + expectedValue: false, + }, + } + for _, testCase := range preEditionsTestCases { + testCase := testCase + t.Run(testCase.syntax, func(t *testing.T) { + t.Parallel() + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "test.proto": ` + syntax = "` + testCase.syntax + `"; + import "google/protobuf/go_features.proto"; + enum Foo { + ZERO = 0; + }`, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto) + // First we resolve the feature with the given file. + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } + + extType := gofeaturespb.E_Go + feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum") + require.NotNil(t, feature) + + // Default for edition + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedValue, val.Bool()) + + // Same value for an element therein + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedValue, val.Bool()) + } + }) + } + + editionsTestCases := []struct { + name string + source string + exopectedValue bool + }{ + { + name: "editions-2023-default", + source: ` + edition = "2023"; + import "google/protobuf/go_features.proto"; + enum Foo { + ZERO = 0; + }`, + exopectedValue: false, + }, + { + name: "editions-override", + source: ` + edition = "2023"; + import "google/protobuf/go_features.proto"; + enum Foo { + option features.(pb.go).legacy_unmarshal_json_enum = true; + ZERO = 0; + }`, + exopectedValue: true, + }, + } + + for _, testCase := range editionsTestCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + sourceResolver := &protocompile.SourceResolver{ + Accessor: protocompile.SourceAccessorFromMap(map[string]string{ + "test.proto": testCase.source, + }), + } + file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto) + // First we resolve the feature with the given file. + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } + + extType := gofeaturespb.E_Go + feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum") + require.NotNil(t, feature) + + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + // Edition default is false, and can't be overridden at the file level, + // so this should always be false. + require.False(t, val.Bool()) + + // Override + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.exopectedValue, val.Bool()) + } + }) + } +} + func compileFile( t *testing.T, filename string,