Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more thorough checks for when to adapt a value when resolving custom feature #306

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 91 additions & 27 deletions internal/editions/editions.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ func ResolveFeature(
features = withFeatures.GetFeatures()
}

// TODO: adaptFeatureSet is only looking at the first field. But if we needed to
// support an extension field inside a custom feature, we'd really need
// to check all fields. That gets particularly complicated if the traversal
// path of fields includes list and map values. Luckily, features are not
// supposed to be repeated and not supposed to themselves have extensions.
// So this should be fine, at least for now.
msgRef, err := adaptFeatureSet(features, fields[0])
if err != nil {
return protoreflect.Value{}, err
Expand Down Expand Up @@ -254,45 +260,52 @@ func GetFeatureDefault(edition descriptorpb.Edition, container protoreflect.Mess

func adaptFeatureSet(msg *descriptorpb.FeatureSet, field protoreflect.FieldDescriptor) (protoreflect.Message, error) {
msgRef := msg.ProtoReflect()
if field.IsExtension() {
// Extensions can always be used directly with the feature set, even if
// field.ContainingMessage() != FeatureSetDescriptor.
if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 {
return msgRef, nil
}
// The field is not present, but the message has unrecognized values. So
// let's try to parse the unrecognized bytes, just in case they contain
// this extension.
temp := &descriptorpb.FeatureSet{}
unmarshaler := proto.UnmarshalOptions{
AllowPartial: true,
Resolver: resolverForExtension{field},
}
if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil {
return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err)
var actualField protoreflect.FieldDescriptor
switch {
case field.IsExtension():
// Extensions can be used directly with the feature set, even if
// field.ContainingMessage() != FeatureSetDescriptor. But only if
// the value is either not a message or is a message with the
// right descriptor, i.e. val.Descriptor() == field.Message().
if actualField = actualDescriptor(msgRef, field); actualField == nil || actualField == field {
if msgRef.Has(field) || len(msgRef.GetUnknown()) == 0 {
return msgRef, nil
}
// The field is not present, but the message has unrecognized values. So
// let's try to parse the unrecognized bytes, just in case they contain
// this extension.
temp := &descriptorpb.FeatureSet{}
unmarshaler := proto.UnmarshalOptions{
AllowPartial: true,
Resolver: resolverForExtension{field},
}
if err := unmarshaler.Unmarshal(msgRef.GetUnknown(), temp); err != nil {
return nil, fmt.Errorf("failed to parse unrecognized fields of FeatureSet: %w", err)
}
return temp.ProtoReflect(), nil
}
return temp.ProtoReflect(), nil
}

if field.ContainingMessage() == FeatureSetDescriptor {
case field.ContainingMessage() == FeatureSetDescriptor:
// Known field, not dynamically generated. Can directly use with the feature set.
return msgRef, nil
default:
actualField = FeatureSetDescriptor.Fields().ByNumber(field.Number())
}

// If we get here, we have a dynamic field descriptor. We want to copy its
// value into a dynamic message, which requires marshalling/unmarshalling.
msgField := FeatureSetDescriptor.Fields().ByNumber(field.Number())
// If we get here, we have a dynamic field descriptor or an extension
// descriptor whose message type does not match the descriptor of the
// stored value. We need to copy its value into a dynamic message,
// which requires marshalling/unmarshalling.
// We only need to copy over the unrecognized bytes (if any)
// and the same field (if present).
data := msgRef.GetUnknown()
if msgField != nil && msgRef.Has(msgField) {
if actualField != nil && msgRef.Has(actualField) {
subset := &descriptorpb.FeatureSet{}
subset.ProtoReflect().Set(msgField, msgRef.Get(msgField))
fieldBytes, err := proto.MarshalOptions{AllowPartial: true}.Marshal(subset)
subset.ProtoReflect().Set(actualField, msgRef.Get(actualField))
var err error
data, err = proto.MarshalOptions{AllowPartial: true}.MarshalAppend(data, subset)
if err != nil {
return nil, fmt.Errorf("failed to marshal FeatureSet field %s to bytes: %w", field.Name(), err)
}
data = append(data, fieldBytes...)
}
if len(data) == 0 {
// No relevant data to copy over, so we can just return
Expand Down Expand Up @@ -354,3 +367,54 @@ func computeSupportedEditions(min, max descriptorpb.Edition) map[string]descript
}
return supportedEditions
}

// actualDescriptor returns the actual field descriptor referenced by msg that
// corresponds to the given ext (i.e. same number). It returns nil if msg has
// no reference, if the actual descriptor is the same as ext, or if ext is
// otherwise safe to use as is.
func actualDescriptor(msg protoreflect.Message, ext protoreflect.ExtensionDescriptor) protoreflect.FieldDescriptor {
if !msg.Has(ext) || ext.Message() == nil {
// nothing to match; safe as is
return nil
}
val := msg.Get(ext)
switch {
case ext.IsMap(): // should not actually be possible
expectedDescriptor := ext.MapValue().Message()
if expectedDescriptor == nil {
return nil // nothing to match
}
// We know msg.Has(field) is true, from above, so there's at least one entry.
var matches bool
val.Map().Range(func(_ protoreflect.MapKey, val protoreflect.Value) bool {
matches = val.Message().Descriptor() == expectedDescriptor
return false
})
if matches {
return nil
}
case ext.IsList():
// We know msg.Has(field) is true, from above, so there's at least one entry.
if val.List().Get(0).Message().Descriptor() == ext.Message() {
return nil
}
case !ext.IsMap():
if val.Message().Descriptor() == ext.Message() {
return nil
}
}
// The underlying message descriptors do not match. So we need to return
// the actual field descriptor. Sadly, protoreflect.Message provides no way
// to query the field descriptor in a message by number. For non-extensions,
// one can query the associated message descriptor. But for extensions, we
// have to do the slow thing, and range through all fields looking for it.
var actualField protoreflect.FieldDescriptor
msg.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
if fd.Number() == ext.Number() {
actualField = fd
return false
}
return true
})
return actualField
}
138 changes: 138 additions & 0 deletions protoutil/editions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/gofeaturespb"

"github.com/bufbuild/protocompile"
"github.com/bufbuild/protocompile/internal/editions"
Expand Down Expand Up @@ -377,6 +378,7 @@ func TestResolveCustomFeature(t *testing.T) {
}),
}
file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto)
// First we resolve the feature with the given file.
// Then we'll do a second pass where we resolve the
// feature, but all extensions are unrecognized. Both
// ways should work.
Expand Down Expand Up @@ -414,6 +416,142 @@ func TestResolveCustomFeature(t *testing.T) {
})
}

