From 029370a6e6e22c85652a0d67cf00bf9747a1627f Mon Sep 17 00:00:00 2001 From: Jeroen Simonetti Date: Wed, 10 Apr 2019 09:47:48 +0200 Subject: [PATCH 1/2] Use upstreams goroutine safe execution This will change the way messages are send and received to a goroutine concurrent safe netlink.Execute call. Signed-off-by: Jeroen Simonetti --- conn.go | 107 ++++++++++++++++++++++++++++++++------------------- conn_test.go | 14 +++++-- 2 files changed, 78 insertions(+), 43 deletions(-) diff --git a/conn.go b/conn.go index 8402297..77c673b 100644 --- a/conn.go +++ b/conn.go @@ -27,6 +27,7 @@ type conn interface { Close() error Send(m netlink.Message) (netlink.Message, error) Receive() ([]netlink.Message, error) + Execute(m netlink.Message) ([]netlink.Message, error) } // Dial dials a route netlink connection. Config specifies optional @@ -38,11 +39,15 @@ func Dial(config *netlink.Config) (*Conn, error) { return nil, err } - return newConn(c), nil + return NewConn(c), nil } -// newConn is the internal constructor for Conn, used in tests. -func newConn(c conn) *Conn { +// NewConn creates a Conn that wraps an existing *netlink.Conn for +// generic netlink communications. +// +// NewConn is primarily useful for tests. Most applications should use +// Dial instead. +func NewConn(c conn) *Conn { rtc := &Conn{ c: c, } @@ -92,11 +97,66 @@ func (c *Conn) Receive() ([]Message, []netlink.Message, error) { return nil, nil, err } - return messageUnmarshall(msgs) + rtmsgs, err := unpackMessages(msgs) + if err != nil { + return nil, nil, err + } + + return rtmsgs, msgs, nil } -// messageUnmarshall will unmarshal the message based on its type -func messageUnmarshall(msgs []netlink.Message) ([]Message, []netlink.Message, error) { +// Execute sends a single Message to netlink using Send, receives one or more +// replies using Receive, and then checks the validity of the replies against +// the request using netlink.Validate. +// +// Execute acquires a lock for the duration of the function call which blocks +// concurrent calls to Send and Receive, in order to ensure consistency between +// generic netlink request/reply messages. +// +// See the documentation of Send, Receive, and netlink.Validate for details +// about each function. +func (c *Conn) Execute(m Message, family uint16, flags netlink.HeaderFlags) ([]Message, error) { + nm, err := packMessage(m, family, flags) + if err != nil { + return nil, err + } + + msgs, err := c.c.Execute(nm) + if err != nil { + return nil, err + } + + return unpackMessages(msgs) +} + +//Message is the interface used for passing around different kinds of rtnetlink messages +type Message interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + rtMessage() +} + +// packMessage packs a rtnetlink Message into a netlink.Message with the +// appropriate rtnetlink family and netlink flags. +func packMessage(m Message, family uint16, flags netlink.HeaderFlags) (netlink.Message, error) { + nm := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(family), + Flags: flags, + }, + } + + mb, err := m.MarshalBinary() + if err != nil { + return netlink.Message{}, err + } + nm.Data = mb + + return nm, nil +} + +// unpackMessages unpacks rtnetlink Messages from a slice of netlink.Messages. +func unpackMessages(msgs []netlink.Message) ([]Message, error) { lmsgs := make([]Message, 0, len(msgs)) for _, nm := range msgs { @@ -125,41 +185,10 @@ func messageUnmarshall(msgs []netlink.Message) ([]Message, []netlink.Message, er } if err := (m).UnmarshalBinary(nm.Data); err != nil { - return nil, nil, err + return nil, err } lmsgs = append(lmsgs, m) } - return lmsgs, msgs, nil -} - -// Execute sends a single Message to netlink using Conn.Send, receives one or -// more replies using Conn.Receive, and then checks the validity of the replies -// against the request using netlink.Validate. -// -// See the documentation of Conn.Send, Conn.Receive, and netlink.Validate for -// details about each function. -func (c *Conn) Execute(m Message, family uint16, flags netlink.HeaderFlags) ([]Message, error) { - req, err := c.Send(m, family, flags) - if err != nil { - return nil, err - } - - msgs, replies, err := c.Receive() - if err != nil { - return nil, err - } - - if err := netlink.Validate(req, replies); err != nil { - return nil, err - } - - return msgs, nil -} - -//Message is the interface used for passing around different kinds of rtnetlink messages -type Message interface { - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler - rtMessage() + return lmsgs, nil } diff --git a/conn_test.go b/conn_test.go index 35984c9..6afda46 100644 --- a/conn_test.go +++ b/conn_test.go @@ -171,7 +171,7 @@ func TestConnReceive(t *testing.T) { func testConn(t *testing.T) (*Conn, *testNetlinkConn) { c := &testNetlinkConn{} - return newConn(c), c + return NewConn(c), c } type testNetlinkConn struct { @@ -190,11 +190,17 @@ func (c *testNetlinkConn) Receive() ([]netlink.Message, error) { return c.receive, nil } +func (c *testNetlinkConn) Execute(m netlink.Message) ([]netlink.Message, error) { + c.send = m + return c.receive, nil +} + type noopConn struct{} -func (c *noopConn) Close() error { return nil } -func (c *noopConn) Send(m netlink.Message) (netlink.Message, error) { return netlink.Message{}, nil } -func (c *noopConn) Receive() ([]netlink.Message, error) { return nil, nil } +func (c *noopConn) Close() error { return nil } +func (c *noopConn) Send(_ netlink.Message) (netlink.Message, error) { return netlink.Message{}, nil } +func (c *noopConn) Receive() ([]netlink.Message, error) { return nil, nil } +func (c *noopConn) Execute(m netlink.Message) ([]netlink.Message, error) { return nil, nil } func mustMarshal(m encoding.BinaryMarshaler) []byte { b, err := m.MarshalBinary() From 57c0b0b853cdd8ebf8df6aff69103acaeb0440d3 Mon Sep 17 00:00:00 2001 From: Jeroen Simonetti Date: Wed, 10 Apr 2019 13:04:55 +0200 Subject: [PATCH 2/2] Use Execute on all services --- address.go | 8 ++++---- route.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/address.go b/address.go index 02bbb61..b140919 100644 --- a/address.go +++ b/address.go @@ -110,8 +110,8 @@ const ( // New creates a new address using the AddressMessage information. func (a *AddressService) New(req *AddressMessage) error { - flags := netlink.Request - _, err := a.c.Send(req, RTM_NEWADDR, flags) + flags := netlink.Request | netlink.Create | netlink.Acknowledge | netlink.Excl + _, err := a.c.Execute(req, RTM_NEWADDR, flags) if err != nil { return err } @@ -128,8 +128,8 @@ func (a *AddressService) Delete(address net.IP, index uint32) error { }, } - flags := netlink.Request - _, err := a.c.Send(req, RTM_DELADDR, flags) + flags := netlink.Request | netlink.Acknowledge + _, err := a.c.Execute(req, RTM_DELADDR, flags) if err != nil { return err } diff --git a/route.go b/route.go index 0391f99..7377bcf 100644 --- a/route.go +++ b/route.go @@ -109,8 +109,8 @@ func (r *RouteService) Add(req *RouteMessage) error { // Delete existing route func (r *RouteService) Delete(req *RouteMessage) error { - flags := netlink.Request - _, err := r.c.Send(req, RTM_DELROUTE, flags) + flags := netlink.Request | netlink.Acknowledge + _, err := r.c.Execute(req, RTM_DELROUTE, flags) if err != nil { return err }