dhcp: implement a portable Conn.

Unlike the linux Conn, this one cannot share a port with another
process, but it should work across platforms.
This commit is contained in:
David Anderson 2016-05-14 23:05:23 -07:00
parent c5449b945c
commit 45dd5dfe4d
3 changed files with 171 additions and 27 deletions

122
dhcp/conn.go Normal file
View File

@ -0,0 +1,122 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dhcp
import (
"errors"
"net"
"time"
"golang.org/x/net/ipv4"
)
// defined as a var so tests can override it.
var dhcpClientPort = 68
var platformConn func(string) (Conn, error)
func NewConn(addr string) (Conn, error) {
if platformConn != nil {
c, err := platformConn(addr)
if err == nil {
return c, nil
}
}
// Always try falling back to the portable implementation
return newPortableConn(addr)
}
type portableConn struct {
conn *ipv4.PacketConn
}
func newPortableConn(addr string) (Conn, error) {
c, err := net.ListenPacket("udp4", addr)
if err != nil {
return nil, err
}
l := ipv4.NewPacketConn(c)
if err = l.SetControlMessage(ipv4.FlagInterface, true); err != nil {
l.Close()
return nil, err
}
return &portableConn{l}, nil
}
func (c *portableConn) Close() error {
return c.conn.Close()
}
func (c *portableConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) {
var buf [1500]byte
for {
n, cm, _, err := c.conn.ReadFrom(buf[:])
if err != nil {
return nil, nil, err
}
pkt, err := Unmarshal(buf[:n])
if err != nil {
continue
}
intf, err := net.InterfaceByIndex(cm.IfIndex)
if err != nil {
return nil, nil, err
}
// TODO: possibly more validation that the source lines up
// with what the packet.
return pkt, intf, nil
}
}
func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
b, err := pkt.Marshal()
if err != nil {
return err
}
switch pkt.TxType() {
case TxBroadcast, TxHardwareAddr:
cm := ipv4.ControlMessage{
IfIndex: intf.Index,
}
addr := net.UDPAddr{
IP: net.IPv4bcast,
Port: dhcpClientPort,
}
_, err = c.conn.WriteTo(b, &cm, &addr)
return err
case TxRelayAddr:
// Send to the server port, not the client port.
addr := net.UDPAddr{
IP: pkt.RelayAddr,
Port: 67,
}
_, err = c.conn.WriteTo(b, nil, &addr)
return err
case TxClientAddr:
addr := net.UDPAddr{
IP: pkt.ClientAddr,
Port: dhcpClientPort,
}
_, err = c.conn.WriteTo(b, nil, &addr)
return err
default:
return errors.New("unknown TX type for packet")
}
}

View File

@ -31,15 +31,16 @@ const (
udpProtocolNumber = 17
)
// defined as a var so tests can override it.
var dhcpClientPort = uint16(68)
type linuxConn struct {
port uint16
conn *ipv4.RawConn
}
func NewLinuxConn(addr string) (Conn, error) {
func init() {
platformConn = newLinuxConn
}
func newLinuxConn(addr string) (Conn, error) {
if addr == "" {
addr = ":67"
}
@ -116,7 +117,6 @@ func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) {
}
pkt, err := Unmarshal(p[8:])
if err != nil {
fmt.Println(err)
continue
}
intf, err := net.InterfaceByIndex(cm.IfIndex)
@ -139,7 +139,7 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
// src port
binary.BigEndian.PutUint16(raw[:2], c.port)
// dst port
binary.BigEndian.PutUint16(raw[2:4], dhcpClientPort)
binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort))
// length
binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b)))
copy(raw[8:], b)

View File

@ -23,26 +23,8 @@ import (
"time"
)
func TestConnLinux(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("not supported on %s", runtime.GOOS)
}
if os.Getuid() != 0 {
t.Skipf("must be root on %s", runtime.GOOS)
}
// Use a listener to grab a free port, but we don't use it beyond
// that.
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
c, err := NewLinuxConn(l.LocalAddr().String())
if err != nil {
t.Fatalf("creating the linuxconn: %s", err)
}
s, err := net.Dial("udp4", l.LocalAddr().String())
func testConn(t *testing.T, c Conn, addr string) {
s, err := net.Dial("udp4", addr)
if err != nil {
t.Fatal(err)
}
@ -85,7 +67,7 @@ func TestConnLinux(t *testing.T) {
// Test writing
p.ClientAddr = net.IPv4(127, 0, 0, 1)
dhcpClientPort = uint16(s.LocalAddr().(*net.UDPAddr).Port)
dhcpClientPort = s.LocalAddr().(*net.UDPAddr).Port
bs, err = p.Marshal()
if err != nil {
t.Fatalf("marshaling packet: %s", err)
@ -129,3 +111,43 @@ func TestConnLinux(t *testing.T) {
t.Fatalf("DHCP packet not the same as when it was sent")
}
}
func TestPortableConn(t *testing.T) {
// Use a listener to grab a free port, but we don't use it beyond
// that.
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := l.LocalAddr().String()
l.Close()
c, err := newPortableConn(addr)
if err != nil {
t.Fatalf("creating the conn: %s", err)
}
testConn(t, c, addr)
}
func TestLinuxConn(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("not supported on %s", runtime.GOOS)
}
if os.Getuid() != 0 {
t.Skipf("must be root on %s", runtime.GOOS)
}
// Use a listener to grab a free port, but we don't use it beyond
// that.
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
c, err := newLinuxConn(l.LocalAddr().String())
if err != nil {
t.Fatalf("creating the linuxconn: %s", err)
}
testConn(t, c, l.LocalAddr().String())
}