mirror of
https://github.com/miekg/dns.git
synced 2025-12-15 16:51:48 +01:00
feat: add support for ReuseAddr (#1510)
* feat: add support for ReuseAddr * Update listen_reuseport.go * Update listen_reuseport.go * fixup! feat: add support for ReuseAddr --------- Co-authored-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
3d593a6b1d
commit
257e89e9ba
@ -7,16 +7,18 @@ import "net"
|
||||
|
||||
const supportsReusePort = false
|
||||
|
||||
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
|
||||
if reuseport {
|
||||
func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
|
||||
if reuseport || reuseaddr {
|
||||
// TODO(tmthrgd): return an error?
|
||||
}
|
||||
|
||||
return net.Listen(network, addr)
|
||||
}
|
||||
|
||||
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
|
||||
if reuseport {
|
||||
const supportsReuseAddr = false
|
||||
|
||||
func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
|
||||
if reuseport || reuseaddr {
|
||||
// TODO(tmthrgd): return an error?
|
||||
}
|
||||
|
||||
|
||||
@ -25,19 +25,41 @@ func reuseportControl(network, address string, c syscall.RawConn) error {
|
||||
return opErr
|
||||
}
|
||||
|
||||
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
|
||||
const supportsReuseAddr = true
|
||||
|
||||
func reuseaddrControl(network, address string, c syscall.RawConn) error {
|
||||
var opErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return opErr
|
||||
}
|
||||
|
||||
func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
|
||||
var lc net.ListenConfig
|
||||
if reuseport {
|
||||
switch {
|
||||
case reuseaddr && reuseport:
|
||||
case reuseport:
|
||||
lc.Control = reuseportControl
|
||||
case reuseaddr:
|
||||
lc.Control = reuseaddrControl
|
||||
}
|
||||
|
||||
return lc.Listen(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
|
||||
func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
|
||||
var lc net.ListenConfig
|
||||
if reuseport {
|
||||
switch {
|
||||
case reuseaddr && reuseport:
|
||||
case reuseport:
|
||||
lc.Control = reuseportControl
|
||||
case reuseaddr:
|
||||
lc.Control = reuseaddrControl
|
||||
}
|
||||
|
||||
return lc.ListenPacket(context.Background(), network, addr)
|
||||
|
||||
10
server.go
10
server.go
@ -226,6 +226,10 @@ type Server struct {
|
||||
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
|
||||
// It is only supported on certain GOOSes and when using ListenAndServe.
|
||||
ReusePort bool
|
||||
// Whether to set the SO_REUSEADDR socket option, allowing multiple listeners to be bound to a single address.
|
||||
// Crucially this allows binding when an existing server is listening on `0.0.0.0` or `::`.
|
||||
// It is only supported on certain GOOSes and when using ListenAndServe.
|
||||
ReuseAddr bool
|
||||
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
|
||||
// By default DefaultMsgAcceptFunc will be used.
|
||||
MsgAcceptFunc MsgAcceptFunc
|
||||
@ -304,7 +308,7 @@ func (srv *Server) ListenAndServe() error {
|
||||
|
||||
switch srv.Net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
l, err := listenTCP(srv.Net, addr, srv.ReusePort)
|
||||
l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -317,7 +321,7 @@ func (srv *Server) ListenAndServe() error {
|
||||
return errors.New("dns: neither Certificates nor GetCertificate set in Config")
|
||||
}
|
||||
network := strings.TrimSuffix(srv.Net, "-tls")
|
||||
l, err := listenTCP(network, addr, srv.ReusePort)
|
||||
l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -327,7 +331,7 @@ func (srv *Server) ListenAndServe() error {
|
||||
unlock()
|
||||
return srv.serveTCP(l)
|
||||
case "udp", "udp4", "udp6":
|
||||
l, err := listenUDP(srv.Net, addr, srv.ReusePort)
|
||||
l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
171
server_test.go
171
server_test.go
@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -1041,6 +1042,176 @@ func TestServerReuseport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerReuseaddr(t *testing.T) {
|
||||
startServerFn := func(t *testing.T, network, addr string, expectSuccess bool) (*Server, chan error) {
|
||||
t.Helper()
|
||||
wait := make(chan struct{})
|
||||
srv := &Server{
|
||||
Net: network,
|
||||
Addr: addr,
|
||||
NotifyStartedFunc: func() { close(wait) },
|
||||
ReuseAddr: true,
|
||||
}
|
||||
|
||||
fin := make(chan error, 1)
|
||||
go func() {
|
||||
fin <- srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-wait:
|
||||
case err := <-fin:
|
||||
switch {
|
||||
case expectSuccess:
|
||||
t.Fatalf("%s: failed to start server: %v", t.Name(), err)
|
||||
default:
|
||||
fin <- err
|
||||
return nil, fin
|
||||
}
|
||||
}
|
||||
return srv, fin
|
||||
}
|
||||
|
||||
externalIPFn := func(t *testing.T) (string, error) {
|
||||
t.Helper()
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue // interface down
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
continue // loopback interface
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
if ip == nil || ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
ip = ip.To4()
|
||||
if ip == nil {
|
||||
continue // not an ipv4 address
|
||||
}
|
||||
return ip.String(), nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("are you connected to the network?")
|
||||
}
|
||||
|
||||
freePortFn := func(t *testing.T) int {
|
||||
t.Helper()
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unable resolve tcp addr: %s", err)
|
||||
}
|
||||
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("unable listen tcp: %s", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
t.Run("should-fail-tcp", func(t *testing.T) {
|
||||
// ReuseAddr should fail if you try to bind to exactly the same
|
||||
// combination of source address and port.
|
||||
// This should fail whether or not ReuseAddr is supported on a
|
||||
// particular OS
|
||||
ip, err := externalIPFn(t)
|
||||
if err != nil {
|
||||
t.Skip("no external IPs found")
|
||||
return
|
||||
}
|
||||
port := freePortFn(t)
|
||||
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), false)
|
||||
switch {
|
||||
case srv2 != nil && srv2.started:
|
||||
t.Fatalf("second ListenAndServe should not have started")
|
||||
default:
|
||||
if err := <-fin2; err == nil {
|
||||
t.Fatalf("second ListenAndServe should have returned a startup error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := srv1.Shutdown(); err != nil {
|
||||
t.Fatalf("failed to shutdown first server: %v", err)
|
||||
}
|
||||
if err := <-fin1; err != nil {
|
||||
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("should-succeed-tcp", func(t *testing.T) {
|
||||
if !supportsReuseAddr {
|
||||
t.Skip("reuseaddr is not supported")
|
||||
}
|
||||
ip, err := externalIPFn(t)
|
||||
if err != nil {
|
||||
t.Skip("no external IPs found")
|
||||
return
|
||||
}
|
||||
port := freePortFn(t)
|
||||
|
||||
// ReuseAddr should succeed if you try to bind to the same port but a different source address
|
||||
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("localhost:%d", port), true)
|
||||
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||
|
||||
if err := srv1.Shutdown(); err != nil {
|
||||
t.Fatalf("failed to shutdown first server: %v", err)
|
||||
}
|
||||
if err := srv2.Shutdown(); err != nil {
|
||||
t.Fatalf("failed to shutdown second server: %v", err)
|
||||
}
|
||||
if err := <-fin1; err != nil {
|
||||
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||
}
|
||||
if err := <-fin2; err != nil {
|
||||
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("should-succeed-udp", func(t *testing.T) {
|
||||
if !supportsReuseAddr {
|
||||
t.Skip("reuseaddr is not supported")
|
||||
}
|
||||
ip, err := externalIPFn(t)
|
||||
if err != nil {
|
||||
t.Skip("no external IPs found")
|
||||
return
|
||||
}
|
||||
port := freePortFn(t)
|
||||
|
||||
// ReuseAddr should succeed if you try to bind to the same port but a different source address
|
||||
srv1, fin1 := startServerFn(t, "udp", fmt.Sprintf("localhost:%d", port), true)
|
||||
srv2, fin2 := startServerFn(t, "udp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||
|
||||
if err := srv1.Shutdown(); err != nil {
|
||||
t.Fatalf("failed to shutdown first server: %v", err)
|
||||
}
|
||||
if err := srv2.Shutdown(); err != nil {
|
||||
t.Fatalf("failed to shutdown second server: %v", err)
|
||||
}
|
||||
if err := <-fin1; err != nil {
|
||||
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||
}
|
||||
if err := <-fin2; err != nil {
|
||||
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerRoundtripTsig(t *testing.T) {
|
||||
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user