diff --git a/proto/all_test.go b/proto/all_test.go index a68d91d2aa..1bea4b6e8e 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -2324,6 +2324,28 @@ func TestInvalidUTF8(t *testing.T) { } } +func TestRequired(t *testing.T) { + // The F_BoolRequired field appears after all of the required fields. + // It should still be handled even after multiple required field violations. + m := &GoTest{F_BoolRequired: Bool(true)} + got, err := Marshal(m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if want := []byte{0x50, 0x01}; !bytes.Equal(got, want) { + t.Errorf("Marshal() = %x, want %x", got, want) + } + + m = new(GoTest) + err = Unmarshal(got, m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if !m.GetF_BoolRequired() { + t.Error("m.F_BoolRequired = false, want true") + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/table_marshal.go b/proto/table_marshal.go index eafe04d14b..b16794496f 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -252,11 +252,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte } } for _, f := range u.fields { - if f.required && errLater == nil { + if f.required { if ptr.offset(f.field).getPointer().isNil() { // Required field is not set. // We record the error but keep going, to give a complete marshaling. - errLater = &RequiredNotSetError{f.name} + if errLater == nil { + errLater = &RequiredNotSetError{f.name} + } continue } } @@ -2592,7 +2594,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b = append(b, 1<<3|WireEndGroup) - if nerr.Merge(err) { + if !nerr.Merge(err) { return b, err } } diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index de868ae927..ebf1caa56a 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -175,10 +175,12 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { reqMask |= f.reqMask continue } - if r, ok := err.(*RequiredNotSetError); ok && errLater == nil { + if r, ok := err.(*RequiredNotSetError); ok { // Remember this error, but keep parsing. We need to produce // a full parse even if a required field is missing. - errLater = r + if errLater == nil { + errLater = r + } reqMask |= f.reqMask continue }