diff --git a/route_test.go b/route_test.go index 2ed217e..5c2bca5 100644 --- a/route_test.go +++ b/route_test.go @@ -10,15 +10,14 @@ import ( // Tests will only pass on little endian machines -func TestRouteMessageMarshalBinary(t *testing.T) { +func TestRouteMessageMarshalUnmarshalBinary(t *testing.T) { skipBigEndian(t) timeout := uint32(255) tests := []struct { name string - m Message + m *RouteMessage b []byte - err error }{ { name: "empty", @@ -41,83 +40,31 @@ func TestRouteMessageMarshalBinary(t *testing.T) { }, }, { - name: "attributes", + name: "full", m: &RouteMessage{ + Family: 2, + DstLength: 8, + Table: unix.RT_TABLE_MAIN, + Protocol: unix.RTPROT_STATIC, + Scope: unix.RT_SCOPE_UNIVERSE, + Type: unix.RTN_UNICAST, Attributes: RouteAttributes{ - Dst: net.ParseIP("10.0.0.0"), - Gateway: net.ParseIP("10.10.10.10"), - OutIface: 4, + Dst: net.IPv4(10, 0, 0, 0), + Src: net.IPv4(10, 100, 10, 1), + Gateway: net.IPv4(10, 0, 0, 1), + OutIface: 5, + Priority: 1, + Table: 2, + Mark: 3, Expires: &timeout, Metrics: &RouteMetrics{ - MTU: 1500, + AdvMSS: 1, + Features: 0xffffffff, + InitCwnd: 2, + MTU: 1500, }, }, }, - b: []byte{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x01, 0x00, - 0x0a, 0x00, 0x00, 0x00, 0x08, 0x00, 0x05, 0x00, - 0x0a, 0x0a, 0x0a, 0x0a, 0x08, 0x00, 0x04, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x17, 0x00, - 0xff, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x08, 0x80, - 0x08, 0x00, 0x02, 0x00, 0xdc, 0x05, 0x00, 0x00, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b, err := tt.m.MarshalBinary() - - if want, got := tt.err, err; want != got { - t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) - } - if err != nil { - return - } - - if diff := cmp.Diff(tt.b, b); diff != "" { - t.Fatalf("unexpected RouteMessage bytes (-want +got):\n%s", diff) - } - }) - } -} - -func TestRouteMessageUnmarshalBinary(t *testing.T) { - skipBigEndian(t) - - timeout := uint32(1000) - tests := []struct { - name string - b []byte - m Message - err error - }{ - { - name: "empty", - err: errInvalidRouteMessage, - }, - { - name: "short", - b: make([]byte, 3), - err: errInvalidRouteMessage, - }, - { - name: "invalid attr", - b: []byte{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x01, 0x00, 0x04, 0x00, 0x02, 0x00, - 0x05, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - err: errInvalidRouteMessageAttr, - }, - { - name: "full", b: []byte{ // RouteMessage struct literal // @@ -165,7 +112,7 @@ func TestRouteMessageUnmarshalBinary(t *testing.T) { 0x03, 0x00, 0x00, 0x00, // Expires 0x08, 0x00, 0x17, 0x00, - 0xe8, 0x03, 0x00, 0x00, + 0xff, 0x00, 0x00, 0x00, // RouteMetrics // Length must be manually adjusted as more fields are added. 0x24, 0x00, 0x08, 0x80, @@ -182,48 +129,94 @@ func TestRouteMessageUnmarshalBinary(t *testing.T) { 0x08, 0x00, 0x02, 0x00, 0xdc, 0x05, 0x00, 0x00, }, - m: &RouteMessage{ - Family: 2, - DstLength: 8, - Table: unix.RT_TABLE_MAIN, - Protocol: unix.RTPROT_STATIC, - Scope: unix.RT_SCOPE_UNIVERSE, - Type: unix.RTN_UNICAST, - Attributes: RouteAttributes{ - Dst: net.IPv4(10, 0, 0, 0), - Src: net.IPv4(10, 100, 10, 1), - Gateway: net.IPv4(10, 0, 0, 1), - OutIface: 5, - Priority: 1, - Table: 2, - Mark: 3, - Expires: &timeout, - Metrics: &RouteMetrics{ - AdvMSS: 1, - Features: 0xffffffff, - InitCwnd: 2, - MTU: 1500, - }, - }, - }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := &RouteMessage{} - err := (m).UnmarshalBinary(tt.b) - - if want, got := tt.err, err; want != got { - t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) + // It's important to be able to parse raw bytes into valid + // structures so we start with that step first. After, we'll do a + // marshaling round-trip to ensure that the structure's byte output + // and parsed form match what is expected, while also comparing + // against the expected fixtures throughout. + var m1 RouteMessage + if err := m1.UnmarshalBinary(tt.b); err != nil { + t.Fatalf("failed to unmarshal first message from binary: %v", err) } + + if diff := cmp.Diff(tt.m, &m1); diff != "" { + t.Fatalf("unexpected first message (-want +got):\n%s", diff) + } + + b, err := m1.MarshalBinary() if err != nil { - return + t.Fatalf("failed to marshal first message binary: %v", err) } - if diff := cmp.Diff(tt.m, m); diff != "" { - t.Fatalf("unexpected RouteMessage (-want +got):\n%s", diff) + if diff := cmp.Diff(tt.b, b); diff != "" { + t.Fatalf("unexpected first message bytes (-want +got):\n%s", diff) + } + + var m2 RouteMessage + if err := m2.UnmarshalBinary(b); err != nil { + t.Fatalf("failed to unmarshal second message from binary: %v", err) + } + + if diff := cmp.Diff(&m1, &m2); diff != "" { + t.Fatalf("unexpected parsed messages (-want +got):\n%s", diff) } }) } } + +func TestRouteMessageUnmarshalBinaryErrors(t *testing.T) { + skipBigEndian(t) + + tests := []struct { + name string + b []byte + m Message + err error + }{ + { + name: "empty", + err: errInvalidRouteMessage, + }, + { + name: "short", + b: make([]byte, 3), + err: errInvalidRouteMessage, + }, + { + name: "invalid attr", + b: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x01, 0x00, 0x04, 0x00, 0x02, 0x00, + 0x05, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + err: errInvalidRouteMessageAttr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m RouteMessage + err := m.UnmarshalBinary(tt.b) + + if diff := cmp.Diff(tt.err, err, cmp.Comparer(compareErrors)); diff != "" { + t.Fatalf("unexpected error (-want +got):\n%s", diff) + } + }) + } +} + +func compareErrors(x, y error) bool { + // This is lazy but should be sufficient for the typical stringified errors + // returned by this package. + return x.Error() == y.Error() +}