netboot/dhcp4/conn_test.go
David Anderson e9e2bb6f2d dhcp4: avoid reusing bs between the writing and reading tests.
Although it looks like the sequencing of writing bs to the socket
and reusing the slice for the read test is guaranteed by there being
a receive from the socket in between, the race detector disagrees, and
found a race.
2017-06-01 20:07:28 -07:00

157 lines
3.5 KiB
Go

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dhcp4
import (
"net"
"os"
"reflect"
"runtime"
"testing"
"time"
)
func testConn(t *testing.T, impl conn, addr string) {
c := &Conn{impl, 0}
s, err := net.Dial("udp4", addr)
if err != nil {
t.Fatal(err)
}
mac, err := net.ParseMAC("ce:e7:7b:ef:45:f7")
if err != nil {
t.Fatal(err)
}
p := &Packet{
Type: MsgDiscover,
TransactionID: "1234",
Broadcast: true,
HardwareAddr: mac,
}
bs, err := p.Marshal()
if err != nil {
t.Fatalf("marshaling packet: %s", err)
}
// Unmarshal the packet again, to smooth out representation
// differences (e.g. nil IP vs. IP set to 0.0.0.0).
p, err = Unmarshal(bs)
if err != nil {
t.Fatal(err)
}
go func() {
s.Write(bs)
}()
if err = c.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
rpkt, intf, err := c.RecvDHCP()
if err != nil {
t.Fatalf("reading DHCP packet: %s", err)
}
if !reflect.DeepEqual(p, rpkt) {
t.Fatalf("DHCP packet not the same as when it was sent")
}
// Test writing
p.ClientAddr = net.IPv4(127, 0, 0, 1)
dhcpClientPort = s.LocalAddr().(*net.UDPAddr).Port
bs2, err := p.Marshal()
if err != nil {
t.Fatalf("marshaling packet: %s", err)
}
// Unmarshal the packet again, to smooth out representation
// differences (e.g. nil IP vs. IP set to 0.0.0.0).
p, err = Unmarshal(bs2)
if err != nil {
t.Fatal(err)
}
defer func() { dhcpClientPort = 68 }()
ch := make(chan *Packet, 1)
go func() {
s.SetReadDeadline(time.Now().Add(time.Second))
var buf [1500]byte
n, err := s.Read(buf[:])
if err != nil {
t.Errorf("reading DHCP packet sent by conn_linux: %s", err)
ch <- nil
return
}
pkt, err := Unmarshal(buf[:n])
if err != nil {
t.Errorf("decoding DHCP packet: %s", err)
ch <- nil
return
}
ch <- pkt
}()
if err = c.SendDHCP(p, intf); err != nil {
t.Fatalf("sending DHCP packet: %s", err)
}
rpkt = <-ch
if rpkt == nil {
t.FailNow()
}
if !reflect.DeepEqual(p, rpkt) {
t.Fatalf("DHCP packet not the same as when it was sent")
}
}
func TestPortableConn(t *testing.T) {
// Use a listener to grab a free port, but we don't use it beyond
// that.
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := l.LocalAddr().(*net.UDPAddr).Port
addr := l.LocalAddr().String()
l.Close()
c, err := newPortableConn(port)
if err != nil {
t.Fatalf("creating the conn: %s", err)
}
testConn(t, c, addr)
}
func TestLinuxConn(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("not supported on %s", runtime.GOOS)
}
if os.Getuid() != 0 {
t.Skipf("must be root on %s", runtime.GOOS)
}
// Use a listener to grab a free port, but we don't use it beyond
// that.
l, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
c, err := newLinuxConn(l.LocalAddr().(*net.UDPAddr).Port)
if err != nil {
t.Fatalf("creating the linuxconn: %s", err)
}
testConn(t, c, l.LocalAddr().String())
}