diff --git a/dhcp/conn.go b/dhcp/conn.go new file mode 100644 index 0000000..dbdf5bf --- /dev/null +++ b/dhcp/conn.go @@ -0,0 +1,122 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dhcp + +import ( + "errors" + "net" + "time" + + "golang.org/x/net/ipv4" +) + +// defined as a var so tests can override it. +var dhcpClientPort = 68 + +var platformConn func(string) (Conn, error) + +func NewConn(addr string) (Conn, error) { + if platformConn != nil { + c, err := platformConn(addr) + if err == nil { + return c, nil + } + } + // Always try falling back to the portable implementation + return newPortableConn(addr) +} + +type portableConn struct { + conn *ipv4.PacketConn +} + +func newPortableConn(addr string) (Conn, error) { + c, err := net.ListenPacket("udp4", addr) + if err != nil { + return nil, err + } + l := ipv4.NewPacketConn(c) + if err = l.SetControlMessage(ipv4.FlagInterface, true); err != nil { + l.Close() + return nil, err + } + return &portableConn{l}, nil +} + +func (c *portableConn) Close() error { + return c.conn.Close() +} + +func (c *portableConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) { + var buf [1500]byte + for { + n, cm, _, err := c.conn.ReadFrom(buf[:]) + if err != nil { + return nil, nil, err + } + pkt, err := Unmarshal(buf[:n]) + if err != nil { + continue + } + intf, err := net.InterfaceByIndex(cm.IfIndex) + if err != nil { + return nil, nil, err + } + // TODO: possibly more validation that the source lines up + // with what the packet. + return pkt, intf, nil + } +} + +func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error { + b, err := pkt.Marshal() + if err != nil { + return err + } + + switch pkt.TxType() { + case TxBroadcast, TxHardwareAddr: + cm := ipv4.ControlMessage{ + IfIndex: intf.Index, + } + addr := net.UDPAddr{ + IP: net.IPv4bcast, + Port: dhcpClientPort, + } + _, err = c.conn.WriteTo(b, &cm, &addr) + return err + case TxRelayAddr: + // Send to the server port, not the client port. + addr := net.UDPAddr{ + IP: pkt.RelayAddr, + Port: 67, + } + _, err = c.conn.WriteTo(b, nil, &addr) + return err + case TxClientAddr: + addr := net.UDPAddr{ + IP: pkt.ClientAddr, + Port: dhcpClientPort, + } + _, err = c.conn.WriteTo(b, nil, &addr) + return err + default: + return errors.New("unknown TX type for packet") + } +} diff --git a/dhcp/conn_linux.go b/dhcp/conn_linux.go index 635c260..8816ec7 100644 --- a/dhcp/conn_linux.go +++ b/dhcp/conn_linux.go @@ -31,15 +31,16 @@ const ( udpProtocolNumber = 17 ) -// defined as a var so tests can override it. -var dhcpClientPort = uint16(68) - type linuxConn struct { port uint16 conn *ipv4.RawConn } -func NewLinuxConn(addr string) (Conn, error) { +func init() { + platformConn = newLinuxConn +} + +func newLinuxConn(addr string) (Conn, error) { if addr == "" { addr = ":67" } @@ -116,7 +117,6 @@ func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) { } pkt, err := Unmarshal(p[8:]) if err != nil { - fmt.Println(err) continue } intf, err := net.InterfaceByIndex(cm.IfIndex) @@ -139,7 +139,7 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error { // src port binary.BigEndian.PutUint16(raw[:2], c.port) // dst port - binary.BigEndian.PutUint16(raw[2:4], dhcpClientPort) + binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort)) // length binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b))) copy(raw[8:], b) diff --git a/dhcp/conn_linux_test.go b/dhcp/conn_test.go similarity index 81% rename from dhcp/conn_linux_test.go rename to dhcp/conn_test.go index d1026b7..c541757 100644 --- a/dhcp/conn_linux_test.go +++ b/dhcp/conn_test.go @@ -23,26 +23,8 @@ import ( "time" ) -func TestConnLinux(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("not supported on %s", runtime.GOOS) - } - if os.Getuid() != 0 { - t.Skipf("must be root on %s", runtime.GOOS) - } - - // Use a listener to grab a free port, but we don't use it beyond - // that. - l, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - c, err := NewLinuxConn(l.LocalAddr().String()) - if err != nil { - t.Fatalf("creating the linuxconn: %s", err) - } - - s, err := net.Dial("udp4", l.LocalAddr().String()) +func testConn(t *testing.T, c Conn, addr string) { + s, err := net.Dial("udp4", addr) if err != nil { t.Fatal(err) } @@ -85,7 +67,7 @@ func TestConnLinux(t *testing.T) { // Test writing p.ClientAddr = net.IPv4(127, 0, 0, 1) - dhcpClientPort = uint16(s.LocalAddr().(*net.UDPAddr).Port) + dhcpClientPort = s.LocalAddr().(*net.UDPAddr).Port bs, err = p.Marshal() if err != nil { t.Fatalf("marshaling packet: %s", err) @@ -129,3 +111,43 @@ func TestConnLinux(t *testing.T) { t.Fatalf("DHCP packet not the same as when it was sent") } } + +func TestPortableConn(t *testing.T) { + // Use a listener to grab a free port, but we don't use it beyond + // that. + l, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := l.LocalAddr().String() + l.Close() + + c, err := newPortableConn(addr) + if err != nil { + t.Fatalf("creating the conn: %s", err) + } + + testConn(t, c, addr) +} + +func TestLinuxConn(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skipf("not supported on %s", runtime.GOOS) + } + if os.Getuid() != 0 { + t.Skipf("must be root on %s", runtime.GOOS) + } + + // Use a listener to grab a free port, but we don't use it beyond + // that. + l, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + c, err := newLinuxConn(l.LocalAddr().String()) + if err != nil { + t.Fatalf("creating the linuxconn: %s", err) + } + + testConn(t, c, l.LocalAddr().String()) +}