diff --git a/attribute_types.go b/attribute_types.go index a0d0831..739beb5 100644 --- a/attribute_types.go +++ b/attribute_types.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/mdlayher/netlink" "github.com/pkg/errors" "github.com/ti-mo/netfilter" ) @@ -21,101 +22,9 @@ const ( opUnSynProxy = "SynProxy unmarshal" ) -var ( - ctaCountersOrigReplyCat = fmt.Sprintf("%s/%s", ctaCountersOrig, ctaCountersReply) - ctaSeqAdjOrigReplyCat = fmt.Sprintf("%s/%s", ctaSeqAdjOrig, ctaSeqAdjReply) -) - -// num16 is a generic numeric attribute. It is represented by a uint32 -// and holds its own AttributeType. -type num16 struct { - Type attributeType - Value uint16 -} - -// Filled returns true if the Num16's type is non-zero. -func (i num16) filled() bool { - return i.Type != 0 || i.Value != 0 -} - -func (i num16) String() string { - return fmt.Sprintf("%d", i.Value) -} - -// unmarshal unmarshals a netfilter.Attribute into a Num16. -func (i *num16) unmarshal(attr netfilter.Attribute) error { - - if len(attr.Data) != 2 { - return errIncorrectSize - } - - i.Type = attributeType(attr.Type) - i.Value = attr.Uint16() - - return nil -} - -// marshal marshals a Num16 into a netfilter.Attribute. If the AttributeType parameter is non-zero, -// it is used as Attribute's type; otherwise, the Num16's Type field is used. -func (i num16) marshal(t attributeType) netfilter.Attribute { - - var nfa netfilter.Attribute - - if t == 0 { - nfa.Type = uint16(i.Type) - } else { - nfa.Type = uint16(t) - } - - nfa.PutUint16(i.Value) - - return nfa -} - -// num32 is a generic numeric attribute. It is represented by a uint32 -// and holds its own AttributeType. -type num32 struct { - Type attributeType - Value uint32 -} - -// Filled returns true if the Num32's type is non-zero. -func (i num32) filled() bool { - return i.Type != 0 || i.Value != 0 -} - -func (i num32) String() string { - return fmt.Sprintf("%d", i.Value) -} - -// unmarshal unmarshals a netfilter.Attribute into a Num32. -func (i *num32) unmarshal(attr netfilter.Attribute) error { - - if len(attr.Data) != 4 { - return errIncorrectSize - } - - i.Type = attributeType(attr.Type) - i.Value = attr.Uint32() - - return nil -} - -// marshal marshals a Num32 into a netfilter.Attribute. If the AttributeType parameter is non-zero, -// it is used as Attribute's type; otherwise, the Num32's Type field is used. -func (i num32) marshal(t attributeType) netfilter.Attribute { - - var nfa netfilter.Attribute - - if t == 0 { - nfa.Type = uint16(i.Type) - } else { - nfa.Type = uint16(t) - } - - nfa.PutUint32(i.Value) - - return nfa +// nestedFlag returns true if the NLA_F_NESTED flag is set on typ. +func nestedFlag(typ uint16) bool { + return typ&netlink.Nested != 0 } // A Helper holds the name and info the helper that creates a related connection. @@ -130,28 +39,20 @@ func (hlp Helper) filled() bool { } // unmarshal unmarshals a netfilter.Attribute into a Helper. -func (hlp *Helper) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaHelp { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaHelp) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnHelper) - } +func (hlp *Helper) unmarshal(ad *netlink.AttributeDecoder) error { - for _, iattr := range attr.Children { - switch helperType(iattr.Type) { + for ad.Next() { + switch helperType(ad.Type()) { case ctaHelpName: - hlp.Name = string(iattr.Data) + hlp.Name = ad.String() case ctaHelpInfo: - hlp.Info = iattr.Data + hlp.Info = ad.Bytes() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaHelp) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a Helper into a netfilter.Attribute. @@ -183,52 +84,40 @@ func (pi ProtoInfo) filled() bool { // unmarshal unmarshals a netfilter.Attribute into a ProtoInfo structure. // one of three ProtoInfo types; TCP, DCCP or SCTP. -func (pi *ProtoInfo) unmarshal(attr netfilter.Attribute) error { +func (pi *ProtoInfo) unmarshal(ad *netlink.AttributeDecoder) error { // Make sure we don't unmarshal into the same ProtoInfo twice. if pi.filled() { return errReusedProtoInfo } - if attributeType(attr.Type) != ctaProtoInfo { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaProtoInfo) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnProtoInfo) - } - - if len(attr.Children) != 1 { + if ad.Len() != 1 { return errors.Wrap(errNeedSingleChild, opUnProtoInfo) } - // Step into the single nested child - iattr := attr.Children[0] + // Step into the single nested child, return on error. + if !ad.Next() { + return ad.Err() + } - switch protoInfoType(iattr.Type) { + switch protoInfoType(ad.Type()) { case ctaProtoInfoTCP: var tpi ProtoInfoTCP - if err := tpi.unmarshal(iattr); err != nil { - return err - } + ad.Nested(tpi.unmarshal) pi.TCP = &tpi case ctaProtoInfoDCCP: var dpi ProtoInfoDCCP - if err := dpi.unmarshal(iattr); err != nil { - return err - } + ad.Nested(dpi.unmarshal) pi.DCCP = &dpi case ctaProtoInfoSCTP: var spi ProtoInfoSCTP - if err := spi.unmarshal(iattr); err != nil { - return err - } + ad.Nested(spi.unmarshal) pi.SCTP = &spi default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaProtoInfo) + return fmt.Errorf(errAttributeChild, ad.Type()) } - return nil + return ad.Err() } // marshal marshals a ProtoInfo into a netfilter.Attribute. @@ -258,39 +147,31 @@ type ProtoInfoTCP struct { } // unmarshal unmarshals a netfilter.Attribute into a ProtoInfoTCP. -func (tpi *ProtoInfoTCP) unmarshal(attr netfilter.Attribute) error { - - if protoInfoType(attr.Type) != ctaProtoInfoTCP { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaProtoInfoTCP) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnProtoInfoTCP) - } +func (tpi *ProtoInfoTCP) unmarshal(ad *netlink.AttributeDecoder) error { // A ProtoInfoTCP has at least 3 members, TCP_STATE and TCP_FLAGS_ORIG/REPLY. - if len(attr.Children) < 3 { + if ad.Len() < 3 { return errors.Wrap(errNeedChildren, opUnProtoInfoTCP) } - for _, iattr := range attr.Children { - switch protoInfoTCPType(iattr.Type) { + for ad.Next() { + switch protoInfoTCPType(ad.Type()) { case ctaProtoInfoTCPState: - tpi.State = iattr.Data[0] + tpi.State = ad.Uint8() case ctaProtoInfoTCPWScaleOriginal: - tpi.OriginalWindowScale = iattr.Data[0] + tpi.OriginalWindowScale = ad.Uint8() case ctaProtoInfoTCPWScaleReply: - tpi.ReplyWindowScale = iattr.Data[0] + tpi.ReplyWindowScale = ad.Uint8() case ctaProtoInfoTCPFlagsOriginal: - tpi.OriginalFlags = iattr.Uint16() + tpi.OriginalFlags = ad.Uint16() case ctaProtoInfoTCPFlagsReply: - tpi.ReplyFlags = iattr.Uint16() + tpi.ReplyFlags = ad.Uint16() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaProtoInfoTCP) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a ProtoInfoTCP into a netfilter.Attribute. @@ -319,34 +200,26 @@ type ProtoInfoDCCP struct { } // unmarshal unmarshals a netfilter.Attribute into a ProtoInfoTCP. -func (dpi *ProtoInfoDCCP) unmarshal(attr netfilter.Attribute) error { - - if protoInfoType(attr.Type) != ctaProtoInfoDCCP { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaProtoInfoDCCP) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnProtoInfoDCCP) - } +func (dpi *ProtoInfoDCCP) unmarshal(ad *netlink.AttributeDecoder) error { - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedChildren, opUnProtoInfoDCCP) } - for _, iattr := range attr.Children { - switch protoInfoDCCPType(iattr.Type) { + for ad.Next() { + switch protoInfoDCCPType(ad.Type()) { case ctaProtoInfoDCCPState: - dpi.State = iattr.Data[0] + dpi.State = ad.Uint8() case ctaProtoInfoDCCPRole: - dpi.Role = iattr.Data[0] + dpi.Role = ad.Uint8() case ctaProtoInfoDCCPHandshakeSeq: - dpi.HandshakeSeq = iattr.Uint64() + dpi.HandshakeSeq = ad.Uint64() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaProtoInfoDCCP) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a ProtoInfoDCCP into a netfilter.Attribute. @@ -368,34 +241,26 @@ type ProtoInfoSCTP struct { } // unmarshal unmarshals a netfilter.Attribute into a ProtoInfoSCTP. -func (spi *ProtoInfoSCTP) unmarshal(attr netfilter.Attribute) error { +func (spi *ProtoInfoSCTP) unmarshal(ad *netlink.AttributeDecoder) error { - if protoInfoType(attr.Type) != ctaProtoInfoSCTP { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaProtoInfoSCTP) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnProtoInfoSCTP) - } - - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedChildren, opUnProtoInfoSCTP) } - for _, iattr := range attr.Children { - switch protoInfoSCTPType(iattr.Type) { + for ad.Next() { + switch protoInfoSCTPType(ad.Type()) { case ctaProtoInfoSCTPState: - spi.State = iattr.Data[0] + spi.State = ad.Uint8() case ctaProtoInfoSCTPVTagOriginal: - spi.VTagOriginal = iattr.Uint32() + spi.VTagOriginal = ad.Uint32() case ctaProtoInfoSCTPVtagReply: - spi.VTagReply = iattr.Uint32() + spi.VTagReply = ad.Uint32() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaProtoInfoSCTP) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a ProtoInfoSCTP into a netfilter.Attribute. @@ -438,41 +303,29 @@ func (ctr Counter) filled() bool { } // unmarshal unmarshals a nested counter attribute into a Counter structure. -func (ctr *Counter) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaCountersOrig && - attributeType(attr.Type) != ctaCountersReply { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaCountersOrigReplyCat) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnCounter) - } +func (ctr *Counter) unmarshal(ad *netlink.AttributeDecoder) error { // A Counter consists of packet and byte attributes but may have // help attributes as well if nf_conntrack_helper enabled - if len(attr.Children) < 2 { + if ad.Len() < 2 { return errors.Wrap(errNeedChildren, opUnCounter) } - // Set Direction to true if it's a reply counter - ctr.Direction = attributeType(attr.Type) == ctaCountersReply - - for _, iattr := range attr.Children { - switch counterType(iattr.Type) { + for ad.Next() { + switch counterType(ad.Type()) { case ctaCountersPackets: - ctr.Packets = iattr.Uint64() + ctr.Packets = ad.Uint64() case ctaCountersBytes: - ctr.Bytes = iattr.Uint64() + ctr.Bytes = ad.Uint64() case ctaCountersPad: // Ignore padding attributes that show up if nf_conntrack_helper is enabled. continue default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaCountersOrigReplyCat) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // A Timestamp represents the start and end time of a flow. @@ -484,33 +337,25 @@ type Timestamp struct { } // unmarshal unmarshals a nested timestamp attribute into a conntrack.Timestamp structure. -func (ts *Timestamp) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaTimestamp { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaTimestamp) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnTimestamp) - } +func (ts *Timestamp) unmarshal(ad *netlink.AttributeDecoder) error { // A Timestamp will always have at least a start time - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedSingleChild, opUnTimestamp) } - for _, iattr := range attr.Children { - switch timestampType(iattr.Type) { + for ad.Next() { + switch timestampType(ad.Type()) { case ctaTimestampStart: - ts.Start = time.Unix(0, iattr.Int64()) + ts.Start = time.Unix(0, int64(ad.Uint64())) case ctaTimestampStop: - ts.Stop = time.Unix(0, iattr.Int64()) + ts.Stop = time.Unix(0, int64(ad.Uint64())) default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaTimestamp) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // A Security structure holds the security info belonging to a connection. @@ -519,31 +364,23 @@ func (ts *Timestamp) unmarshal(attr netfilter.Attribute) error { type Security string // unmarshal unmarshals a nested security attribute into a conntrack.Security structure. -func (sec *Security) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaSecCtx { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaSecCtx) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnSecurity) - } +func (sec *Security) unmarshal(ad *netlink.AttributeDecoder) error { // A SecurityContext has at least a name - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedChildren, opUnSecurity) } - for _, iattr := range attr.Children { - switch securityType(iattr.Type) { + for ad.Next() { + switch securityType(ad.Type()) { case ctaSecCtxName: - *sec = Security(iattr.Data) + *sec = Security(ad.Bytes()) default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaSecCtx) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // SequenceAdjust represents a TCP sequence number adjustment event. @@ -575,39 +412,27 @@ func (seq SequenceAdjust) filled() bool { // unmarshal unmarshals a nested sequence adjustment attribute into a // conntrack.SequenceAdjust structure. -func (seq *SequenceAdjust) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaSeqAdjOrig && - attributeType(attr.Type) != ctaSeqAdjReply { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaSeqAdjOrigReplyCat) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnSeqAdj) - } +func (seq *SequenceAdjust) unmarshal(ad *netlink.AttributeDecoder) error { // A SequenceAdjust message should come with at least 1 child. - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedSingleChild, opUnSeqAdj) } - // Set Direction to true if it's a reply adjustment - seq.Direction = attributeType(attr.Type) == ctaSeqAdjReply - - for _, iattr := range attr.Children { - switch seqAdjType(iattr.Type) { + for ad.Next() { + switch seqAdjType(ad.Type()) { case ctaSeqAdjCorrectionPos: - seq.Position = iattr.Uint32() + seq.Position = ad.Uint32() case ctaSeqAdjOffsetBefore: - seq.OffsetBefore = iattr.Uint32() + seq.OffsetBefore = ad.Uint32() case ctaSeqAdjOffsetAfter: - seq.OffsetAfter = iattr.Uint32() + seq.OffsetAfter = ad.Uint32() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaSeqAdjOrigReplyCat) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a SequenceAdjust into a netfilter.Attribute. @@ -642,34 +467,26 @@ func (sp SynProxy) filled() bool { } // unmarshal unmarshals a SYN proxy attribute into a SynProxy structure. -func (sp *SynProxy) unmarshal(attr netfilter.Attribute) error { - - if attributeType(attr.Type) != ctaSynProxy { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaSynProxy) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnSynProxy) - } +func (sp *SynProxy) unmarshal(ad *netlink.AttributeDecoder) error { - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedSingleChild, opUnSynProxy) } - for _, iattr := range attr.Children { - switch synProxyType(iattr.Type) { + for ad.Next() { + switch synProxyType(ad.Type()) { case ctaSynProxyISN: - sp.ISN = iattr.Uint32() + sp.ISN = ad.Uint32() case ctaSynProxyITS: - sp.ITS = iattr.Uint32() + sp.ITS = ad.Uint32() case ctaSynProxyTSOff: - sp.TSOff = iattr.Uint32() + sp.TSOff = ad.Uint32() default: - return fmt.Errorf(errAttributeChild, iattr.Type, ctaSynProxy) + return fmt.Errorf(errAttributeChild, ad.Type()) } } - return nil + return ad.Err() } // marshal marshals a SynProxy into a netfilter.Attribute. diff --git a/attribute_types_test.go b/attribute_types_test.go index 9476371..9f8d1d5 100644 --- a/attribute_types_test.go +++ b/attribute_types_test.go @@ -6,62 +6,44 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" + + "github.com/mdlayher/netlink" "github.com/ti-mo/netfilter" ) var ( - nfaBadType = netfilter.Attribute{Type: uint16(ctaUnspec)} - nfaTooShort = netfilter.Attribute{} + adEmpty, _ = netfilter.NewAttributeDecoder([]byte{}) + adOneUnknown = *mustDecodeAttribute(netfilter.Attribute{Type: uint16(ctaUnspec)}) + adTwoUnknown = *mustDecodeAttributes([]netfilter.Attribute{{Type: uint16(ctaUnspec)}, {Type: uint16(ctaUnspec)}}) + adThreeUnknown = *mustDecodeAttributes([]netfilter.Attribute{{Type: uint16(ctaUnspec)}, {Type: uint16(ctaUnspec)}, {Type: uint16(ctaUnspec)}}) ) -func TestAttributeTypeString(t *testing.T) { - if attributeType(255).String() == "" { - t.Fatal("AttributeType string representation empty - did you run `go generate`?") - } +// mustDecodeAttribute wraps attr in a list of netfilter.Attributes and calls +// mustDecodeAttributes. +func mustDecodeAttribute(attr netfilter.Attribute) *netlink.AttributeDecoder { + return mustDecodeAttributes([]netfilter.Attribute{attr}) } -func TestAttributeNum16(t *testing.T) { - - n16 := num16{} - assert.Equal(t, false, n16.filled()) - assert.Equal(t, true, num16{Type: 1}.filled()) - assert.Equal(t, true, num16{Value: 1}.filled()) - - assert.EqualError(t, n16.unmarshal(nfaTooShort), errIncorrectSize.Error()) +// mustDecodeAttributes marshals a list of netfilter.Attributes and returns +// an AttributeDecoder holding the binary output of the unmarshal. +func mustDecodeAttributes(attrs []netfilter.Attribute) *netlink.AttributeDecoder { + ba, err := netfilter.MarshalAttributes(attrs) + if err != nil { + panic(err) + } - nfa := netfilter.Attribute{ - Type: uint16(ctaZone), - Data: []byte{0, 1}, + ad, err := netfilter.NewAttributeDecoder(ba) + if err != nil { + panic(err) } - assert.Nil(t, n16.unmarshal(nfa)) - assert.Equal(t, n16.String(), "1") - // Marshal with zero type (auto-fill from struct) - assert.EqualValues(t, netfilter.Attribute{Type: uint16(ctaZone), Data: []byte{0, 1}}, n16.marshal(0)) - // Marshal with explicit type parameter - assert.EqualValues(t, netfilter.Attribute{Type: uint16(ctaZone), Data: []byte{0, 1}}, n16.marshal(ctaZone)) + return ad } -func TestAttributeNum32(t *testing.T) { - - n32 := num32{} - assert.Equal(t, false, n32.filled()) - assert.Equal(t, true, num32{Type: 1}.filled()) - assert.Equal(t, true, num32{Value: 1}.filled()) - - assert.EqualError(t, n32.unmarshal(nfaTooShort), errIncorrectSize.Error()) - - nfa := netfilter.Attribute{ - Type: uint16(ctaMark), - Data: []byte{0, 1, 2, 3}, +func TestAttributeTypeString(t *testing.T) { + if attributeType(255).String() == "" { + t.Fatal("AttributeType string representation empty - did you run `go generate`?") } - assert.Nil(t, n32.unmarshal(nfa)) - assert.Equal(t, n32.String(), "66051") - - // Marshal with zero type (auto-fill from struct) - assert.EqualValues(t, netfilter.Attribute{Type: uint16(ctaMark), Data: []byte{0, 1, 2, 3}}, n32.marshal(0)) - // Marshal with explicit type parameter - assert.EqualValues(t, netfilter.Attribute{Type: uint16(ctaMark), Data: []byte{0, 1, 2, 3}}, n32.marshal(ctaMark)) } func TestAttributeHelper(t *testing.T) { @@ -71,11 +53,6 @@ func TestAttributeHelper(t *testing.T) { assert.Equal(t, true, Helper{Info: []byte{1}}.filled()) assert.Equal(t, true, Helper{Name: "1"}.filled()) - nfaNotNested := netfilter.Attribute{Type: uint16(ctaHelp)} - - assert.EqualError(t, hlp.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaHelp)) - assert.EqualError(t, hlp.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnHelper).Error()) - nfaNameInfo := netfilter.Attribute{ Type: uint16(ctaHelp), Nested: true, @@ -90,20 +67,12 @@ func TestAttributeHelper(t *testing.T) { }, }, } - assert.Nil(t, hlp.unmarshal(nfaNameInfo)) + assert.Nil(t, hlp.unmarshal(mustDecodeAttributes(nfaNameInfo.Children))) assert.EqualValues(t, hlp.marshal(), nfaNameInfo) - nfaUnknownChild := netfilter.Attribute{ - Type: uint16(ctaHelp), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaHelpUnspec), - }, - }, - } - assert.EqualError(t, hlp.unmarshal(nfaUnknownChild), fmt.Sprintf(errAttributeChild, ctaHelpUnspec, ctaHelp)) + ad := adOneUnknown + assert.EqualError(t, hlp.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) } func TestAttributeProtoInfo(t *testing.T) { @@ -114,14 +83,17 @@ func TestAttributeProtoInfo(t *testing.T) { assert.Equal(t, true, ProtoInfo{TCP: &ProtoInfoTCP{}}.filled()) assert.Equal(t, true, ProtoInfo{SCTP: &ProtoInfoSCTP{}}.filled()) - nfaNotNested := netfilter.Attribute{Type: uint16(ctaProtoInfo)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaProtoInfo), Nested: true} + assert.EqualError(t, pi.unmarshal(adEmpty), errors.Wrap(errNeedSingleChild, opUnProtoInfo).Error()) + + // Exhaust the AttributeDecoder before passing to unmarshal. + ead := mustDecodeAttribute(nfaUnspecU16) + ead.Next() + assert.NoError(t, pi.unmarshal(ead)) - assert.EqualError(t, pi.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaProtoInfo)) - assert.EqualError(t, pi.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnProtoInfo).Error()) - assert.EqualError(t, pi.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedSingleChild, opUnProtoInfo).Error()) + ad := adOneUnknown + assert.EqualError(t, pi.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) - // Attempt marshal of empty ProtoInfo, expect attribute with zero children + // Attempt marshal of empty ProtoInfo, expect attribute with zero children. assert.Len(t, pi.marshal().Children, 0) // TCP protocol info @@ -158,27 +130,13 @@ func TestAttributeProtoInfo(t *testing.T) { }, } - // Full ProtoInfoTCP unmarshal + // Full ProtoInfoTCP unmarshal. var tpi ProtoInfo - assert.Nil(t, tpi.unmarshal(nfaInfoTCP)) + assert.NoError(t, tpi.unmarshal(mustDecodeAttributes(nfaInfoTCP.Children))) // Re-marshal into netfilter Attribute assert.EqualValues(t, nfaInfoTCP, tpi.marshal()) - // Error during ProtoInfoTCP unmarshal - nfaInfoTCPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfo), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaProtoInfoTCP), - Nested: false, - }, - }, - } - - assert.EqualError(t, pi.unmarshal(nfaInfoTCPError), errors.Wrap(errNotNested, opUnProtoInfoTCP).Error()) - // DCCP protocol info nfaInfoDCCP := netfilter.Attribute{ Type: uint16(ctaProtoInfo), @@ -205,23 +163,9 @@ func TestAttributeProtoInfo(t *testing.T) { }, } - // Error during ProtoInfoDCCP unmarshal - nfaInfoDCCPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfo), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaProtoInfoDCCP), - Nested: false, - }, - }, - } - - assert.EqualError(t, pi.unmarshal(nfaInfoDCCPError), errors.Wrap(errNotNested, opUnProtoInfoDCCP).Error()) - // Full ProtoInfoDCCP unmarshal var dpi ProtoInfo - assert.Nil(t, dpi.unmarshal(nfaInfoDCCP)) + assert.Nil(t, dpi.unmarshal(mustDecodeAttributes(nfaInfoDCCP.Children))) // Re-marshal into netfilter Attribute assert.EqualValues(t, nfaInfoDCCP, dpi.marshal()) @@ -253,41 +197,14 @@ func TestAttributeProtoInfo(t *testing.T) { // Full ProtoInfoSCTP unmarshal var spi ProtoInfo - assert.Nil(t, spi.unmarshal(nfaInfoSCTP)) + assert.Nil(t, spi.unmarshal(mustDecodeAttributes(nfaInfoSCTP.Children))) // Re-marshal into netfilter Attribute assert.EqualValues(t, nfaInfoSCTP, spi.marshal()) - // Error during ProtoInfoSCTP unmarshal - nfaInfoSCTPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfo), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaProtoInfoSCTP), - Nested: false, - }, - }, - } - - assert.EqualError(t, pi.unmarshal(nfaInfoSCTPError), errors.Wrap(errNotNested, opUnProtoInfoSCTP).Error()) - - // Unknown child attribute type - nfaUnknownChild := netfilter.Attribute{ - Type: uint16(ctaProtoInfo), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaProtoInfoUnspec), - }, - }, - } - - assert.EqualError(t, pi.unmarshal(nfaUnknownChild), fmt.Sprintf(errAttributeChild, ctaProtoInfoUnspec, ctaProtoInfo)) - // Attempt to unmarshal into re-used ProtoInfo pi.TCP = &ProtoInfoTCP{} - assert.EqualError(t, pi.unmarshal(nfaInfoTCP), errReusedProtoInfo.Error()) + assert.EqualError(t, pi.unmarshal(mustDecodeAttribute(nfaInfoTCP)), errReusedProtoInfo.Error()) } func TestProtoInfoTypeString(t *testing.T) { @@ -304,12 +221,10 @@ func TestAttributeProtoInfoTCP(t *testing.T) { pit := ProtoInfoTCP{} - nfaNotNested := netfilter.Attribute{Type: uint16(ctaProtoInfoTCP)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaProtoInfoTCP), Nested: true} + assert.EqualError(t, pit.unmarshal(adEmpty), errors.Wrap(errNeedChildren, opUnProtoInfoTCP).Error()) - assert.EqualError(t, pit.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaProtoInfoTCP)) - assert.EqualError(t, pit.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnProtoInfoTCP).Error()) - assert.EqualError(t, pit.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedChildren, opUnProtoInfoTCP).Error()) + ad := adThreeUnknown + assert.EqualError(t, pit.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaProtoInfoTCP := netfilter.Attribute{ Type: uint16(ctaProtoInfoTCP), @@ -329,40 +244,25 @@ func TestAttributeProtoInfoTCP(t *testing.T) { }, { Type: uint16(ctaProtoInfoTCPWScaleOriginal), - Data: []byte{0, 4}, + Data: []byte{4}, }, { Type: uint16(ctaProtoInfoTCPWScaleReply), - Data: []byte{0, 5}, + Data: []byte{5}, }, }, } - - nfaProtoInfoTCPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfoTCP), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaProtoInfoTCPUnspec)}, - {Type: uint16(ctaProtoInfoTCPUnspec)}, - {Type: uint16(ctaProtoInfoTCPUnspec)}, - }, - } - - assert.Nil(t, pit.unmarshal(nfaProtoInfoTCP)) - assert.EqualError(t, pit.unmarshal(nfaProtoInfoTCPError), fmt.Sprintf(errAttributeChild, ctaProtoInfoTCPUnspec, ctaProtoInfoTCP)) - + assert.NoError(t, pit.unmarshal(mustDecodeAttributes(nfaProtoInfoTCP.Children))) } func TestAttributeProtoInfoDCCP(t *testing.T) { pid := ProtoInfoDCCP{} - nfaNotNested := netfilter.Attribute{Type: uint16(ctaProtoInfoDCCP)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaProtoInfoDCCP), Nested: true} + assert.EqualError(t, pid.unmarshal(adEmpty), errors.Wrap(errNeedChildren, opUnProtoInfoDCCP).Error()) - assert.EqualError(t, pid.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaProtoInfoDCCP)) - assert.EqualError(t, pid.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnProtoInfoDCCP).Error()) - assert.EqualError(t, pid.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedChildren, opUnProtoInfoDCCP).Error()) + ad := adThreeUnknown + assert.EqualError(t, pid.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaProtoInfoDCCP := netfilter.Attribute{ Type: uint16(ctaProtoInfoDCCP), @@ -382,32 +282,17 @@ func TestAttributeProtoInfoDCCP(t *testing.T) { }, }, } - - nfaProtoInfoDCCPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfoDCCP), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaProtoInfoDCCPUnspec)}, - {Type: uint16(ctaProtoInfoDCCPUnspec)}, - {Type: uint16(ctaProtoInfoDCCPUnspec)}, - }, - } - - assert.Nil(t, pid.unmarshal(nfaProtoInfoDCCP)) - assert.EqualError(t, pid.unmarshal(nfaProtoInfoDCCPError), fmt.Sprintf(errAttributeChild, ctaProtoInfoTCPUnspec, ctaProtoInfoDCCP)) - + assert.NoError(t, pid.unmarshal(mustDecodeAttributes(nfaProtoInfoDCCP.Children))) } func TestAttributeProtoInfoSCTP(t *testing.T) { pid := ProtoInfoSCTP{} - nfaNotNested := netfilter.Attribute{Type: uint16(ctaProtoInfoSCTP)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaProtoInfoSCTP), Nested: true} + assert.EqualError(t, pid.unmarshal(adEmpty), errors.Wrap(errNeedChildren, opUnProtoInfoSCTP).Error()) - assert.EqualError(t, pid.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaProtoInfoSCTP)) - assert.EqualError(t, pid.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnProtoInfoSCTP).Error()) - assert.EqualError(t, pid.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedChildren, opUnProtoInfoSCTP).Error()) + ad := adOneUnknown + assert.EqualError(t, pid.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaProtoInfoSCTP := netfilter.Attribute{ Type: uint16(ctaProtoInfoSCTP), @@ -427,20 +312,7 @@ func TestAttributeProtoInfoSCTP(t *testing.T) { }, }, } - - nfaProtoInfoSCTPError := netfilter.Attribute{ - Type: uint16(ctaProtoInfoSCTP), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaProtoInfoSCTPUnspec)}, - {Type: uint16(ctaProtoInfoSCTPUnspec)}, - {Type: uint16(ctaProtoInfoSCTPUnspec)}, - }, - } - - assert.Nil(t, pid.unmarshal(nfaProtoInfoSCTP)) - assert.EqualError(t, pid.unmarshal(nfaProtoInfoSCTPError), fmt.Sprintf(errAttributeChild, ctaProtoInfoTCPUnspec, ctaProtoInfoSCTP)) - + assert.NoError(t, pid.unmarshal(mustDecodeAttributes(nfaProtoInfoSCTP.Children))) } func TestAttributeCounters(t *testing.T) { @@ -455,12 +327,8 @@ func TestAttributeCounters(t *testing.T) { for _, at := range attrTypes { t.Run(at.String(), func(t *testing.T) { - nfaNotNested := netfilter.Attribute{Type: uint16(at)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(at), Nested: true} - assert.EqualError(t, ctr.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaCountersOrigReplyCat)) - assert.EqualError(t, ctr.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnCounter).Error()) - assert.EqualError(t, ctr.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedChildren, opUnCounter).Error()) + assert.EqualError(t, ctr.unmarshal(adEmpty), errors.Wrap(errNeedChildren, opUnCounter).Error()) nfaCounter := netfilter.Attribute{ Type: uint16(at), @@ -480,24 +348,10 @@ func TestAttributeCounters(t *testing.T) { }, }, } + assert.NoError(t, ctr.unmarshal(mustDecodeAttributes(nfaCounter.Children))) - nfaCounterError := netfilter.Attribute{ - Type: uint16(at), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaCountersUnspec)}, - {Type: uint16(ctaCountersUnspec)}, - }, - } - - assert.Nil(t, ctr.unmarshal(nfaCounter)) - assert.EqualError(t, ctr.unmarshal(nfaCounterError), fmt.Sprintf(errAttributeChild, ctaCountersUnspec, ctaCountersOrigReplyCat)) - - if at == ctaCountersOrig { - assert.Equal(t, "[orig: 0 pkts/0 B]", ctr.String()) - } else { - assert.Equal(t, "[reply: 0 pkts/0 B]", ctr.String()) - } + ad := adTwoUnknown + assert.EqualError(t, ctr.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) }) } } @@ -506,12 +360,10 @@ func TestAttributeTimestamp(t *testing.T) { ts := Timestamp{} - nfaNotNested := netfilter.Attribute{Type: uint16(ctaTimestamp)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaTimestamp), Nested: true} + assert.EqualError(t, ts.unmarshal(adEmpty), errors.Wrap(errNeedSingleChild, opUnTimestamp).Error()) - assert.EqualError(t, ts.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaTimestamp)) - assert.EqualError(t, ts.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnTimestamp).Error()) - assert.EqualError(t, ts.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedSingleChild, opUnTimestamp).Error()) + ad := adOneUnknown + assert.EqualError(t, ts.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaTimestamp := netfilter.Attribute{ Type: uint16(ctaTimestamp), @@ -527,30 +379,17 @@ func TestAttributeTimestamp(t *testing.T) { }, }, } - - nfaTimestampError := netfilter.Attribute{ - Type: uint16(ctaTimestamp), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaTimestampUnspec)}, - }, - } - - assert.Nil(t, ts.unmarshal(nfaTimestamp)) - assert.EqualError(t, ts.unmarshal(nfaTimestampError), fmt.Sprintf(errAttributeChild, ctaTimestampUnspec, ctaTimestamp)) - + assert.NoError(t, ts.unmarshal(mustDecodeAttributes(nfaTimestamp.Children))) } func TestAttributeSecCtx(t *testing.T) { var sc Security - nfaNotNested := netfilter.Attribute{Type: uint16(ctaSecCtx)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaSecCtx), Nested: true} + assert.EqualError(t, sc.unmarshal(adEmpty), errors.Wrap(errNeedChildren, opUnSecurity).Error()) - assert.EqualError(t, sc.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaSecCtx)) - assert.EqualError(t, sc.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnSecurity).Error()) - assert.EqualError(t, sc.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedChildren, opUnSecurity).Error()) + ad := adOneUnknown + assert.EqualError(t, sc.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaSecurity := netfilter.Attribute{ Type: uint16(ctaSecCtx), @@ -562,18 +401,7 @@ func TestAttributeSecCtx(t *testing.T) { }, }, } - - nfaSecurityError := netfilter.Attribute{ - Type: uint16(ctaSecCtx), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaSecCtxUnspec)}, - }, - } - - assert.Nil(t, sc.unmarshal(nfaSecurity)) - assert.EqualError(t, sc.unmarshal(nfaSecurityError), fmt.Sprintf(errAttributeChild, ctaSecCtxUnspec, ctaSecCtx)) - + assert.NoError(t, sc.unmarshal(mustDecodeAttributes(nfaSecurity.Children))) } func TestAttributeSeqAdj(t *testing.T) { @@ -588,12 +416,11 @@ func TestAttributeSeqAdj(t *testing.T) { for _, at := range attrTypes { t.Run(at.String(), func(t *testing.T) { - nfaNotNested := netfilter.Attribute{Type: uint16(at)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(at), Nested: true} - assert.EqualError(t, sa.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaSeqAdjOrigReplyCat)) - assert.EqualError(t, sa.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnSeqAdj).Error()) - assert.EqualError(t, sa.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedSingleChild, opUnSeqAdj).Error()) + assert.EqualError(t, sa.unmarshal(adEmpty), errors.Wrap(errNeedSingleChild, opUnSeqAdj).Error()) + + ad := adOneUnknown + assert.EqualError(t, sa.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaSeqAdj := netfilter.Attribute{ Type: uint16(at), @@ -613,26 +440,17 @@ func TestAttributeSeqAdj(t *testing.T) { }, }, } + assert.NoError(t, sa.unmarshal(mustDecodeAttributes(nfaSeqAdj.Children))) - nfaSeqAdjError := netfilter.Attribute{ - Type: uint16(at), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaSeqAdjUnspec)}, - {Type: uint16(ctaSeqAdjUnspec)}, - }, + // The AttributeDecoder unmarshal() no longer has the tuple direction, set it manually. + // TODO: Remove when marshal() switches to AttributeEncoder. + if at == ctaSeqAdjReply { + sa.Direction = true + } else { + sa.Direction = false } - assert.Nil(t, sa.unmarshal(nfaSeqAdj)) - assert.EqualError(t, sa.unmarshal(nfaSeqAdjError), fmt.Sprintf(errAttributeChild, ctaSeqAdjUnspec, ctaSeqAdjOrigReplyCat)) - assert.EqualValues(t, nfaSeqAdj, sa.marshal()) - - if at == ctaSeqAdjOrig { - assert.Equal(t, "[dir: orig, pos: 0, before: 0, after: 0]", sa.String()) - } else { - assert.Equal(t, "[dir: reply, pos: 0, before: 0, after: 0]", sa.String()) - } }) } } @@ -645,12 +463,10 @@ func TestAttributeSynProxy(t *testing.T) { assert.Equal(t, true, SynProxy{ITS: 1}.filled()) assert.Equal(t, true, SynProxy{TSOff: 1}.filled()) - nfaNotNested := netfilter.Attribute{Type: uint16(ctaSynProxy)} - nfaNestedNoChildren := netfilter.Attribute{Type: uint16(ctaSynProxy), Nested: true} + assert.EqualError(t, sp.unmarshal(adEmpty), errors.Wrap(errNeedSingleChild, opUnSynProxy).Error()) - assert.EqualError(t, sp.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaSynProxy)) - assert.EqualError(t, sp.unmarshal(nfaNotNested), errors.Wrap(errNotNested, opUnSynProxy).Error()) - assert.EqualError(t, sp.unmarshal(nfaNestedNoChildren), errors.Wrap(errNeedSingleChild, opUnSynProxy).Error()) + ad := adOneUnknown + assert.EqualError(t, sp.unmarshal(&ad), fmt.Errorf(errAttributeChild, ctaUnspec).Error()) nfaSynProxy := netfilter.Attribute{ Type: uint16(ctaSynProxy), @@ -670,17 +486,7 @@ func TestAttributeSynProxy(t *testing.T) { }, }, } - - nfaSynProxyError := netfilter.Attribute{ - Type: uint16(ctaSynProxy), - Nested: true, - Children: []netfilter.Attribute{ - {Type: uint16(ctaSynProxyUnspec)}, - }, - } - - assert.Nil(t, sp.unmarshal(nfaSynProxy)) - assert.EqualError(t, sp.unmarshal(nfaSynProxyError), fmt.Sprintf(errAttributeChild, ctaSynProxyUnspec, ctaSynProxy)) + assert.NoError(t, sp.unmarshal(mustDecodeAttributes(nfaSynProxy.Children))) assert.EqualValues(t, nfaSynProxy, sp.marshal()) } diff --git a/enum.go b/enum.go index 3d77588..23668d1 100644 --- a/enum.go +++ b/enum.go @@ -295,3 +295,10 @@ const ( ) // enum ctattr_natseq is unused in the kernel source + +// Unused unspec constants. +var _ = []uint8{ + uint8(ctaHelpUnspec), uint8(ctaCountersUnspec), uint8(ctaTimestampUnspec), + uint8(ctaSecCtxUnspec), uint8(ctaProtoInfoTCPUnspec), uint8(ctaProtoInfoDCCPUnspec), + uint8(ctaProtoInfoSCTPUnspec), uint8(ctaSeqAdjUnspec), uint8(ctaSynProxyUnspec), +} diff --git a/errors.go b/errors.go index b737525..e2717c2 100644 --- a/errors.go +++ b/errors.go @@ -7,7 +7,6 @@ var ( errConnHasListeners = errors.New("Conn has existing listeners, open another to listen on more groups") errMultipartEvent = errors.New("received multicast event with more than one Netlink message") - errNested = errors.New("unexpected Nested attribute") errNotNested = errors.New("need a Nested attribute to decode this structure") errNeedSingleChild = errors.New("need (at least) 1 child attribute") errNeedChildren = errors.New("need (at least) 2 child attributes") @@ -27,9 +26,8 @@ var ( ) const ( - errUnknownEventType = "unknown event type %d" - errWorkerCount = "invalid worker count %d" - errWorkerReceive = "netlink.Receive error in listenWorker %d, exiting" - errAttributeWrongType = "attribute type '%d' is not a %s" - errAttributeChild = "child Type '%d' unknown for attribute type %s" + errUnknownEventType = "unknown event type %d" + errWorkerCount = "invalid worker count %d" + errWorkerReceive = "netlink.Receive error in listenWorker %d, exiting" + errAttributeChild = "unknown attribute child Type '%d'" ) diff --git a/event.go b/event.go index 979c424..3f110bc 100644 --- a/event.go +++ b/event.go @@ -74,25 +74,25 @@ func (e *Event) unmarshal(nlmsg netlink.Message) error { var err error - // Unmarshal a netlink.Message into netfilter.Attributes and Header - h, attrs, err := netfilter.UnmarshalNetlink(nlmsg) + // Obtain the nlmsg's Netfilter header and AttributeDecoder. + h, ad, err := netfilter.DecodeNetlink(nlmsg) if err != nil { return err } - // Decode the header to make sure we're dealing with a Conntrack event + // Decode the header to make sure we're dealing with a Conntrack event. err = e.Type.unmarshal(h) if err != nil { return err } - // Unmarshal Netfilter attributes into the event's Flow or Expect entry + // Unmarshal Netfilter attributes into the event's Flow or Expect entry. if h.SubsystemID == netfilter.NFSubsysCTNetlink { e.Flow = new(Flow) - err = e.Flow.unmarshal(attrs) + err = e.Flow.unmarshal(ad) } else if h.SubsystemID == netfilter.NFSubsysCTNetlinkExp { e.Expect = new(Expect) - err = e.Expect.unmarshal(attrs) + err = e.Expect.unmarshal(ad) } if err != nil { diff --git a/event_test.go b/event_test.go index 7d780c9..425e479 100644 --- a/event_test.go +++ b/event_test.go @@ -163,7 +163,7 @@ func TestEventUnmarshalError(t *testing.T) { // Netlink unmarshal error emptyEvent := Event{} - assert.EqualError(t, emptyEvent.unmarshal(netlink.Message{}), "expected at least 4 bytes in netlink message payload") + assert.EqualError(t, emptyEvent.unmarshal(netlink.Message{}), "unmarshaling netfilter header: expected at least 4 bytes in netlink message payload") // EventType unmarshal error, blank SubsystemID assert.EqualError(t, emptyEvent.unmarshal(netlink.Message{ diff --git a/expect.go b/expect.go index cc761ad..4c4aab7 100644 --- a/expect.go +++ b/expect.go @@ -35,34 +35,24 @@ type ExpectNAT struct { } // unmarshal unmarshals a netfilter.Attribute into an ExpectNAT. -func (en *ExpectNAT) unmarshal(attr netfilter.Attribute) error { +func (en *ExpectNAT) unmarshal(ad *netlink.AttributeDecoder) error { - if expectType(attr.Type) != ctaExpectNAT { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaExpectNAT) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnExpectNAT) - } - - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedSingleChild, opUnExpectNAT) } - for _, iattr := range attr.Children { - switch expectNATType(iattr.Type) { + for ad.Next() { + switch expectNATType(ad.Type()) { case ctaExpectNATDir: - en.Direction = iattr.Uint32() == 1 + en.Direction = ad.Uint32() == 1 case ctaExpectNATTuple: - if err := en.Tuple.unmarshal(iattr); err != nil { - return err - } + ad.Nested(en.Tuple.unmarshal) default: - return errors.Wrap(fmt.Errorf(errAttributeChild, iattr.Type, ctaExpectNAT), opUnExpectNAT) + return errors.Wrap(fmt.Errorf(errAttributeChild, ad.Type()), opUnExpectNAT) } } - return nil + return ad.Err() } func (en ExpectNAT) marshal() (netfilter.Attribute, error) { @@ -86,46 +76,48 @@ func (en ExpectNAT) marshal() (netfilter.Attribute, error) { } // unmarshal unmarshals a list of netfilter.Attributes into an Expect structure. -func (ex *Expect) unmarshal(attrs []netfilter.Attribute) error { - - for _, attr := range attrs { - - switch at := expectType(attr.Type); at { +func (ex *Expect) unmarshal(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch at := expectType(ad.Type()); at { case ctaExpectMaster: - if err := ex.TupleMaster.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(ex.TupleMaster.unmarshal) case ctaExpectTuple: - if err := ex.Tuple.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(ex.Tuple.unmarshal) case ctaExpectMask: - if err := ex.Mask.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(ex.Mask.unmarshal) case ctaExpectTimeout: - ex.Timeout = attr.Uint32() + ex.Timeout = ad.Uint32() case ctaExpectID: - ex.ID = attr.Uint32() + ex.ID = ad.Uint32() case ctaExpectHelpName: - ex.HelpName = string(attr.Data) + ex.HelpName = ad.String() case ctaExpectZone: - ex.Zone = attr.Uint16() + ex.Zone = ad.Uint16() case ctaExpectFlags: - ex.Flags = attr.Uint32() + ex.Flags = ad.Uint32() case ctaExpectClass: - ex.Class = attr.Uint32() + ex.Class = ad.Uint32() case ctaExpectNAT: - if err := ex.NAT.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnExpectNAT) } + ad.Nested(ex.NAT.unmarshal) case ctaExpectFN: - ex.Function = string(attr.Data) + ex.Function = ad.String() } } - return nil + return ad.Err() } func (ex Expect) marshal() ([]netfilter.Attribute, error) { @@ -194,12 +186,12 @@ func unmarshalExpect(nlm netlink.Message) (Expect, error) { var ex Expect - _, nfa, err := netfilter.UnmarshalNetlink(nlm) + _, ad, err := netfilter.DecodeNetlink(nlm) if err != nil { return ex, err } - err = ex.unmarshal(nfa) + err = ex.unmarshal(ad) if err != nil { return ex, err } diff --git a/expect_test.go b/expect_test.go index b6f4bda..560f281 100644 --- a/expect_test.go +++ b/expect_test.go @@ -273,15 +273,16 @@ func TestExpectUnmarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var ex Expect - err := ex.unmarshal(tt.attrs) + err := ex.unmarshal(mustDecodeAttributes(tt.attrs)) - if err != nil || tt.err != nil { + if tt.err != nil { require.Error(t, err) - require.Error(t, tt.err) require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.exp, ex); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) } @@ -291,7 +292,7 @@ func TestExpectUnmarshal(t *testing.T) { for _, tt := range corpusExpectUnmarshalError { t.Run(tt.name, func(t *testing.T) { var ex Expect - assert.EqualError(t, ex.unmarshal([]netfilter.Attribute{tt.nfa}), tt.errStr) + assert.EqualError(t, ex.unmarshal(mustDecodeAttributes([]netfilter.Attribute{tt.nfa})), tt.errStr) }) } } @@ -395,25 +396,21 @@ func TestExpectMarshal(t *testing.T) { var corpusExpectNAT = []struct { name string - attr netfilter.Attribute + attr []netfilter.Attribute enat ExpectNAT err error }{ { name: "simple direction, tuple unmarshal", - attr: netfilter.Attribute{ - Type: uint16(ctaExpectNAT), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaExpectNATDir), - Data: []byte{0x00, 0x00, 0x00, 0x01}, - }, - { - Type: uint16(ctaExpectNATTuple), - Nested: true, - Children: nfaIPPT, - }, + attr: []netfilter.Attribute{ + { + Type: uint16(ctaExpectNATDir), + Data: []byte{0x00, 0x00, 0x00, 0x01}, + }, + { + Type: uint16(ctaExpectNATTuple), + Nested: true, + Children: nfaIPPT, }, }, enat: ExpectNAT{ @@ -421,50 +418,10 @@ var corpusExpectNAT = []struct { Tuple: flowIPPT, }, }, - { - name: "error bad tuple", - attr: netfilter.Attribute{ - Type: uint16(ctaExpectNAT), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: uint16(ctaExpectNATDir), - Data: []byte{0x00, 0x00, 0x00, 0x00}, - }, - { - Type: uint16(ctaExpectNATTuple), - }, - }, - }, - err: errors.New("Tuple unmarshal: need a Nested attribute to decode this structure"), - }, { name: "error unknown type", - attr: netfilter.Attribute{Type: 255}, - err: fmt.Errorf(errAttributeWrongType, 255, ctaExpectNAT), - }, - { - name: "error not nested", - attr: netfilter.Attribute{Type: uint16(ctaExpectNAT)}, - err: errors.Wrap(errNotNested, opUnExpectNAT), - }, - { - name: "error no children", - attr: netfilter.Attribute{Type: uint16(ctaExpectNAT), Nested: true}, - err: errors.Wrap(errNeedSingleChild, opUnExpectNAT), - }, - { - name: "error unknown child type", - attr: netfilter.Attribute{ - Type: uint16(ctaExpectNAT), - Nested: true, - Children: []netfilter.Attribute{ - { - Type: 255, - }, - }, - }, - err: errors.Wrap(fmt.Errorf(errAttributeChild, 255, ctaExpectNAT), opUnExpectNAT), + attr: []netfilter.Attribute{{Type: 255}}, + err: errors.Wrap(fmt.Errorf(errAttributeChild, 255), opUnExpectNAT), }, } @@ -474,15 +431,16 @@ func TestExpectNATUnmarshal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var enat ExpectNAT - err := enat.unmarshal(tt.attr) + err := enat.unmarshal(mustDecodeAttributes(tt.attr)) - if err != nil || tt.err != nil { + if tt.err != nil { require.Error(t, err) - require.Error(t, tt.err) require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.enat, enat); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) } @@ -534,7 +492,6 @@ func BenchmarkExpectUnmarshal(b *testing.B) { b.ReportAllocs() var tests []netfilter.Attribute - var ex Expect // Collect all tests from corpus that aren't expected to fail for _, test := range corpusExpect { @@ -543,7 +500,14 @@ func BenchmarkExpectUnmarshal(b *testing.B) { } } + // Marshal these netfilter attributes and return netlink.AttributeDecoder. + ad := mustDecodeAttributes(tests) + for n := 0; n < b.N; n++ { - _ = ex.unmarshal(tests) + // Make a new copy of the AD to avoid reinstantiation. + iad := ad + + var ex Expect + _ = ex.unmarshal(iad) } } diff --git a/flow.go b/flow.go index 72c7fd1..07617f1 100644 --- a/flow.go +++ b/flow.go @@ -4,6 +4,7 @@ import ( "net" "github.com/mdlayher/netlink" + "github.com/pkg/errors" "github.com/ti-mo/netfilter" ) @@ -68,109 +69,124 @@ func NewFlow(proto uint8, status StatusFlag, srcAddr, destAddr net.IP, srcPort, } // unmarshal unmarshals a list of netfilter.Attributes into a Flow structure. -func (f *Flow) unmarshal(attrs []netfilter.Attribute) error { +func (f *Flow) unmarshal(ad *netlink.AttributeDecoder) error { - for _, attr := range attrs { + var at attributeType - switch at := attributeType(attr.Type); at { + for ad.Next() { + at = attributeType(ad.Type()) + + switch at { // CTA_TIMEOUT is the time until the Conntrack entry is automatically destroyed. case ctaTimeout: - f.Timeout = attr.Uint32() + f.Timeout = ad.Uint32() // CTA_ID is the tuple hash value generated by the kernel. It can be relied on for flow identification. case ctaID: - f.ID = attr.Uint32() + f.ID = ad.Uint32() // CTA_USE is the flow's kernel-internal refcount. case ctaUse: - f.Use = attr.Uint32() + f.Use = ad.Uint32() // CTA_MARK is the connection's connmark case ctaMark: - f.Mark = attr.Uint32() + f.Mark = ad.Uint32() // CTA_ZONE describes the Conntrack zone the flow is placed in. This can be combined with a CTA_TUPLE_ZONE // to specify which zone an event originates from. case ctaZone: - f.Zone = attr.Uint16() + f.Zone = ad.Uint16() // CTA_LABELS is a binary bitfield attached to a connection that is sent in // events when changed, as well as in response to dump queries. case ctaLabels: - f.Labels = attr.Data + f.Labels = ad.Bytes() // CTA_LABELS_MASK is never sent by the kernel, but it can be used // in set / update queries to mask label operations on the kernel state table. // it needs to be exactly as wide as the CTA_LABELS field it intends to mask. case ctaLabelsMask: - f.LabelsMask = attr.Data + f.LabelsMask = ad.Bytes() + // CTA_STATUS is a bitfield of the state of the connection + // (eg. if packets are seen in both directions, etc.) + case ctaStatus: + f.Status.Value = StatusFlag(ad.Uint32()) // CTA_TUPLE_* attributes are nested and contain source and destination values for: // - the IPv4/IPv6 addresses involved // - ports used in the connection // - (optional) the Conntrack Zone of the originating/replying side of the flow case ctaTupleOrig: - if err := f.TupleOrig.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(f.TupleOrig.unmarshal) case ctaTupleReply: - if err := f.TupleReply.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(f.TupleReply.unmarshal) case ctaTupleMaster: - if err := f.TupleMaster.unmarshal(attr); err != nil { - return err - } - // CTA_STATUS is a bitfield of the state of the connection - // (eg. if packets are seen in both directions, etc.) - case ctaStatus: - if err := f.Status.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTup) } + ad.Nested(f.TupleMaster.unmarshal) // CTA_PROTOINFO is sent for TCP, DCCP and SCTP protocols only. It conveys extra metadata // about the state flags seen on the wire. Update events are sent when these change. case ctaProtoInfo: - if err := f.ProtoInfo.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnProtoInfo) } + ad.Nested(f.ProtoInfo.unmarshal) case ctaHelp: - if err := f.Helper.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnHelper) } + ad.Nested(f.Helper.unmarshal) // CTA_COUNTERS_* attributes are nested and contain byte and packet counters for flows in either direction. case ctaCountersOrig: - if err := f.CountersOrig.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnCounter) } + ad.Nested(f.CountersOrig.unmarshal) case ctaCountersReply: - if err := f.CountersReply.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnCounter) } + f.CountersReply.Direction = true + ad.Nested(f.CountersReply.unmarshal) // CTA_SECCTX is the SELinux security context of a Conntrack entry. case ctaSecCtx: - if err := f.SecurityContext.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnSecurity) } + ad.Nested(f.SecurityContext.unmarshal) // CTA_TIMESTAMP is a nested attribute that describes the start and end timestamp of a flow. // It is sent by the kernel with dumps and DESTROY events. case ctaTimestamp: - if err := f.Timestamp.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnTimestamp) } + ad.Nested(f.Timestamp.unmarshal) // CTA_SEQADJ_* is generalized TCP window adjustment metadata. It is not (yet) emitted in Conntrack events. // The reason for its introduction is outlined in https://lwn.net/Articles/563151. // Patch set is at http://www.spinics.net/lists/netdev/msg245785.html. case ctaSeqAdjOrig: - if err := f.SeqAdjOrig.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnSeqAdj) } + ad.Nested(f.SeqAdjOrig.unmarshal) case ctaSeqAdjReply: - if err := f.SeqAdjReply.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnSeqAdj) } + f.SeqAdjReply.Direction = true + ad.Nested(f.SeqAdjReply.unmarshal) // CTA_SYNPROXY are the connection's SYN proxy parameters case ctaSynProxy: - if err := f.SynProxy.unmarshal(attr); err != nil { - return err + if !nestedFlag(ad.TypeFlags()) { + return errors.Wrap(errNotNested, opUnSynProxy) } + ad.Nested(f.SynProxy.unmarshal) } } - return nil + return ad.Err() } // marshal marshals a Flow object into a list of netfilter.Attributes. @@ -202,7 +218,9 @@ func (f Flow) marshal() ([]netfilter.Attribute, error) { // Optional attributes appended to the list when filled if f.Timeout != 0 { - attrs = append(attrs, num32{Value: f.Timeout}.marshal(ctaTimeout)) + a := netfilter.Attribute{Type: uint16(ctaTimeout)} + a.PutUint32(f.Timeout) + attrs = append(attrs, a) } if f.Status.Value != 0 { @@ -210,11 +228,15 @@ func (f Flow) marshal() ([]netfilter.Attribute, error) { } if f.Mark != 0 { - attrs = append(attrs, num32{Value: f.Mark}.marshal(ctaMark)) + a := netfilter.Attribute{Type: uint16(ctaMark)} + a.PutUint32(f.Mark) + attrs = append(attrs, a) } if f.Zone != 0 { - attrs = append(attrs, num16{Value: f.Zone}.marshal(ctaZone)) + a := netfilter.Attribute{Type: uint16(ctaZone)} + a.PutUint16(f.Zone) + attrs = append(attrs, a) } if f.ProtoInfo.filled() { @@ -254,12 +276,12 @@ func unmarshalFlow(nlm netlink.Message) (Flow, error) { var f Flow - _, nfa, err := netfilter.UnmarshalNetlink(nlm) + _, ad, err := netfilter.DecodeNetlink(nlm) if err != nil { return f, err } - err = f.unmarshal(nfa) + err = f.unmarshal(ad) if err != nil { return f, err } diff --git a/flow_test.go b/flow_test.go index 9feb275..62f6422 100644 --- a/flow_test.go +++ b/flow_test.go @@ -365,11 +365,6 @@ var ( nfa: netfilter.Attribute{Type: uint16(ctaTupleMaster)}, errStr: "Tuple unmarshal: need a Nested attribute to decode this structure", }, - { - name: "error unmarshal status", - nfa: netfilter.Attribute{Type: uint16(ctaStatus), Nested: true}, - errStr: "Status unmarshal: unexpected Nested attribute", - }, { name: "error unmarshal protoinfo", nfa: netfilter.Attribute{Type: uint16(ctaProtoInfo)}, @@ -422,14 +417,16 @@ func TestFlowUnmarshal(t *testing.T) { for _, tt := range corpusFlow { t.Run(tt.name, func(t *testing.T) { var f Flow - err := f.unmarshal(tt.attrs) + err := f.unmarshal(mustDecodeAttributes(tt.attrs)) - if err != nil || tt.err != nil { + if tt.err != nil { require.Error(t, err) - require.EqualError(t, tt.err, err.Error()) + require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.flow, f); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) } @@ -439,7 +436,7 @@ func TestFlowUnmarshal(t *testing.T) { for _, tt := range corpusFlowUnmarshalError { t.Run(tt.name, func(t *testing.T) { var f Flow - assert.EqualError(t, f.unmarshal([]netfilter.Attribute{tt.nfa}), tt.errStr) + assert.EqualError(t, f.unmarshal(mustDecodeAttributes([]netfilter.Attribute{tt.nfa})), tt.errStr) }) } } @@ -480,7 +477,7 @@ func TestFlowMarshal(t *testing.T) { func TestUnmarshalFlowsError(t *testing.T) { _, err := unmarshalFlows([]netlink.Message{{}}) - assert.EqualError(t, err, "expected at least 4 bytes in netlink message payload") + assert.EqualError(t, err, "unmarshaling netfilter header: expected at least 4 bytes in netlink message payload") // Use netfilter.MarshalNetlink to assemble a Netlink message with a single attribute with empty data. // Cause a random error in unmarshalFlows to cover error return. @@ -544,8 +541,14 @@ func BenchmarkFlowUnmarshal(b *testing.B) { } } + // Marshal these netfilter attributes and return netlink.AttributeDecoder. + ad := mustDecodeAttributes(tests) + for n := 0; n < b.N; n++ { + // Make a new copy of the AD to avoid reinstantiation. + iad := ad + var f Flow - _ = f.unmarshal(tests) + _ = f.unmarshal(iad) } } diff --git a/go.mod b/go.mod index f6c83eb..60288e3 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/ti-mo/conntrack go 1.12 require ( - github.com/google/go-cmp v0.2.0 - github.com/mdlayher/netlink v0.0.0-20190313131330-258ea9dff42c + github.com/google/go-cmp v0.3.1 + github.com/mdlayher/netlink v1.0.1-0.20191210152442-a1644773bc99 github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.4.0 - github.com/ti-mo/netfilter v0.2.0 + github.com/ti-mo/netfilter v0.3.0 github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc - golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc + golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 ) diff --git a/go.sum b/go.sum index 4b333cc..4b47d25 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/mdlayher/netlink v0.0.0-20190313131330-258ea9dff42c h1:qYXI+3AN4zBWsTF5drEu1akWPu2juaXPs58tZ4/GaCg= -github.com/mdlayher/netlink v0.0.0-20190313131330-258ea9dff42c/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a h1:84IpUNXj4mCR9CuCEvSiCArMbzr/TMbuPIadKDwypkI= +github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw= +github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= +github.com/mdlayher/netlink v1.0.1-0.20191210152442-a1644773bc99 h1:j14xqbiblLsxSSBc6uvABovvqAIr8mHnwaXCKRAtlkk= +github.com/mdlayher/netlink v1.0.1-0.20191210152442-a1644773bc99/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -13,18 +18,25 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ti-mo/netfilter v0.2.0 h1:mMZ70vvHTlY9y8ElWflp5nVN5kkUDvm6D1JXRgartKI= -github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU= +github.com/ti-mo/netfilter v0.3.0 h1:T+KLhuAYx6u7p/aCO/hCti1IwILERru8mXLMWKNocA4= +github.com/ti-mo/netfilter v0.3.0/go.mod h1:jkASCo4ZNGAEBjy0giTFiPV59LT5A9OXMtkHpHI8xpw= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc h1:4gbWbmmPFp4ySWICouJl6emP0MyS31yy9SrTlAGFT+g= -golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 h1:gSbV7h1NRL2G1xTg/owz62CST1oJBmxy4QpMMregXVQ= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/stats_test.go b/stats_test.go index 503f619..2aea487 100644 --- a/stats_test.go +++ b/stats_test.go @@ -78,7 +78,7 @@ func TestStatsUnmarshal(t *testing.T) { func TestUnmarshalStatsError(t *testing.T) { _, err := unmarshalStats([]netlink.Message{{}}) - assert.EqualError(t, err, "expected at least 4 bytes in netlink message payload") + assert.EqualError(t, err, "unmarshaling netfilter header: expected at least 4 bytes in netlink message payload") } func TestStatsExpectUnmarshal(t *testing.T) { @@ -115,7 +115,7 @@ func TestStatsExpectUnmarshal(t *testing.T) { func TestUnmarshalStatsExpectError(t *testing.T) { _, err := unmarshalStatsExpect([]netlink.Message{{}}) - assert.EqualError(t, err, "expected at least 4 bytes in netlink message payload") + assert.EqualError(t, err, "unmarshaling netfilter header: expected at least 4 bytes in netlink message payload") } func TestStatsGlobalUnmarshal(t *testing.T) { @@ -147,5 +147,5 @@ func TestStatsGlobalUnmarshal(t *testing.T) { func TestUnmarshalStatsGlobalError(t *testing.T) { _, err := unmarshalStatsGlobal(netlink.Message{}) - assert.EqualError(t, err, "expected at least 4 bytes in netlink message payload") + assert.EqualError(t, err, "unmarshaling netfilter header: expected at least 4 bytes in netlink message payload") } diff --git a/status.go b/status.go index 6dbd903..6c14548 100644 --- a/status.go +++ b/status.go @@ -1,8 +1,7 @@ package conntrack import ( - "fmt" - + "github.com/mdlayher/netlink" "github.com/pkg/errors" "github.com/ti-mo/netfilter" ) @@ -17,23 +16,23 @@ type Status struct { } // unmarshal unmarshals a netfilter.Attribute into a Status structure. -func (s *Status) unmarshal(attr netfilter.Attribute) error { +func (s *Status) unmarshal(ad *netlink.AttributeDecoder) error { - if attributeType(attr.Type) != ctaStatus { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaStatus) + if ad.Len() != 1 { + return errors.Wrap(errNeedSingleChild, opUnStatus) } - if attr.Nested { - return errors.Wrap(errNested, opUnStatus) + if !ad.Next() { + return ad.Err() } - if len(attr.Data) != 4 { + if len(ad.Bytes()) != 4 { return errors.Wrap(errIncorrectSize, opUnStatus) } - s.Value = StatusFlag(attr.Uint32()) + s.Value = StatusFlag(ad.Uint32()) - return nil + return ad.Err() } // marshal marshals a Status into a netfilter.Attribute. diff --git a/status_test.go b/status_test.go index 103b037..08b70b6 100644 --- a/status_test.go +++ b/status_test.go @@ -1,24 +1,31 @@ package conntrack import ( - "fmt" "testing" "github.com/google/go-cmp/cmp" + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nlenc" + "github.com/mdlayher/netlink/nltest" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ti-mo/netfilter" ) -func TestStatusError(t *testing.T) { +var nfaUnspecU16 = netfilter.Attribute{Type: uint16(ctaUnspec), Data: []byte{0, 0}} - nfaNested := netfilter.Attribute{Type: uint16(ctaStatus), Nested: true} +func TestStatusError(t *testing.T) { var s Status - assert.EqualError(t, s.unmarshal(nfaBadType), fmt.Sprintf(errAttributeWrongType, ctaUnspec, ctaStatus)) - assert.EqualError(t, s.unmarshal(nfaNested), errors.Wrap(errNested, opUnStatus).Error()) + assert.EqualError(t, s.unmarshal(adEmpty), errors.Wrap(errNeedSingleChild, opUnStatus).Error()) + assert.EqualError(t, s.unmarshal(mustDecodeAttribute(nfaUnspecU16)), errors.Wrap(errIncorrectSize, opUnStatus).Error()) + + // Exhaust the AttributeDecoder before passing to unmarshal. + ad := mustDecodeAttribute(nfaUnspecU16) + ad.Next() + assert.NoError(t, s.unmarshal(ad)) } func TestStatusMarshalTwoWay(t *testing.T) { @@ -63,7 +70,7 @@ func TestStatusMarshalTwoWay(t *testing.T) { var s Status - err := s.unmarshal(nfa) + err := s.unmarshal(mustDecodeAttribute(nfa)) if err != nil || tt.err != nil { require.Error(t, err) require.EqualError(t, tt.err, err.Error()) @@ -147,18 +154,25 @@ func TestStatusString(t *testing.T) { } func BenchmarkStatusUnmarshalAttribute(b *testing.B) { - inputs := [][]byte{ - {0x00, 0x00, 0x00, 0x01}, {0x00, 0x00, 0x00, 0x02}, {0x00, 0x00, 0x00, 0x03}, {0x00, 0x00, 0x00, 0x04}, - {0x00, 0x00, 0x00, 0x05}, {0x00, 0x00, 0x00, 0x06}, {0x00, 0x00, 0x00, 0x07}, {0x00, 0x00, 0x00, 0x08}, + + var ads []netlink.AttributeDecoder + for i := 1; i <= 8; i++ { + nla := netlink.Attribute{Data: nlenc.Uint32Bytes(uint32(i))} + ad, err := netfilter.NewAttributeDecoder(nltest.MustMarshalAttributes([]netlink.Attribute{nla})) + if err != nil { + b.Error(err) + } + ads = append(ads, *ad) } var ss Status - var nfa netfilter.Attribute - nfa.Type = uint16(ctaStatus) + var ad netlink.AttributeDecoder + adl := len(ads) for n := 0; n < b.N; n++ { - nfa.Data = inputs[n%len(inputs)] - if err := ss.unmarshal(nfa); err != nil { + // Make a fresh copy of the AttributeDecoder. + ad = ads[n%adl] + if err := ss.unmarshal(&ad); err != nil { b.Fatal(err) } } diff --git a/string_test.go b/string_test.go index 97db0a2..93e255f 100644 --- a/string_test.go +++ b/string_test.go @@ -48,18 +48,20 @@ func TestEventString(t *testing.T) { ef.Flow.CountersOrig.Bytes = 42 ef.Flow.CountersOrig.Packets = 1 + ef.Flow.CountersReply.Direction = true + ef.Flow.Labels = []byte{0xf0, 0xf0} ef.Flow.LabelsMask = []byte{0xff, 0xff} ef.Flow.Mark = 0xf000baaa ef.Flow.SeqAdjOrig = SequenceAdjust{OffsetBefore: 80, OffsetAfter: 747811, Position: 42} - ef.Flow.SeqAdjReply = SequenceAdjust{OffsetBefore: 123, OffsetAfter: 456, Position: 889999} + ef.Flow.SeqAdjReply = SequenceAdjust{Direction: true, OffsetBefore: 123, OffsetAfter: 456, Position: 889999} ef.Flow.SecurityContext = "selinux_t" assert.Equal(t, - "[EventUnknown] (Unreplied) Timeout: 0, <0, Src: 1.2.3.4:54321, Dst: [fe80::1]:80>, Zone 0, Acct: [orig: 1 pkts/42 B] [orig: 0 pkts/0 B], Label: <0xf0f0/0xffff>, Mark: <0xf000baaa>, SeqAdjOrig: [dir: orig, pos: 42, before: 80, after: 747811], SeqAdjReply: [dir: orig, pos: 889999, before: 123, after: 456], SecCtx: selinux_t", + "[EventUnknown] (Unreplied) Timeout: 0, <0, Src: 1.2.3.4:54321, Dst: [fe80::1]:80>, Zone 0, Acct: [orig: 1 pkts/42 B] [reply: 0 pkts/0 B], Label: <0xf0f0/0xffff>, Mark: <0xf000baaa>, SeqAdjOrig: [dir: orig, pos: 42, before: 80, after: 747811], SeqAdjReply: [dir: reply, pos: 889999, before: 123, after: 456], SecCtx: selinux_t", ef.String()) // Event with Expect diff --git a/tuple.go b/tuple.go index fe8398b..8889eb3 100644 --- a/tuple.go +++ b/tuple.go @@ -6,6 +6,7 @@ import ( "strconv" "syscall" + "github.com/mdlayher/netlink" "github.com/pkg/errors" "golang.org/x/sys/unix" @@ -41,41 +42,30 @@ func (t Tuple) String() string { } // unmarshal unmarshals a netfilter.Attribute into a Tuple. -func (t *Tuple) unmarshal(attr netfilter.Attribute) error { +func (t *Tuple) unmarshal(ad *netlink.AttributeDecoder) error { - if !attr.Nested { - return errors.Wrap(errNotNested, opUnTup) - } - - if len(attr.Children) < 2 { + if ad.Len() < 2 { return errors.Wrap(errNeedChildren, opUnTup) } - for _, iattr := range attr.Children { - switch tupleType(iattr.Type) { + for ad.Next() { + switch tupleType(ad.Type()) { case ctaTupleIP: var ti IPTuple - if err := ti.unmarshal(iattr); err != nil { - return err - } + ad.Nested(ti.unmarshal) t.IP = ti case ctaTupleProto: var tp ProtoTuple - if err := tp.unmarshal(iattr); err != nil { - return err - } + ad.Nested(tp.unmarshal) t.Proto = tp case ctaTupleZone: - if len(iattr.Data) != 2 { - return errIncorrectSize - } - t.Zone = iattr.Uint16() + t.Zone = ad.Uint16() default: - return errors.Wrap(fmt.Errorf(errAttributeChild, iattr.Type, attributeType(attr.Type)), opUnTup) + return errors.Wrap(fmt.Errorf(errAttributeChild, ad.Type()), opUnTup) } } - return nil + return ad.Err() } // marshal marshals a Tuple to a netfilter.Attribute. @@ -114,37 +104,31 @@ func (ipt IPTuple) filled() bool { // IPv4 addresses will be represented by a 4-byte net.IP, IPv6 addresses by 16-byte. // The net.IP object is created with the raw bytes, NOT with net.ParseIP(). // Use IP.Equal() to compare addresses in implementations and tests. -func (ipt *IPTuple) unmarshal(attr netfilter.Attribute) error { - - if tupleType(attr.Type) != ctaTupleIP { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaTupleIP) - } +func (ipt *IPTuple) unmarshal(ad *netlink.AttributeDecoder) error { - if !attr.Nested { - return errors.Wrap(errNotNested, opUnIPTup) - } - - if len(attr.Children) != 2 { + if ad.Len() != 2 { return errors.Wrap(errNeedChildren, opUnIPTup) } - for _, iattr := range attr.Children { + for ad.Next() { - if len(iattr.Data) != 4 && len(iattr.Data) != 16 { + b := ad.Bytes() + + if len(b) != 4 && len(b) != 16 { return errIncorrectSize } - switch ipTupleType(iattr.Type) { + switch ipTupleType(ad.Type()) { case ctaIPv4Src: - ipt.SourceAddress = net.IPv4(iattr.Data[0], iattr.Data[1], iattr.Data[2], iattr.Data[3]) + ipt.SourceAddress = net.IPv4(b[0], b[1], b[2], b[3]) case ctaIPv6Src: - ipt.SourceAddress = net.IP(iattr.Data) + ipt.SourceAddress = net.IP(b) case ctaIPv4Dst: - ipt.DestinationAddress = net.IPv4(iattr.Data[0], iattr.Data[1], iattr.Data[2], iattr.Data[3]) + ipt.DestinationAddress = net.IPv4(b[0], b[1], b[2], b[3]) case ctaIPv6Dst: - ipt.DestinationAddress = net.IP(iattr.Data) + ipt.DestinationAddress = net.IP(b) default: - return errors.Wrap(fmt.Errorf(errAttributeChild, iattr.Type, ctaTupleIP), opUnIPTup) + return errors.Wrap(fmt.Errorf(errAttributeChild, ad.Type()), opUnIPTup) } } @@ -208,24 +192,16 @@ func (pt ProtoTuple) filled() bool { } // unmarshal unmarshals a netfilter.Attribute into a ProtoTuple. -func (pt *ProtoTuple) unmarshal(attr netfilter.Attribute) error { - - if tupleType(attr.Type) != ctaTupleProto { - return fmt.Errorf(errAttributeWrongType, attr.Type, ctaTupleProto) - } - - if !attr.Nested { - return errors.Wrap(errNotNested, opUnPTup) - } +func (pt *ProtoTuple) unmarshal(ad *netlink.AttributeDecoder) error { - if len(attr.Children) == 0 { + if ad.Len() == 0 { return errors.Wrap(errNeedSingleChild, opUnPTup) } - for _, iattr := range attr.Children { - switch protoTupleType(iattr.Type) { + for ad.Next() { + switch protoTupleType(ad.Type()) { case ctaProtoNum: - pt.Protocol = iattr.Data[0] + pt.Protocol = ad.Uint8() if pt.Protocol == syscall.IPPROTO_ICMP { pt.ICMPv4 = true @@ -233,17 +209,17 @@ func (pt *ProtoTuple) unmarshal(attr netfilter.Attribute) error { pt.ICMPv6 = true } case ctaProtoSrcPort: - pt.SourcePort = iattr.Uint16() + pt.SourcePort = ad.Uint16() case ctaProtoDstPort: - pt.DestinationPort = iattr.Uint16() + pt.DestinationPort = ad.Uint16() case ctaProtoICMPID, ctaProtoICMPv6ID: - pt.ICMPID = iattr.Uint16() + pt.ICMPID = ad.Uint16() case ctaProtoICMPType, ctaProtoICMPv6Type: - pt.ICMPType = iattr.Data[0] + pt.ICMPType = ad.Uint8() case ctaProtoICMPCode, ctaProtoICMPv6Code: - pt.ICMPCode = iattr.Data[0] + pt.ICMPCode = ad.Uint8() default: - return errors.Wrap(fmt.Errorf(errAttributeChild, iattr.Type, ctaTupleProto), opUnPTup) + return errors.Wrap(fmt.Errorf(errAttributeChild, ad.Type()), opUnPTup) } } diff --git a/tuple_test.go b/tuple_test.go index 24d8027..b1d8234 100644 --- a/tuple_test.go +++ b/tuple_test.go @@ -18,13 +18,11 @@ import ( var ( // Template attribute with Nested disabled attrDefault = netfilter.Attribute{Nested: false} - // Attribute with random, unused type 65535 - attrUnknown = netfilter.Attribute{Type: 0xFFFF} + // Attribute with random, unused type 16383 + attrUnknown = netfilter.Attribute{Type: 0x3FFF} // Nested structure of attributes with random, unused type 65535 attrTupleUnknownNested = netfilter.Attribute{Type: uint16(ctaTupleOrig), Nested: true, Children: []netfilter.Attribute{attrUnknown, attrUnknown}} - // Tuple attribute without Nested flag - attrTupleNotNested = netfilter.Attribute{Type: uint16(ctaTupleOrig)} // Tuple attribute with Nested flag attrTupleNestedOneChild = netfilter.Attribute{Type: uint16(ctaTupleOrig), Nested: true, Children: []netfilter.Attribute{attrDefault}} ) @@ -38,7 +36,7 @@ var ipTupleTests = []struct { { name: "correct ipv4 tuple", nfa: netfilter.Attribute{ - Type: 0x1, + Type: uint16(ctaTupleIP), Nested: true, Children: []netfilter.Attribute{ { @@ -61,7 +59,7 @@ var ipTupleTests = []struct { { name: "correct ipv6 tuple", nfa: netfilter.Attribute{ - Type: 0x1, + Type: uint16(ctaTupleIP), Nested: true, Children: []netfilter.Attribute{ { @@ -87,18 +85,10 @@ var ipTupleTests = []struct { DestinationAddress: net.ParseIP("4:4:3:3:2:2:1:1"), }, }, - { - name: "error nested flag not set on attribute", - nfa: netfilter.Attribute{ - Type: 0x1, - Nested: false, - }, - err: errors.Wrap(errNotNested, opUnIPTup), - }, { name: "error incorrect amount of children", nfa: netfilter.Attribute{ - Type: 0x1, + Type: uint16(ctaTupleIP), Nested: true, Children: []netfilter.Attribute{attrDefault}, }, @@ -107,7 +97,7 @@ var ipTupleTests = []struct { { name: "error child incorrect length", nfa: netfilter.Attribute{ - Type: 0x1, + Type: uint16(ctaTupleIP), Nested: true, Children: []netfilter.Attribute{ { @@ -120,21 +110,15 @@ var ipTupleTests = []struct { }, err: errIncorrectSize, }, - { - name: "error iptuple unmarshal with wrong type", - nfa: attrUnknown, - err: fmt.Errorf(errAttributeWrongType, attrUnknown.Type, ctaTupleIP), - }, { name: "error iptuple unmarshal with unknown IPTupleType", nfa: netfilter.Attribute{ - // CTA_TUPLE_IP - Type: 0x1, + Type: uint16(ctaTupleIP), Nested: true, Children: []netfilter.Attribute{ { // Unknown type - Type: 0xFFFF, + Type: 0x3FFF, // Correct IP address length Data: []byte{0, 0, 0, 0}, }, @@ -142,7 +126,7 @@ var ipTupleTests = []struct { attrDefault, }, }, - err: errors.Wrap(fmt.Errorf(errAttributeChild, 0xFFFF, ctaTupleIP), opUnIPTup), + err: errors.Wrap(fmt.Errorf(errAttributeChild, 0x3FFF), opUnIPTup), }, } @@ -153,13 +137,16 @@ func TestIPTupleMarshalTwoWay(t *testing.T) { var ipt IPTuple - err := ipt.unmarshal(tt.nfa) - if err != nil || tt.err != nil { + err := ipt.unmarshal(mustDecodeAttributes(tt.nfa.Children)) + + if tt.err != nil { require.Error(t, err) - require.EqualError(t, tt.err, err.Error()) + require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.cta, ipt); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) } @@ -193,8 +180,12 @@ var protoTupleTests = []struct { }{ { name: "error unmarshal with wrong type", - nfa: attrUnknown, - err: fmt.Errorf(errAttributeWrongType, attrUnknown.Type, ctaTupleProto), + nfa: netfilter.Attribute{ + Type: uint16(ctaTupleProto), + Nested: true, + Children: []netfilter.Attribute{attrUnknown}, + }, + err: errors.Wrap(fmt.Errorf(errAttributeChild, attrUnknown.Type), opUnPTup), }, { name: "error unmarshal with incorrect amount of children", @@ -215,7 +206,7 @@ var protoTupleTests = []struct { attrDefault, }, }, - err: errors.Wrap(fmt.Errorf(errAttributeChild, attrUnknown.Type, ctaTupleProto), opUnPTup), + err: errors.Wrap(fmt.Errorf(errAttributeChild, attrUnknown.Type), opUnPTup), }, { name: "correct icmpv4 prototuple", @@ -290,13 +281,16 @@ func TestProtoTupleMarshalTwoWay(t *testing.T) { var pt ProtoTuple - err := pt.unmarshal(tt.nfa) - if err != nil || tt.err != nil { + err := pt.unmarshal(mustDecodeAttributes(tt.nfa.Children)) + + if tt.err != nil { require.Error(t, err) - require.EqualError(t, tt.err, err.Error()) + require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.cta, pt); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) } @@ -400,46 +394,7 @@ var tupleTests = []struct { attrDefault, }, }, - err: errIncorrectSize, - }, - { - name: "error returned from iptuple unmarshal", - nfa: netfilter.Attribute{ - // CTA_TUPLE_ORIG - Type: 0x1, - Nested: true, - Children: []netfilter.Attribute{ - { - // CTA_TUPLE_IP - Type: 0x1, - }, - // Padding element - attrDefault, - }, - }, - err: errors.Wrap(errNotNested, opUnIPTup), - }, - { - name: "error returned from prototuple unmarshal", - nfa: netfilter.Attribute{ - // CTA_TUPLE_ORIG - Type: 0x1, - Nested: true, - Children: []netfilter.Attribute{ - { - // CTA_TUPLE_PROTO - Type: 0x2, - }, - // Padding element - attrDefault, - }, - }, - err: errors.Wrap(errNotNested, opUnPTup), - }, - { - name: "error nested flag not set on tuple", - nfa: attrTupleNotNested, - err: errors.Wrap(errNotNested, opUnTup), + err: errors.New("netlink: attribute 3 is not a uint16; length: 4"), }, { name: "error too few children", @@ -449,7 +404,7 @@ var tupleTests = []struct { { name: "error unknown nested tuple type", nfa: attrTupleUnknownNested, - err: errors.Wrap(fmt.Errorf(errAttributeChild, attrTupleUnknownNested.Children[0].Type, ctaTupleOrig), opUnTup), + err: errors.Wrap(fmt.Errorf(errAttributeChild, attrTupleUnknownNested.Children[0].Type), opUnTup), }, } @@ -460,13 +415,16 @@ func TestTupleMarshalTwoWay(t *testing.T) { var tpl Tuple - err := tpl.unmarshal(tt.nfa) - if err != nil || tt.err != nil { + err := tpl.unmarshal(mustDecodeAttributes(tt.nfa.Children)) + + if tt.err != nil { require.Error(t, err) - require.EqualError(t, tt.err, err.Error()) + require.EqualError(t, err, tt.err.Error()) return } + require.NoError(t, err) + if diff := cmp.Diff(tt.cta, tpl); diff != "" { t.Fatalf("unexpected unmarshal (-want +got):\n%s", diff) }