From 53e4bfd991fc2184f037e790c24aa0a7c6218c28 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Sun, 15 May 2016 14:56:09 -0700 Subject: [PATCH] dhcp4: refactor the low-level packet sending. This lets the implementations share the transmission strategy and decoding/sanity logic. --- dhcp4/conn.go | 181 +++++++++++++++++++++++++------------------- dhcp4/conn_linux.go | 71 ++++++----------- dhcp4/conn_test.go | 4 +- 3 files changed, 131 insertions(+), 125 deletions(-) diff --git a/dhcp4/conn.go b/dhcp4/conn.go index dfb6a72..b46b133 100644 --- a/dhcp4/conn.go +++ b/dhcp4/conn.go @@ -26,7 +26,7 @@ import ( // defined as a var so tests can override it. var ( dhcpClientPort = 68 - platformConn func(string) (Conn, error) + platformConn func(string) (conn, error) ) // txType describes how a Packet should be sent on the wire. @@ -51,44 +51,105 @@ const ( txHardwareAddr ) +type conn interface { + io.Closer + Recv([]byte) (b []byte, addr *net.UDPAddr, ifidx int, err error) + Send(b []byte, addr *net.UDPAddr, ifidx int) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} + // Conn is a DHCP-oriented packet socket. // // Multiple goroutines may invoke methods on a Conn simultaneously. -type Conn interface { - io.Closer - // RecvDHCP reads a Packet from the connection. It returns the - // packet and the interface it was received on, which may be nil - // if interface information cannot be obtained. - RecvDHCP() (pkt *Packet, intf *net.Interface, err error) - // SendDHCP sends pkt. The precise transmission mechanism depends - // on pkt.txType(). intf should be the net.Interface returned by - // RecvDHCP if responding to a DHCP client, or the interface for - // which configuration is desired if acting as a client. - SendDHCP(pkt *Packet, intf *net.Interface) error - // SetReadDeadline sets the deadline for future Read calls. - // If the deadline is reached, Read will fail with a timeout - // (see type Error) instead of blocking. - // A zero value for t means Read will not time out. - SetReadDeadline(t time.Time) error +type Conn struct { + conn conn } // NewConn creates a Conn bound to the given UDP ip:port. -func NewConn(addr string) (Conn, error) { +func NewConn(addr string) (*Conn, error) { if platformConn != nil { c, err := platformConn(addr) if err == nil { - return c, nil + return &Conn{c}, nil } } // Always try falling back to the portable implementation - return newPortableConn(addr) + c, err := newPortableConn(addr) + if err != nil { + return nil, err + } + return &Conn{c}, nil +} + +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) RecvDHCP() (*Packet, *net.Interface, error) { + var buf [1500]byte + for { + b, _, ifidx, err := c.conn.Recv(buf[:]) + if err != nil { + return nil, nil, err + } + pkt, err := Unmarshal(b) + if err != nil { + continue + } + intf, err := net.InterfaceByIndex(ifidx) + if err != nil { + return nil, nil, err + } + // TODO: possibly more validation that the source lines up + // with what the packet says. + return pkt, intf, nil + } +} + +func (c *Conn) SendDHCP(pkt *Packet, intf *net.Interface) error { + b, err := pkt.Marshal() + if err != nil { + return err + } + + switch pkt.txType() { + case txBroadcast, txHardwareAddr: + addr := net.UDPAddr{ + IP: net.IPv4bcast, + Port: dhcpClientPort, + } + return c.conn.Send(b, &addr, intf.Index) + case txRelayAddr: + addr := net.UDPAddr{ + IP: pkt.RelayAddr, + Port: 67, + } + return c.conn.Send(b, &addr, 0) + case txClientAddr: + addr := net.UDPAddr{ + IP: pkt.ClientAddr, + Port: dhcpClientPort, + } + return c.conn.Send(b, &addr, 0) + default: + return errors.New("unknown TX type for packet") + } +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) } type portableConn struct { conn *ipv4.PacketConn } -func newPortableConn(addr string) (Conn, error) { +func newPortableConn(addr string) (conn, error) { c, err := net.ListenPacket("udp4", addr) if err != nil { return nil, err @@ -105,64 +166,30 @@ func (c *portableConn) Close() error { return c.conn.Close() } +func (c *portableConn) Recv(b []byte) (rb []byte, addr *net.UDPAddr, ifidx int, err error) { + n, cm, a, err := c.conn.ReadFrom(b) + if err != nil { + return nil, nil, 0, err + } + return b[:n], a.(*net.UDPAddr), cm.IfIndex, nil +} + +func (c *portableConn) Send(b []byte, addr *net.UDPAddr, ifidx int) error { + if ifidx > 0 { + _, err := c.conn.WriteTo(b, nil, addr) + return err + } + cm := ipv4.ControlMessage{ + IfIndex: ifidx, + } + _, err := c.conn.WriteTo(b, &cm, addr) + return err +} + func (c *portableConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) { - var buf [1500]byte - for { - n, cm, _, err := c.conn.ReadFrom(buf[:]) - if err != nil { - return nil, nil, err - } - pkt, err := Unmarshal(buf[:n]) - if err != nil { - continue - } - intf, err := net.InterfaceByIndex(cm.IfIndex) - if err != nil { - return nil, nil, err - } - // TODO: possibly more validation that the source lines up - // with what the packet. - return pkt, intf, nil - } -} - -func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error { - b, err := pkt.Marshal() - if err != nil { - return err - } - - switch pkt.txType() { - case txBroadcast, txHardwareAddr: - cm := ipv4.ControlMessage{ - IfIndex: intf.Index, - } - addr := net.UDPAddr{ - IP: net.IPv4bcast, - Port: dhcpClientPort, - } - _, err = c.conn.WriteTo(b, &cm, &addr) - return err - case txRelayAddr: - // Send to the server port, not the client port. - addr := net.UDPAddr{ - IP: pkt.RelayAddr, - Port: 67, - } - _, err = c.conn.WriteTo(b, nil, &addr) - return err - case txClientAddr: - addr := net.UDPAddr{ - IP: pkt.ClientAddr, - Port: dhcpClientPort, - } - _, err = c.conn.WriteTo(b, nil, &addr) - return err - default: - return errors.New("unknown TX type for packet") - } +func (c *portableConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) } diff --git a/dhcp4/conn_linux.go b/dhcp4/conn_linux.go index 900f057..8537949 100644 --- a/dhcp4/conn_linux.go +++ b/dhcp4/conn_linux.go @@ -40,7 +40,7 @@ func init() { platformConn = newLinuxConn } -func newLinuxConn(addr string) (Conn, error) { +func newLinuxConn(addr string) (conn, error) { if addr == "" { addr = ":67" } @@ -101,45 +101,24 @@ func (c *linuxConn) Close() error { return c.conn.Close() } -func (c *linuxConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) { - var buf [1500]byte - for { - _, p, cm, err := c.conn.ReadFrom(buf[:]) - if err != nil { - return nil, nil, err - } - if len(p) < 8 { - continue - } - pkt, err := Unmarshal(p[8:]) - if err != nil { - continue - } - intf, err := net.InterfaceByIndex(cm.IfIndex) - if err != nil { - return nil, nil, err - } - // TODO: possibly more validation that the IPv4 header lines - // up with what the packet. - return pkt, intf, nil - } -} - -func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error { - b, err := pkt.Marshal() +func (c *linuxConn) Recv(b []byte) (rb []byte, addr *net.UDPAddr, ifidx int, err error) { + hdr, p, cm, err := c.conn.ReadFrom(b) if err != nil { - return err + return nil, nil, 0, err } + if len(p) < 8 { + return nil, nil, 0, errors.New("not a UDP packet, too short") + } + sport := int(binary.BigEndian.Uint16(p[:2])) + return p[8:], &net.UDPAddr{IP: hdr.Src, Port: sport}, cm.IfIndex, nil +} +func (c *linuxConn) Send(b []byte, addr *net.UDPAddr, ifidx int) error { raw := make([]byte, 8+len(b)) // src port binary.BigEndian.PutUint16(raw[:2], c.port) // dst port - binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort)) + binary.BigEndian.PutUint16(raw[2:4], uint16(addr.Port)) // length binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b))) copy(raw[8:], b) @@ -151,24 +130,22 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error { TotalLen: ipv4.HeaderLen + 8 + len(b), TTL: 64, Protocol: 17, + Dst: addr.IP, } - switch pkt.txType() { - case txBroadcast, txHardwareAddr: - hdr.Dst = net.IPv4bcast + if ifidx > 0 { cm := ipv4.ControlMessage{ - IfIndex: intf.Index, + IfIndex: ifidx, } return c.conn.WriteTo(&hdr, raw, &cm) - case txRelayAddr: - // Send to the server port, not the client port. - binary.BigEndian.PutUint16(raw[2:4], 67) - hdr.Dst = pkt.RelayAddr - return c.conn.WriteTo(&hdr, raw, nil) - case txClientAddr: - hdr.Dst = pkt.ClientAddr - return c.conn.WriteTo(&hdr, raw, nil) - default: - return errors.New("unknown TX type for packet") } + return c.conn.WriteTo(&hdr, raw, nil) +} + +func (c *linuxConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *linuxConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) } diff --git a/dhcp4/conn_test.go b/dhcp4/conn_test.go index c541757..087aeae 100644 --- a/dhcp4/conn_test.go +++ b/dhcp4/conn_test.go @@ -23,7 +23,9 @@ import ( "time" ) -func testConn(t *testing.T, c Conn, addr string) { +func testConn(t *testing.T, impl conn, addr string) { + c := &Conn{impl} + s, err := net.Dial("udp4", addr) if err != nil { t.Fatal(err)