diff --git a/internal/editions/editions.go b/internal/editions/editions.go index 0897c8d..b1de386 100644 --- a/internal/editions/editions.go +++ b/internal/editions/editions.go @@ -264,7 +264,7 @@ func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescr // let's try to parse the unrecognized bytes, just in case they contain // this extension. temp := &descriptorpb.FeatureSet{} - unmarshaler := prototext.UnmarshalOptions{ + unmarshaler := proto.UnmarshalOptions{ AllowPartial: true, Resolver: resolverForExtension{field}, } diff --git a/protoutil/editions_test.go b/protoutil/editions_test.go index 7be4d54..4c20548 100644 --- a/protoutil/editions_test.go +++ b/protoutil/editions_test.go @@ -22,6 +22,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" @@ -29,6 +30,7 @@ import ( "github.com/bufbuild/protocompile/internal/editions" "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/protoutil" + "github.com/bufbuild/protocompile/walk" ) func TestResolveFeature(t *testing.T) { @@ -318,30 +320,40 @@ func TestResolveCustomFeature(t *testing.T) { }), } file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) - optionsFile := file.FindImportByPath("options.proto") - extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) - feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck - - val, err := protoutil.ResolveCustomFeature(file, extType, feature) - require.NoError(t, err) - require.Equal(t, testCase.expectedEncabulate, val.Bool()) - - // Same value for an element therein - elem := file.FindDescriptorByName("Foo") - require.NotNil(t, elem) - val, err = protoutil.ResolveCustomFeature(elem, extType, feature) - require.NoError(t, err) - require.Equal(t, testCase.expectedEncabulate, val.Bool()) - - // Check the other feature field, too - feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck - val, err = protoutil.ResolveCustomFeature(file, extType, feature) - require.NoError(t, err) - require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + // First we resolve the feature with the given file. + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } - val, err = protoutil.ResolveCustomFeature(elem, extType, feature) - require.NoError(t, err) - require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + optionsFile := file.FindImportByPath("options.proto") + extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) + feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck + + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedEncabulate, val.Bool()) + + // Same value for an element therein + elem := file.FindDescriptorByName("Foo") + require.NotNil(t, elem) + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, testCase.expectedEncabulate, val.Bool()) + + // Check the other feature field, too + feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck + val, err = protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + + val, err = protoutil.ResolveCustomFeature(elem, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(testCase.expectedNitz), val.Enum()) + } }) } @@ -365,35 +377,49 @@ func TestResolveCustomFeature(t *testing.T) { }), } file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto) - optionsFile := file.FindImportByPath("options.proto") - extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) - feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck + // Then we'll do a second pass where we resolve the + // feature, but all extensions are unrecognized. Both + // ways should work. + for _, clearKnownExts := range []bool{false, true} { + if clearKnownExts { + clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file)) + } - val, err := protoutil.ResolveCustomFeature(file, extType, feature) - require.NoError(t, err) - // Default for edition - require.False(t, val.Bool()) + optionsFile := file.FindImportByPath("options.proto") + extType := dynamicpb.NewExtensionType(optionsFile.FindDescriptorByName("test.custom").(protoreflect.ExtensionDescriptor)) + feature := optionsFile.FindDescriptorByName("test.CustomFeatures.encabulate").(protoreflect.FieldDescriptor) //nolint:errcheck - // Override - field := file.FindDescriptorByName("Bar.name") - require.NotNil(t, field) - val, err = protoutil.ResolveCustomFeature(field, extType, feature) - require.NoError(t, err) - require.True(t, val.Bool()) + val, err := protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + // Default for edition + require.False(t, val.Bool()) - // Check the other feature field, too - feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck - val, err = protoutil.ResolveCustomFeature(file, extType, feature) - require.NoError(t, err) - require.Equal(t, protoreflect.EnumNumber(3), val.Enum()) + // Override + field := file.FindDescriptorByName("Bar.name") + require.NotNil(t, field) + val, err = protoutil.ResolveCustomFeature(field, extType, feature) + require.NoError(t, err) + require.True(t, val.Bool()) - val, err = protoutil.ResolveCustomFeature(field, extType, feature) - require.NoError(t, err) - require.Equal(t, protoreflect.EnumNumber(2), val.Enum()) + // Check the other feature field, too + feature = optionsFile.FindDescriptorByName("test.CustomFeatures.nitz").(protoreflect.FieldDescriptor) //nolint:errcheck + val, err = protoutil.ResolveCustomFeature(file, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(3), val.Enum()) + + val, err = protoutil.ResolveCustomFeature(field, extType, feature) + require.NoError(t, err) + require.Equal(t, protoreflect.EnumNumber(2), val.Enum()) + } }) } -func compileFile(t *testing.T, filename string, sources *protocompile.SourceResolver, deps ...*descriptorpb.FileDescriptorProto) (result linker.File, featureSet protoreflect.MessageDescriptor) { +func compileFile( + t *testing.T, + filename string, + sources *protocompile.SourceResolver, + deps ...*descriptorpb.FileDescriptorProto, +) (result linker.File, featureSet protoreflect.MessageDescriptor) { t.Helper() if sources == nil { sources = &protocompile.SourceResolver{ @@ -441,3 +467,44 @@ func addDepsToResolver(resolver protocompile.Resolver, deps ...*descriptorpb.Fil return resolver.FindFileByPath(path) }) } + +func clearKnownExtensionsFromFile(t *testing.T, file *descriptorpb.FileDescriptorProto) { + t.Helper() + clearKnownExtensionsFromOptions(t, file.GetOptions()) + err := walk.DescriptorProtos(file, func(name protoreflect.FullName, element proto.Message) error { + switch element := element.(type) { + case *descriptorpb.DescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + for _, extRange := range element.GetExtensionRange() { + clearKnownExtensionsFromOptions(t, extRange.GetOptions()) + } + case *descriptorpb.FieldDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + case *descriptorpb.OneofDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + case *descriptorpb.EnumDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + case *descriptorpb.EnumValueDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + case *descriptorpb.ServiceDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + case *descriptorpb.MethodDescriptorProto: + clearKnownExtensionsFromOptions(t, element.GetOptions()) + } + return nil + }) + require.NoError(t, err) +} + +func clearKnownExtensionsFromOptions(t *testing.T, options proto.Message) { + t.Helper() + if options == nil || !options.ProtoReflect().IsValid() { + return // nothing to do + } + data, err := proto.Marshal(options) + require.NoError(t, err) + // We unmarshal from bytes, with a nil resolver, so all extensions + // will remain unrecognized. + err = proto.UnmarshalOptions{Resolver: (*protoregistry.Types)(nil)}.Unmarshal(data, options) + require.NoError(t, err) +}