func TestResolveCustomFeature_Generated(t *testing.T) {
t.Parallel()
descriptorProto := protodesc.ToFileDescriptorProto(
(*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile(),
)
goFeaturesProto := protodesc.ToFileDescriptorProto(
(*gofeaturespb.GoFeatures)(nil).ProtoReflect().Descriptor().ParentFile(),
)

// We can do proto2 and proto3 in the same way since they
// can't override feature values.
preEditionsTestCases := []struct {
syntax string
expectedValue bool
}{
{
syntax: "proto2",
expectedValue: true,
},
{
syntax: "proto3",
expectedValue: false,
},
}
for _, testCase := range preEditionsTestCases {
testCase := testCase
t.Run(testCase.syntax, func(t *testing.T) {
t.Parallel()
sourceResolver := &protocompile.SourceResolver{
Accessor: protocompile.SourceAccessorFromMap(map[string]string{
"test.proto": `
syntax = "` + testCase.syntax + `";
import "google/protobuf/go_features.proto";
enum Foo {
ZERO = 0;
}`,
}),
}
file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto)
// First we resolve the feature with the given file.
// Then we'll do a second pass where we resolve the
// feature, but all extensions are unrecognized. Both
// ways should work.
for _, clearKnownExts := range []bool{false, true} {
if clearKnownExts {
clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file))
}

extType := gofeaturespb.E_Go
feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum")
require.NotNil(t, feature)

// Default for edition
val, err := protoutil.ResolveCustomFeature(file, extType, feature)
require.NoError(t, err)
require.Equal(t, testCase.expectedValue, val.Bool())

// Same value for an element therein
elem := file.FindDescriptorByName("Foo")
require.NotNil(t, elem)
val, err = protoutil.ResolveCustomFeature(elem, extType, feature)
require.NoError(t, err)
require.Equal(t, testCase.expectedValue, val.Bool())
}
})
}

editionsTestCases := []struct {
name string
source string
exopectedValue bool
}{
{
name: "editions-2023-default",
source: `
edition = "2023";
import "google/protobuf/go_features.proto";
enum Foo {
ZERO = 0;
}`,
exopectedValue: false,
},
{
name: "editions-override",
source: `
edition = "2023";
import "google/protobuf/go_features.proto";
enum Foo {
option features.(pb.go).legacy_unmarshal_json_enum = true;
ZERO = 0;
}`,
exopectedValue: true,
},
}

for _, testCase := range editionsTestCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

sourceResolver := &protocompile.SourceResolver{
Accessor: protocompile.SourceAccessorFromMap(map[string]string{
"test.proto": testCase.source,
}),
}
file, _ := compileFile(t, "test.proto", sourceResolver, descriptorProto, goFeaturesProto)
// First we resolve the feature with the given file.
// Then we'll do a second pass where we resolve the
// feature, but all extensions are unrecognized. Both
// ways should work.
for _, clearKnownExts := range []bool{false, true} {
if clearKnownExts {
clearKnownExtensionsFromFile(t, protoutil.ProtoFromFileDescriptor(file))
}

extType := gofeaturespb.E_Go
feature := gofeaturespb.E_Go.TypeDescriptor().Message().Fields().ByName("legacy_unmarshal_json_enum")
require.NotNil(t, feature)

val, err := protoutil.ResolveCustomFeature(file, extType, feature)
require.NoError(t, err)
// Edition default is false, and can't be overridden at the file level,
// so this should always be false.
require.False(t, val.Bool())

// Override
elem := file.FindDescriptorByName("Foo")
require.NotNil(t, elem)
val, err = protoutil.ResolveCustomFeature(elem, extType, feature)
require.NoError(t, err)
require.Equal(t, testCase.exopectedValue, val.Bool())
}
})
}
}

func compileFile(
t *testing.T,
filename string,
Expand Down
Loading