rtnetlink/rule_test.go
Florian Lehner a833fb5b68
add netlink/rule (#139)
* add netlink/rule

Signed-off-by: Florian Lehner <dev@der-flo.net>

* Add some fuzzing corpus

Signed-off-by: Jeroen Simonetti <jeroen@simonetti.nl>

Co-authored-by: Jeroen Simonetti <jeroen@simonetti.nl>
2022-04-12 09:00:30 +02:00

201 lines
5.6 KiB
Go

package rtnetlink
import (
"bytes"
"errors"
"net"
"reflect"
"testing"
)
func TestRuleMessage(t *testing.T) {
skipBigEndian(t)
tests := map[string]struct {
m Message
b []byte
marshalErr error
unmarshalErr error
}{
"empty": {
m: &RuleMessage{},
b: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
"no attributes": {
m: &RuleMessage{
Family: 1,
DstLength: 2,
SrcLength: 3,
TOS: 4,
Table: 5,
Action: 6,
Flags: 7,
},
b: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x00, 0x06, 0x07, 0x00, 0x00, 0x00},
},
"with attributes": {
m: &RuleMessage{
Family: 7,
DstLength: 6,
SrcLength: 5,
TOS: 4,
Table: 3,
Action: 2,
Flags: 1,
Attributes: &RuleAttributes{
Src: netIPPtr(net.ParseIP("8.8.8.8")),
Dst: netIPPtr(net.ParseIP("1.1.1.1")),
IIFName: strPtr("eth0"),
OIFName: strPtr("br0"),
Goto: uint32Ptr(1),
Priority: uint32Ptr(2),
FwMark: uint32Ptr(3),
FwMask: uint32Ptr(5),
L3MDev: uint8Ptr(7),
DstRealm: uint16Ptr(11),
SrcRealm: uint16Ptr(13),
TunID: uint64Ptr(17),
Protocol: uint8Ptr(19),
IPProto: uint8Ptr(23),
Table: uint32Ptr(29),
SuppressPrefixLen: uint32Ptr(31),
SuppressIFGroup: uint32Ptr(37),
UIDRange: &RuleUIDRange{
Start: 22,
End: 25,
},
SPortRange: &RulePortRange{
Start: 23,
End: 26,
},
DPortRange: &RulePortRange{
Start: 24,
End: 27,
},
},
},
b: []byte{
0x07, 0x06, 0x05, 0x04, 0x03, 0x00, 0x00, 0x02, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00,
0x0f, 0x00, 0x1d, 0x00, 0x00, 0x00, 0x05, 0x00, 0x15, 0x00, 0x13, 0x00, 0x00, 0x00,
0x08, 0x00, 0x02, 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x00, 0x01, 0x00, 0x01, 0x01,
0x01, 0x01, 0x09, 0x00, 0x03, 0x00, 0x65, 0x74, 0x68, 0x30, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x11, 0x00, 0x62, 0x72, 0x30, 0x00, 0x08, 0x00, 0x04, 0x00, 0x01, 0x00,
0x00, 0x00, 0x08, 0x00, 0x06, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00,
0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x10, 0x00, 0x05, 0x00, 0x00, 0x00, 0x08, 0x00,
0x0b, 0x00, 0x0b, 0x00, 0x0d, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x11, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x13, 0x00, 0x07, 0x00, 0x00, 0x00, 0x05, 0x00,
0x16, 0x00, 0x17, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0d, 0x00, 0x25, 0x00, 0x00, 0x00,
0x08, 0x00, 0x0e, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x08, 0x00, 0x14, 0x00, 0x16, 0x00,
0x19, 0x00, 0x08, 0x00, 0x17, 0x00, 0x17, 0x00, 0x1a, 0x00, 0x08, 0x00, 0x18, 0x00,
0x18, 0x00, 0x1b, 0x00,
},
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var b []byte
t.Run("marshal", func(t *testing.T) {
var marshalErr error
b, marshalErr = tt.m.MarshalBinary()
if !errors.Is(marshalErr, tt.marshalErr) {
t.Fatalf("Expected error '%v' but got '%v'", tt.marshalErr, marshalErr)
}
})
t.Run("compare bytes", func(t *testing.T) {
if want, got := tt.b, b; !bytes.Equal(want, got) {
t.Fatalf("unexpected Message bytes:\n- want: [%# x]\n- got: [%# x]", want, got)
}
})
m := &RuleMessage{}
t.Run("unmarshal", func(t *testing.T) {
unmarshalErr := (m).UnmarshalBinary(b)
if !errors.Is(unmarshalErr, tt.unmarshalErr) {
t.Fatalf("Expected error '%v' but got '%v'", tt.unmarshalErr, unmarshalErr)
}
})
t.Run("compare messages", func(t *testing.T) {
if !reflect.DeepEqual(tt.m, m) {
t.Fatalf("unexpected Message:\n- want: %#v\n- got: %#v", tt.m, m)
}
})
})
}
t.Run("invalid length", func(t *testing.T) {
m := &RuleMessage{}
unmarshalErr := (m).UnmarshalBinary([]byte{0x00, 0x01, 0x2, 0x03})
if !errors.Is(unmarshalErr, errInvalidRuleMessage) {
t.Fatalf("Expected 'errInvalidRuleMessage' but got '%v'", unmarshalErr)
}
})
t.Run("skipped attributes", func(t *testing.T) {
m := &RuleMessage{}
unmarshalErr := (m).UnmarshalBinary([]byte{
0x01, 0x00, 0x00, 0x02, 0x03, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00,
0x00, 0x00, 0x04, 0x00, 0x05, 0x00, 0x04, 0x00, 0x07, 0x00, 0x04, 0x00, 0x08, 0x00,
0x04, 0x00, 0x09, 0x00, 0x04, 0x00, 0x12, 0x00,
})
if !errors.Is(unmarshalErr, nil) {
t.Fatalf("Expected no error but got '%v'", unmarshalErr)
}
expected := &RuleMessage{
Family: 1,
TOS: 2,
Table: 3,
Action: 4,
Flags: 5,
Attributes: &RuleAttributes{},
}
if !reflect.DeepEqual(expected, m) {
t.Fatalf("unexpected Message:\n- want: %#v\n- got: %#v", expected, m)
}
})
t.Run("invalid attribute", func(t *testing.T) {
m := &RuleMessage{}
unmarshalErr := (m).UnmarshalBinary([]byte{
0x01, 0x00, 0x00, 0x02, 0x03, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00,
0x2a, 0x00,
})
if !errors.Is(unmarshalErr, errInvalidRuleAttribute) {
t.Fatalf("Expected 'errInvalidRuleAttribute' error but got '%v'", unmarshalErr)
}
})
}
func uint64Ptr(v uint64) *uint64 {
return &v
}
func uint32Ptr(v uint32) *uint32 {
return &v
}
func uint16Ptr(v uint16) *uint16 {
return &v
}
func uint8Ptr(v uint8) *uint8 {
return &v
}
func netIPPtr(v net.IP) *net.IP {
if ip4 := v.To4(); ip4 != nil {
// By default net.IP returns the 16 byte representation.
// But netlink requires us to provide only four bytes
// for legacy IPs.
return &ip4
}
return &v
}
func strPtr(v string) *string {
return &v
}