mirror of
https://github.com/danderson/netboot.git
synced 2025-10-17 10:31:28 +02:00
dhcp4: refactor the low-level packet sending.
This lets the implementations share the transmission strategy and decoding/sanity logic.
This commit is contained in:
parent
6a00ec441b
commit
53e4bfd991
181
dhcp4/conn.go
181
dhcp4/conn.go
@ -26,7 +26,7 @@ import (
|
||||
// defined as a var so tests can override it.
|
||||
var (
|
||||
dhcpClientPort = 68
|
||||
platformConn func(string) (Conn, error)
|
||||
platformConn func(string) (conn, error)
|
||||
)
|
||||
|
||||
// txType describes how a Packet should be sent on the wire.
|
||||
@ -51,44 +51,105 @@ const (
|
||||
txHardwareAddr
|
||||
)
|
||||
|
||||
type conn interface {
|
||||
io.Closer
|
||||
Recv([]byte) (b []byte, addr *net.UDPAddr, ifidx int, err error)
|
||||
Send(b []byte, addr *net.UDPAddr, ifidx int) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// Conn is a DHCP-oriented packet socket.
|
||||
//
|
||||
// Multiple goroutines may invoke methods on a Conn simultaneously.
|
||||
type Conn interface {
|
||||
io.Closer
|
||||
// RecvDHCP reads a Packet from the connection. It returns the
|
||||
// packet and the interface it was received on, which may be nil
|
||||
// if interface information cannot be obtained.
|
||||
RecvDHCP() (pkt *Packet, intf *net.Interface, err error)
|
||||
// SendDHCP sends pkt. The precise transmission mechanism depends
|
||||
// on pkt.txType(). intf should be the net.Interface returned by
|
||||
// RecvDHCP if responding to a DHCP client, or the interface for
|
||||
// which configuration is desired if acting as a client.
|
||||
SendDHCP(pkt *Packet, intf *net.Interface) error
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
// If the deadline is reached, Read will fail with a timeout
|
||||
// (see type Error) instead of blocking.
|
||||
// A zero value for t means Read will not time out.
|
||||
SetReadDeadline(t time.Time) error
|
||||
type Conn struct {
|
||||
conn conn
|
||||
}
|
||||
|
||||
// NewConn creates a Conn bound to the given UDP ip:port.
|
||||
func NewConn(addr string) (Conn, error) {
|
||||
func NewConn(addr string) (*Conn, error) {
|
||||
if platformConn != nil {
|
||||
c, err := platformConn(addr)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
return &Conn{c}, nil
|
||||
}
|
||||
}
|
||||
// Always try falling back to the portable implementation
|
||||
return newPortableConn(addr)
|
||||
c, err := newPortableConn(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{c}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||
var buf [1500]byte
|
||||
for {
|
||||
b, _, ifidx, err := c.conn.Recv(buf[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pkt, err := Unmarshal(b)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
intf, err := net.InterfaceByIndex(ifidx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// TODO: possibly more validation that the source lines up
|
||||
// with what the packet says.
|
||||
return pkt, intf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
b, err := pkt.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch pkt.txType() {
|
||||
case txBroadcast, txHardwareAddr:
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
return c.conn.Send(b, &addr, intf.Index)
|
||||
case txRelayAddr:
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.RelayAddr,
|
||||
Port: 67,
|
||||
}
|
||||
return c.conn.Send(b, &addr, 0)
|
||||
case txClientAddr:
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.ClientAddr,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
return c.conn.Send(b, &addr, 0)
|
||||
default:
|
||||
return errors.New("unknown TX type for packet")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
type portableConn struct {
|
||||
conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func newPortableConn(addr string) (Conn, error) {
|
||||
func newPortableConn(addr string) (conn, error) {
|
||||
c, err := net.ListenPacket("udp4", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -105,64 +166,30 @@ func (c *portableConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *portableConn) Recv(b []byte) (rb []byte, addr *net.UDPAddr, ifidx int, err error) {
|
||||
n, cm, a, err := c.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
return b[:n], a.(*net.UDPAddr), cm.IfIndex, nil
|
||||
}
|
||||
|
||||
func (c *portableConn) Send(b []byte, addr *net.UDPAddr, ifidx int) error {
|
||||
if ifidx > 0 {
|
||||
_, err := c.conn.WriteTo(b, nil, addr)
|
||||
return err
|
||||
}
|
||||
cm := ipv4.ControlMessage{
|
||||
IfIndex: ifidx,
|
||||
}
|
||||
_, err := c.conn.WriteTo(b, &cm, addr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *portableConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||
var buf [1500]byte
|
||||
for {
|
||||
n, cm, _, err := c.conn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pkt, err := Unmarshal(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// TODO: possibly more validation that the source lines up
|
||||
// with what the packet.
|
||||
return pkt, intf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
b, err := pkt.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch pkt.txType() {
|
||||
case txBroadcast, txHardwareAddr:
|
||||
cm := ipv4.ControlMessage{
|
||||
IfIndex: intf.Index,
|
||||
}
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, &cm, &addr)
|
||||
return err
|
||||
case txRelayAddr:
|
||||
// Send to the server port, not the client port.
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.RelayAddr,
|
||||
Port: 67,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||
return err
|
||||
case txClientAddr:
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.ClientAddr,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||
return err
|
||||
default:
|
||||
return errors.New("unknown TX type for packet")
|
||||
}
|
||||
func (c *portableConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func init() {
|
||||
platformConn = newLinuxConn
|
||||
}
|
||||
|
||||
func newLinuxConn(addr string) (Conn, error) {
|
||||
func newLinuxConn(addr string) (conn, error) {
|
||||
if addr == "" {
|
||||
addr = ":67"
|
||||
}
|
||||
@ -101,45 +101,24 @@ func (c *linuxConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *linuxConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||
var buf [1500]byte
|
||||
for {
|
||||
_, p, cm, err := c.conn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(p) < 8 {
|
||||
continue
|
||||
}
|
||||
pkt, err := Unmarshal(p[8:])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// TODO: possibly more validation that the IPv4 header lines
|
||||
// up with what the packet.
|
||||
return pkt, intf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
b, err := pkt.Marshal()
|
||||
func (c *linuxConn) Recv(b []byte) (rb []byte, addr *net.UDPAddr, ifidx int, err error) {
|
||||
hdr, p, cm, err := c.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
if len(p) < 8 {
|
||||
return nil, nil, 0, errors.New("not a UDP packet, too short")
|
||||
}
|
||||
sport := int(binary.BigEndian.Uint16(p[:2]))
|
||||
return p[8:], &net.UDPAddr{IP: hdr.Src, Port: sport}, cm.IfIndex, nil
|
||||
}
|
||||
|
||||
func (c *linuxConn) Send(b []byte, addr *net.UDPAddr, ifidx int) error {
|
||||
raw := make([]byte, 8+len(b))
|
||||
// src port
|
||||
binary.BigEndian.PutUint16(raw[:2], c.port)
|
||||
// dst port
|
||||
binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort))
|
||||
binary.BigEndian.PutUint16(raw[2:4], uint16(addr.Port))
|
||||
// length
|
||||
binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b)))
|
||||
copy(raw[8:], b)
|
||||
@ -151,24 +130,22 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
TotalLen: ipv4.HeaderLen + 8 + len(b),
|
||||
TTL: 64,
|
||||
Protocol: 17,
|
||||
Dst: addr.IP,
|
||||
}
|
||||
|
||||
switch pkt.txType() {
|
||||
case txBroadcast, txHardwareAddr:
|
||||
hdr.Dst = net.IPv4bcast
|
||||
if ifidx > 0 {
|
||||
cm := ipv4.ControlMessage{
|
||||
IfIndex: intf.Index,
|
||||
IfIndex: ifidx,
|
||||
}
|
||||
return c.conn.WriteTo(&hdr, raw, &cm)
|
||||
case txRelayAddr:
|
||||
// Send to the server port, not the client port.
|
||||
binary.BigEndian.PutUint16(raw[2:4], 67)
|
||||
hdr.Dst = pkt.RelayAddr
|
||||
return c.conn.WriteTo(&hdr, raw, nil)
|
||||
case txClientAddr:
|
||||
hdr.Dst = pkt.ClientAddr
|
||||
return c.conn.WriteTo(&hdr, raw, nil)
|
||||
default:
|
||||
return errors.New("unknown TX type for packet")
|
||||
}
|
||||
return c.conn.WriteTo(&hdr, raw, nil)
|
||||
}
|
||||
|
||||
func (c *linuxConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *linuxConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
@ -23,7 +23,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func testConn(t *testing.T, c Conn, addr string) {
|
||||
func testConn(t *testing.T, impl conn, addr string) {
|
||||
c := &Conn{impl}
|
||||
|
||||
s, err := net.Dial("udp4", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user