diff --git a/dhcp4/conn.go b/dhcp4/conn.go index a68e08a..134affc 100644 --- a/dhcp4/conn.go +++ b/dhcp4/conn.go @@ -16,6 +16,7 @@ package dhcp4 import ( "errors" + "fmt" "io" "net" "time" @@ -62,16 +63,66 @@ type conn interface { // // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn struct { - conn conn + conn conn + ifIndex int } // NewConn creates a Conn bound to the given UDP ip:port. func NewConn(addr string) (*Conn, error) { - c, err := newPortableConn(addr) + return newConn(addr, newPortableConn) +} + +func newConn(addr string, n func(int) (conn, error)) (*Conn, error) { + if addr == "" { + addr = "0.0.0.0:67" + } + + ifIndex := 0 + udpAddr, err := net.ResolveUDPAddr("udp4", addr) if err != nil { return nil, err } - return &Conn{c}, nil + if !udpAddr.IP.To4().Equal(net.IPv4zero) { + // Caller wants to listen only on one address. However, DHCP + // packets are frequently broadcast, so we can't just listen + // on the given address. Instead, we need to translate it to + // an interface, and then filter incoming packets based on + // their received interface. + ifIndex, err = ipToIfindex(udpAddr.IP) + if err != nil { + return nil, err + } + } + + c, err := n(udpAddr.Port) + if err != nil { + return nil, err + } + return &Conn{ + conn: c, + ifIndex: ifIndex, + }, nil +} + +func ipToIfindex(ip net.IP) (int, error) { + intfs, err := net.Interfaces() + if err != nil { + return 0, err + } + for _, intf := range intfs { + addrs, err := intf.Addrs() + if err != nil { + return 0, err + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.Equal(ip) { + return intf.Index, nil + } + } + } + } + return 0, fmt.Errorf("IP %s not found on any local interface", ip) } // Close closes the DHCP socket. @@ -89,6 +140,9 @@ func (c *Conn) RecvDHCP() (*Packet, *net.Interface, error) { if err != nil { return nil, nil, err } + if c.ifIndex != 0 && ifidx != c.ifIndex { + continue + } pkt, err := Unmarshal(b) if err != nil { continue @@ -97,6 +151,7 @@ func (c *Conn) RecvDHCP() (*Packet, *net.Interface, error) { if err != nil { return nil, nil, err } + // TODO: possibly more validation that the source lines up // with what the packet says. return pkt, intf, nil @@ -157,11 +212,8 @@ type portableConn struct { conn *ipv4.PacketConn } -func newPortableConn(addr string) (conn, error) { - if addr == "" { - addr = ":67" - } - c, err := net.ListenPacket("udp4", addr) +func newPortableConn(port int) (conn, error) { + c, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", port)) if err != nil { return nil, err } diff --git a/dhcp4/conn_linux.go b/dhcp4/conn_linux.go index 1ec97e2..094d508 100644 --- a/dhcp4/conn_linux.go +++ b/dhcp4/conn_linux.go @@ -41,27 +41,12 @@ type linuxConn struct { // Unlike NewConn, NewSnooperConn does not bind to the ip:port, // enabling the Conn to coexist with other services on the machine. func NewSnooperConn(addr string) (*Conn, error) { - c, err := newLinuxConn(addr) - if err != nil { - return nil, err - } - return &Conn{c}, nil + return newConn(addr, newLinuxConn) } -func newLinuxConn(addr string) (conn, error) { - if addr == "" { - addr = "0.0.0.0:67" - } - udpAddr, err := net.ResolveUDPAddr("udp4", addr) - if err != nil { - return nil, err - } - if udpAddr.IP != nil && udpAddr.IP.To4() == nil { - return nil, fmt.Errorf("%s is not an IPv4 address", addr) - } - udpAddr.IP = udpAddr.IP.To4() - if udpAddr.Port == 0 { - return nil, fmt.Errorf("%s must specify a listen port", addr) +func newLinuxConn(port int) (conn, error) { + if port == 0 { + return nil, errors.New("must specify a listen port") } filter, err := bpf.Assemble([]bpf.Instruction{ @@ -70,7 +55,7 @@ func newLinuxConn(addr string) (conn, error) { // Get UDP dport bpf.LoadIndirect{Off: 2, Size: 2}, // Correct dport? - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(udpAddr.Port), SkipFalse: 1}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(port), SkipFalse: 1}, // Accept bpf.RetConstant{Val: 1500}, // Ignore @@ -80,7 +65,7 @@ func newLinuxConn(addr string) (conn, error) { return nil, err } - c, err := net.ListenPacket("ip4:17", udpAddr.IP.String()) + c, err := net.ListenPacket("ip4:17", "0.0.0.0") if err != nil { return nil, err } @@ -99,7 +84,7 @@ func newLinuxConn(addr string) (conn, error) { } ret := &linuxConn{ - port: uint16(udpAddr.Port), + port: uint16(port), conn: r, } return ret, nil diff --git a/dhcp4/conn_test.go b/dhcp4/conn_test.go index 560fe3b..bd09d3c 100644 --- a/dhcp4/conn_test.go +++ b/dhcp4/conn_test.go @@ -24,7 +24,7 @@ import ( ) func testConn(t *testing.T, impl conn, addr string) { - c := &Conn{impl} + c := &Conn{impl, 0} s, err := net.Dial("udp4", addr) if err != nil { @@ -121,10 +121,11 @@ func TestPortableConn(t *testing.T) { if err != nil { t.Fatal(err) } + port := l.LocalAddr().(*net.UDPAddr).Port addr := l.LocalAddr().String() l.Close() - c, err := newPortableConn(addr) + c, err := newPortableConn(port) if err != nil { t.Fatalf("creating the conn: %s", err) } @@ -146,7 +147,7 @@ func TestLinuxConn(t *testing.T) { if err != nil { t.Fatal(err) } - c, err := newLinuxConn(l.LocalAddr().String()) + c, err := newLinuxConn(l.LocalAddr().(*net.UDPAddr).Port) if err != nil { t.Fatalf("creating the linuxconn: %s", err) }