mirror of
https://github.com/danderson/netboot.git
synced 2025-10-16 18:11:21 +02:00
dhcp: implement a portable Conn.
Unlike the linux Conn, this one cannot share a port with another process, but it should work across platforms.
This commit is contained in:
parent
c5449b945c
commit
45dd5dfe4d
122
dhcp/conn.go
Normal file
122
dhcp/conn.go
Normal file
@ -0,0 +1,122 @@
|
||||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dhcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
// defined as a var so tests can override it.
|
||||
var dhcpClientPort = 68
|
||||
|
||||
var platformConn func(string) (Conn, error)
|
||||
|
||||
func NewConn(addr string) (Conn, error) {
|
||||
if platformConn != nil {
|
||||
c, err := platformConn(addr)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
// Always try falling back to the portable implementation
|
||||
return newPortableConn(addr)
|
||||
}
|
||||
|
||||
type portableConn struct {
|
||||
conn *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func newPortableConn(addr string) (Conn, error) {
|
||||
c, err := net.ListenPacket("udp4", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := ipv4.NewPacketConn(c)
|
||||
if err = l.SetControlMessage(ipv4.FlagInterface, true); err != nil {
|
||||
l.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &portableConn{l}, nil
|
||||
}
|
||||
|
||||
func (c *portableConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *portableConn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *portableConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||
var buf [1500]byte
|
||||
for {
|
||||
n, cm, _, err := c.conn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pkt, err := Unmarshal(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// TODO: possibly more validation that the source lines up
|
||||
// with what the packet.
|
||||
return pkt, intf, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *portableConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
b, err := pkt.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch pkt.TxType() {
|
||||
case TxBroadcast, TxHardwareAddr:
|
||||
cm := ipv4.ControlMessage{
|
||||
IfIndex: intf.Index,
|
||||
}
|
||||
addr := net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, &cm, &addr)
|
||||
return err
|
||||
case TxRelayAddr:
|
||||
// Send to the server port, not the client port.
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.RelayAddr,
|
||||
Port: 67,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||
return err
|
||||
case TxClientAddr:
|
||||
addr := net.UDPAddr{
|
||||
IP: pkt.ClientAddr,
|
||||
Port: dhcpClientPort,
|
||||
}
|
||||
_, err = c.conn.WriteTo(b, nil, &addr)
|
||||
return err
|
||||
default:
|
||||
return errors.New("unknown TX type for packet")
|
||||
}
|
||||
}
|
@ -31,15 +31,16 @@ const (
|
||||
udpProtocolNumber = 17
|
||||
)
|
||||
|
||||
// defined as a var so tests can override it.
|
||||
var dhcpClientPort = uint16(68)
|
||||
|
||||
type linuxConn struct {
|
||||
port uint16
|
||||
conn *ipv4.RawConn
|
||||
}
|
||||
|
||||
func NewLinuxConn(addr string) (Conn, error) {
|
||||
func init() {
|
||||
platformConn = newLinuxConn
|
||||
}
|
||||
|
||||
func newLinuxConn(addr string) (Conn, error) {
|
||||
if addr == "" {
|
||||
addr = ":67"
|
||||
}
|
||||
@ -116,7 +117,6 @@ func (c *linuxConn) RecvDHCP() (*Packet, *net.Interface, error) {
|
||||
}
|
||||
pkt, err := Unmarshal(p[8:])
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
}
|
||||
intf, err := net.InterfaceByIndex(cm.IfIndex)
|
||||
@ -139,7 +139,7 @@ func (c *linuxConn) SendDHCP(pkt *Packet, intf *net.Interface) error {
|
||||
// src port
|
||||
binary.BigEndian.PutUint16(raw[:2], c.port)
|
||||
// dst port
|
||||
binary.BigEndian.PutUint16(raw[2:4], dhcpClientPort)
|
||||
binary.BigEndian.PutUint16(raw[2:4], uint16(dhcpClientPort))
|
||||
// length
|
||||
binary.BigEndian.PutUint16(raw[4:6], uint16(8+len(b)))
|
||||
copy(raw[8:], b)
|
||||
|
@ -23,26 +23,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnLinux(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("not supported on %s", runtime.GOOS)
|
||||
}
|
||||
if os.Getuid() != 0 {
|
||||
t.Skipf("must be root on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// Use a listener to grab a free port, but we don't use it beyond
|
||||
// that.
|
||||
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := NewLinuxConn(l.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("creating the linuxconn: %s", err)
|
||||
}
|
||||
|
||||
s, err := net.Dial("udp4", l.LocalAddr().String())
|
||||
func testConn(t *testing.T, c Conn, addr string) {
|
||||
s, err := net.Dial("udp4", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -85,7 +67,7 @@ func TestConnLinux(t *testing.T) {
|
||||
|
||||
// Test writing
|
||||
p.ClientAddr = net.IPv4(127, 0, 0, 1)
|
||||
dhcpClientPort = uint16(s.LocalAddr().(*net.UDPAddr).Port)
|
||||
dhcpClientPort = s.LocalAddr().(*net.UDPAddr).Port
|
||||
bs, err = p.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("marshaling packet: %s", err)
|
||||
@ -129,3 +111,43 @@ func TestConnLinux(t *testing.T) {
|
||||
t.Fatalf("DHCP packet not the same as when it was sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableConn(t *testing.T) {
|
||||
// Use a listener to grab a free port, but we don't use it beyond
|
||||
// that.
|
||||
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addr := l.LocalAddr().String()
|
||||
l.Close()
|
||||
|
||||
c, err := newPortableConn(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("creating the conn: %s", err)
|
||||
}
|
||||
|
||||
testConn(t, c, addr)
|
||||
}
|
||||
|
||||
func TestLinuxConn(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("not supported on %s", runtime.GOOS)
|
||||
}
|
||||
if os.Getuid() != 0 {
|
||||
t.Skipf("must be root on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// Use a listener to grab a free port, but we don't use it beyond
|
||||
// that.
|
||||
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := newLinuxConn(l.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("creating the linuxconn: %s", err)
|
||||
}
|
||||
|
||||
testConn(t, c, l.LocalAddr().String())
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user