mirror of
https://github.com/danderson/netboot.git
synced 2025-12-05 17:41:44 +01:00
dhcp: implement a portable Conn.
Unlike the linux Conn, this one cannot share a port with another process, but it should work across platforms.
This commit is contained in:
parent
c5449b945c
commit
45dd5dfe4d
122
dhcp/conn.go
Normal file
122
dhcp/conn.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
// Copyright 2016 Google Inc.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package dhcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defined as a var so tests can override it.
|
||||||
|
var dhcpClientPort = 68
|
||||||
|
|
||||||
|
var platformConn func(string) (Conn, error)
|
||||||
|
|
||||||
|
func NewConn(addr string) (Conn, error) {
|
||||||
|
if platformConn != nil {
|
||||||
|
c, err := platformConn(addr)
|
||||||
|
if err == nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Always try falling back to the portable implementation
|
||||||
|
return newPortableConn(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type portableConn struct {
|
||||||
|
conn *ipv4.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPortableConn(addr string) (Conn, error) {
|
||||||
|
c, err := net.ListenPacket("udp4", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := ipv4.NewPacketConn(c)
|
||||||
|
if err = l.SetControlMessage(ipv4.FlagInterface, true); err != nil {
|
||||||
|
l.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &portableConn{l}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *portableConn) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *portableConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||||
|
var buf [1500]byte
|
||||||
|
for {
|
||||||
|
n, cm, _, err := c.conn.ReadFrom(buf[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
pkt, err := Unmarshal(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// TODO: possibly more validation that the source lines up
|
||||||
|
// with what the packet.
|
||||||
|
return pkt, intf, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||||
|
b, err := pkt.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pkt.TxType() {
|
||||||
|
case TxBroadcast, TxHardwareAddr:
|
||||||
|
cm := ipv4.ControlMessage{
|
||||||
|
IfIndex: intf.Index,
|
||||||
|
}
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.IPv4bcast,
|
||||||
|
Port: dhcpClientPort,
|
||||||
|
}
|
||||||
|
_, err = c.conn.WriteTo(b, &cm, &addr)
|
||||||
|
return err
|
||||||
|
case TxRelayAddr:
|
||||||
|
// Send to the server port, not the client port.
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: pkt.RelayAddr,
|
||||||
|
Port: 67,
|
||||||
|
}
|
||||||
|
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||||
|
return err
|
||||||
|
case TxClientAddr:
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: pkt.ClientAddr,
|
||||||
|
Port: dhcpClientPort,
|
||||||
|
}
|
||||||
|
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
return errors.New("unknown TX type for packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -31,15 +31,16 @@ const (
|
|||||||
udpProtocolNumber = 17
|
udpProtocolNumber = 17
|
||||||
)
|
)
|
||||||
|
|
||||||
// defined as a var so tests can override it.
|
|
||||||
var dhcpClientPort = uint16(68)
|
|
||||||
|
|
||||||
type linuxConn struct {
|
type linuxConn struct {
|
||||||
port uint16
|
port uint16
|
||||||
conn *ipv4.RawConn
|
conn *ipv4.RawConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLinuxConn(addr string) (Conn, error) {
|
func init() {
|
||||||
|
platformConn = newLinuxConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLinuxConn(addr string) (Conn, error) {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
addr = ":67"
|
addr = ":67"
|
||||||
}
|
}
|
||||||
@ -116,7 +117,6 @@ func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
|||||||
}
|
}
|
||||||
pkt, err := Unmarshal(p[8:])
|
pkt, err := Unmarshal(p[8:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||||
@ -139,7 +139,7 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
|||||||
// src port
|
// src port
|
||||||
binary.BigEndian.PutUint16(raw[:2], c.port)
|
binary.BigEndian.PutUint16(raw[:2], c.port)
|
||||||
// dst port
|
// dst port
|
||||||
binary.BigEndian.PutUint16(raw[2:4], dhcpClientPort)
|
binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort))
|
||||||
// length
|
// length
|
||||||
binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b)))
|
binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b)))
|
||||||
copy(raw[8:], b)
|
copy(raw[8:], b)
|
||||||
|
|||||||
@ -23,26 +23,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnLinux(t *testing.T) {
|
func testConn(t *testing.T, c Conn, addr string) {
|
||||||
if runtime.GOOS != "linux" {
|
s, err := net.Dial("udp4", addr)
|
||||||
t.Skipf("not supported on %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
if os.Getuid() != 0 {
|
|
||||||
t.Skipf("must be root on %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use a listener to grab a free port, but we don't use it beyond
|
|
||||||
// that.
|
|
||||||
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
c, err := NewLinuxConn(l.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("creating the linuxconn: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := net.Dial("udp4", l.LocalAddr().String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -85,7 +67,7 @@ func TestConnLinux(t *testing.T) {
|
|||||||
|
|
||||||
// Test writing
|
// Test writing
|
||||||
p.ClientAddr = net.IPv4(127, 0, 0, 1)
|
p.ClientAddr = net.IPv4(127, 0, 0, 1)
|
||||||
dhcpClientPort = uint16(s.LocalAddr().(*net.UDPAddr).Port)
|
dhcpClientPort = s.LocalAddr().(*net.UDPAddr).Port
|
||||||
bs, err = p.Marshal()
|
bs, err = p.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("marshaling packet: %s", err)
|
t.Fatalf("marshaling packet: %s", err)
|
||||||
@ -129,3 +111,43 @@ func TestConnLinux(t *testing.T) {
|
|||||||
t.Fatalf("DHCP packet not the same as when it was sent")
|
t.Fatalf("DHCP packet not the same as when it was sent")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPortableConn(t *testing.T) {
|
||||||
|
// Use a listener to grab a free port, but we don't use it beyond
|
||||||
|
// that.
|
||||||
|
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addr := l.LocalAddr().String()
|
||||||
|
l.Close()
|
||||||
|
|
||||||
|
c, err := newPortableConn(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating the conn: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testConn(t, c, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinuxConn(t *testing.T) {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
t.Skipf("not supported on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
if os.Getuid() != 0 {
|
||||||
|
t.Skipf("must be root on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a listener to grab a free port, but we don't use it beyond
|
||||||
|
// that.
|
||||||
|
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
c, err := newLinuxConn(l.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating the linuxconn: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testConn(t, c, l.LocalAddr().String())
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user