mirror of
https://github.com/miekg/dns.git
synced 2025-12-16 09:11:34 +01:00
feat: add support for ReuseAddr (#1510)
* feat: add support for ReuseAddr * Update listen_reuseport.go * Update listen_reuseport.go * fixup! feat: add support for ReuseAddr --------- Co-authored-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
3d593a6b1d
commit
257e89e9ba
@ -7,16 +7,18 @@ import "net"
|
|||||||
|
|
||||||
const supportsReusePort = false
|
const supportsReusePort = false
|
||||||
|
|
||||||
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
|
func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
|
||||||
if reuseport {
|
if reuseport || reuseaddr {
|
||||||
// TODO(tmthrgd): return an error?
|
// TODO(tmthrgd): return an error?
|
||||||
}
|
}
|
||||||
|
|
||||||
return net.Listen(network, addr)
|
return net.Listen(network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
|
const supportsReuseAddr = false
|
||||||
if reuseport {
|
|
||||||
|
func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
|
||||||
|
if reuseport || reuseaddr {
|
||||||
// TODO(tmthrgd): return an error?
|
// TODO(tmthrgd): return an error?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,19 +25,41 @@ func reuseportControl(network, address string, c syscall.RawConn) error {
|
|||||||
return opErr
|
return opErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
|
const supportsReuseAddr = true
|
||||||
|
|
||||||
|
func reuseaddrControl(network, address string, c syscall.RawConn) error {
|
||||||
|
var opErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return opErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
|
||||||
var lc net.ListenConfig
|
var lc net.ListenConfig
|
||||||
if reuseport {
|
switch {
|
||||||
|
case reuseaddr && reuseport:
|
||||||
|
case reuseport:
|
||||||
lc.Control = reuseportControl
|
lc.Control = reuseportControl
|
||||||
|
case reuseaddr:
|
||||||
|
lc.Control = reuseaddrControl
|
||||||
}
|
}
|
||||||
|
|
||||||
return lc.Listen(context.Background(), network, addr)
|
return lc.Listen(context.Background(), network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
|
func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
|
||||||
var lc net.ListenConfig
|
var lc net.ListenConfig
|
||||||
if reuseport {
|
switch {
|
||||||
|
case reuseaddr && reuseport:
|
||||||
|
case reuseport:
|
||||||
lc.Control = reuseportControl
|
lc.Control = reuseportControl
|
||||||
|
case reuseaddr:
|
||||||
|
lc.Control = reuseaddrControl
|
||||||
}
|
}
|
||||||
|
|
||||||
return lc.ListenPacket(context.Background(), network, addr)
|
return lc.ListenPacket(context.Background(), network, addr)
|
||||||
|
|||||||
10
server.go
10
server.go
@ -226,6 +226,10 @@ type Server struct {
|
|||||||
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
|
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
|
||||||
// It is only supported on certain GOOSes and when using ListenAndServe.
|
// It is only supported on certain GOOSes and when using ListenAndServe.
|
||||||
ReusePort bool
|
ReusePort bool
|
||||||
|
// Whether to set the SO_REUSEADDR socket option, allowing multiple listeners to be bound to a single address.
|
||||||
|
// Crucially this allows binding when an existing server is listening on `0.0.0.0` or `::`.
|
||||||
|
// It is only supported on certain GOOSes and when using ListenAndServe.
|
||||||
|
ReuseAddr bool
|
||||||
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
|
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
|
||||||
// By default DefaultMsgAcceptFunc will be used.
|
// By default DefaultMsgAcceptFunc will be used.
|
||||||
MsgAcceptFunc MsgAcceptFunc
|
MsgAcceptFunc MsgAcceptFunc
|
||||||
@ -304,7 +308,7 @@ func (srv *Server) ListenAndServe() error {
|
|||||||
|
|
||||||
switch srv.Net {
|
switch srv.Net {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
l, err := listenTCP(srv.Net, addr, srv.ReusePort)
|
l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -317,7 +321,7 @@ func (srv *Server) ListenAndServe() error {
|
|||||||
return errors.New("dns: neither Certificates nor GetCertificate set in Config")
|
return errors.New("dns: neither Certificates nor GetCertificate set in Config")
|
||||||
}
|
}
|
||||||
network := strings.TrimSuffix(srv.Net, "-tls")
|
network := strings.TrimSuffix(srv.Net, "-tls")
|
||||||
l, err := listenTCP(network, addr, srv.ReusePort)
|
l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -327,7 +331,7 @@ func (srv *Server) ListenAndServe() error {
|
|||||||
unlock()
|
unlock()
|
||||||
return srv.serveTCP(l)
|
return srv.serveTCP(l)
|
||||||
case "udp", "udp4", "udp6":
|
case "udp", "udp4", "udp6":
|
||||||
l, err := listenUDP(srv.Net, addr, srv.ReusePort)
|
l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
171
server_test.go
171
server_test.go
@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -1041,6 +1042,176 @@ func TestServerReuseport(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerReuseaddr(t *testing.T) {
|
||||||
|
startServerFn := func(t *testing.T, network, addr string, expectSuccess bool) (*Server, chan error) {
|
||||||
|
t.Helper()
|
||||||
|
wait := make(chan struct{})
|
||||||
|
srv := &Server{
|
||||||
|
Net: network,
|
||||||
|
Addr: addr,
|
||||||
|
NotifyStartedFunc: func() { close(wait) },
|
||||||
|
ReuseAddr: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
fin := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
fin <- srv.ListenAndServe()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-wait:
|
||||||
|
case err := <-fin:
|
||||||
|
switch {
|
||||||
|
case expectSuccess:
|
||||||
|
t.Fatalf("%s: failed to start server: %v", t.Name(), err)
|
||||||
|
default:
|
||||||
|
fin <- err
|
||||||
|
return nil, fin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return srv, fin
|
||||||
|
}
|
||||||
|
|
||||||
|
externalIPFn := func(t *testing.T) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
continue // interface down
|
||||||
|
}
|
||||||
|
if iface.Flags&net.FlagLoopback != 0 {
|
||||||
|
continue // loopback interface
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
}
|
||||||
|
if ip == nil || ip.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = ip.To4()
|
||||||
|
if ip == nil {
|
||||||
|
continue // not an ipv4 address
|
||||||
|
}
|
||||||
|
return ip.String(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.New("are you connected to the network?")
|
||||||
|
}
|
||||||
|
|
||||||
|
freePortFn := func(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable resolve tcp addr: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.ListenTCP("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable listen tcp: %s", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should-fail-tcp", func(t *testing.T) {
|
||||||
|
// ReuseAddr should fail if you try to bind to exactly the same
|
||||||
|
// combination of source address and port.
|
||||||
|
// This should fail whether or not ReuseAddr is supported on a
|
||||||
|
// particular OS
|
||||||
|
ip, err := externalIPFn(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("no external IPs found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
port := freePortFn(t)
|
||||||
|
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||||
|
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), false)
|
||||||
|
switch {
|
||||||
|
case srv2 != nil && srv2.started:
|
||||||
|
t.Fatalf("second ListenAndServe should not have started")
|
||||||
|
default:
|
||||||
|
if err := <-fin2; err == nil {
|
||||||
|
t.Fatalf("second ListenAndServe should have returned a startup error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv1.Shutdown(); err != nil {
|
||||||
|
t.Fatalf("failed to shutdown first server: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-fin1; err != nil {
|
||||||
|
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("should-succeed-tcp", func(t *testing.T) {
|
||||||
|
if !supportsReuseAddr {
|
||||||
|
t.Skip("reuseaddr is not supported")
|
||||||
|
}
|
||||||
|
ip, err := externalIPFn(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("no external IPs found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
port := freePortFn(t)
|
||||||
|
|
||||||
|
// ReuseAddr should succeed if you try to bind to the same port but a different source address
|
||||||
|
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("localhost:%d", port), true)
|
||||||
|
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||||
|
|
||||||
|
if err := srv1.Shutdown(); err != nil {
|
||||||
|
t.Fatalf("failed to shutdown first server: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv2.Shutdown(); err != nil {
|
||||||
|
t.Fatalf("failed to shutdown second server: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-fin1; err != nil {
|
||||||
|
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-fin2; err != nil {
|
||||||
|
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("should-succeed-udp", func(t *testing.T) {
|
||||||
|
if !supportsReuseAddr {
|
||||||
|
t.Skip("reuseaddr is not supported")
|
||||||
|
}
|
||||||
|
ip, err := externalIPFn(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("no external IPs found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
port := freePortFn(t)
|
||||||
|
|
||||||
|
// ReuseAddr should succeed if you try to bind to the same port but a different source address
|
||||||
|
srv1, fin1 := startServerFn(t, "udp", fmt.Sprintf("localhost:%d", port), true)
|
||||||
|
srv2, fin2 := startServerFn(t, "udp", fmt.Sprintf("%s:%d", ip, port), true)
|
||||||
|
|
||||||
|
if err := srv1.Shutdown(); err != nil {
|
||||||
|
t.Fatalf("failed to shutdown first server: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv2.Shutdown(); err != nil {
|
||||||
|
t.Fatalf("failed to shutdown second server: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-fin1; err != nil {
|
||||||
|
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-fin2; err != nil {
|
||||||
|
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerRoundtripTsig(t *testing.T) {
|
func TestServerRoundtripTsig(t *testing.T) {
|
||||||
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
|
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user