From 05b7b04527690281b14a5c1934b7d652c2ce9f5a Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 19 Feb 2026 08:20:07 -0600 Subject: [PATCH] net/rioconn: implement addrPortFromUDPAddr In this commit, we add a function that takes a network name (e.g., udp, udp4, udp6) and an optional net.UDPAddr, and returns a netip.AddrPort along with a flag indicating whether the socket should be configured as dual-stack (IPv6 + IPv4). Updates tailscale/corp#8610 Signed-off-by: Nick Khyl --- net/rioconn/addr.go | 49 ++++++++++++++ net/rioconn/addr_test.go | 140 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/net/rioconn/addr.go b/net/rioconn/addr.go index 544bc568c..7c6d48a29 100644 --- a/net/rioconn/addr.go +++ b/net/rioconn/addr.go @@ -117,6 +117,15 @@ func (rsa rawSockaddr) ToAddrPort() (netip.AddrPort, error) { } } +// ToUDPAddr returns a [net.UDPAddr] representation of the receiver. +func (rsa rawSockaddr) ToUDPAddr() (*net.UDPAddr, error) { + ap, err := rsa.ToAddrPort() + if err != nil { + return nil, err + } + return net.UDPAddrFromAddrPort(ap), nil +} + func addrPortFromSocket(socket windows.Handle) (netip.AddrPort, error) { sa, err := windows.Getsockname(socket) if err != nil { @@ -150,6 +159,46 @@ func addrPortFromSockaddr(sa windows.Sockaddr) (netip.AddrPort, error) { } } +func addrPortFromUDPAddr(network string, addr *net.UDPAddr) (_ netip.AddrPort, dualStack bool, err error) { + if addr == nil { + // A nil address is equivalent to an unspecified address. + addr = &net.UDPAddr{} + } + + var ap netip.AddrPort + switch { + case addr.IP != nil: + // [net.IP] values are typically (always?) 16 bytes long, even for IPv4. + // As a result, [netip.AddrFromSlice] (and [net.UDPAddr.AddrPort], etc.) + // return IPv6-mapped IPv4 addresses. We need to unmap them back to IPv4 here + // if the network is not "udp6". + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + return netip.AddrPort{}, false, fmt.Errorf("invalid IP address: %v", addr.IP) + } + switch network { + case "udp", "udp4": + ip = ip.Unmap() + case "udp6": + // Keep as-is, even if it's an IPv4-mapped IPv6 address. + default: + return netip.AddrPort{}, false, net.UnknownNetworkError(network) + } + ip = ip.WithZone(addr.Zone) + ap = netip.AddrPortFrom(ip, uint16(addr.Port)) + case network == "udp": + ap = netip.AddrPortFrom(netip.IPv6Unspecified(), uint16(addr.Port)) + dualStack = true // dual-stack, unspecified address + case network == "udp4": + ap = netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(addr.Port)) + case network == "udp6": + ap = netip.AddrPortFrom(netip.IPv6Unspecified(), uint16(addr.Port)) + default: + return netip.AddrPort{}, false, net.UnknownNetworkError(network) + } + return ap, dualStack, nil +} + func sockaddrFromAddrPort(addr netip.AddrPort) (sa windows.Sockaddr, family int32, err error) { rsa, err := rawSockaddrFromAddrPort(addr) if err != nil { diff --git a/net/rioconn/addr_test.go b/net/rioconn/addr_test.go index 28640ab7b..3713155cb 100644 --- a/net/rioconn/addr_test.go +++ b/net/rioconn/addr_test.go @@ -294,6 +294,13 @@ func TestRawSockaddrToAddrPort(t *testing.T) { if ap != tt.want { t.Errorf("rawSockaddr.ToAddrPort(): got %v; want %v", ap, tt.want) } + gotUDPAddr, err := sa.ToUDPAddr() + if err != nil { + t.Fatalf("rawSockaddr.ToUDPAddr() error: %v", err) + } + if gotUDPAddr.AddrPort() != tt.want { + t.Errorf("rawSockaddr.ToUDPAddr(): got %v; want %v", gotUDPAddr, net.UDPAddrFromAddrPort(tt.want)) + } }) } } @@ -546,6 +553,139 @@ func TestAddrPortFromSockaddr(t *testing.T) { } } +func TestAddrPortFromUDPAddr(t *testing.T) { + t.Parallel() + iface := firstInterface(t) + tests := []struct { + name string + network string + udpAddr *net.UDPAddr + wantAddr netip.AddrPort + wantDualStack bool + wantErr bool + }{ + { + name: "nil/udp", + network: "udp", + udpAddr: nil, + wantDualStack: true, + wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 0), + }, + { + name: "nil/udp4", + network: "udp4", + udpAddr: nil, + wantDualStack: false, + wantAddr: netip.AddrPortFrom(netip.IPv4Unspecified(), 0), + }, + { + name: "nil/udp6", + network: "udp6", + udpAddr: nil, + wantDualStack: false, + wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 0), + }, + { + name: "unspecified/udp", + network: "udp", + udpAddr: &net.UDPAddr{Port: 1234}, + wantDualStack: true, + wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 1234), + }, + { + name: "unspecified/udp4", + network: "udp4", + udpAddr: &net.UDPAddr{Port: 1234}, + wantDualStack: false, + wantAddr: netip.AddrPortFrom(netip.IPv4Unspecified(), 1234), + }, + { + name: "unspecified/udp6", + network: "udp6", + udpAddr: &net.UDPAddr{Port: 1234}, + wantDualStack: false, + wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 1234), + }, + { + name: "IPv4/udp", + network: "udp", + udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 1234}, + wantDualStack: false, + wantAddr: netip.MustParseAddrPort("192.0.2.1:1234"), + }, + { + name: "IPv6/udp", + network: "udp", + udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + wantDualStack: false, + wantAddr: netip.MustParseAddrPort("[2001:db8::1]:1234"), + }, + { + name: "IPv6-with-zone/udp", + network: "udp", + udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234, Zone: iface.Name}, + wantDualStack: false, + wantAddr: netip.AddrPortFrom(netip.MustParseAddr("2001:db8::1").WithZone(iface.Name), 1234), + }, + { + name: "IPv4-mapped-IPv6/udp", + network: "udp", + udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1).To16(), Port: 1234}, + wantDualStack: false, + wantAddr: netip.MustParseAddrPort("192.0.2.1:1234"), + }, + { + name: "IPv4-mapped-IPv6/udp6", + network: "udp6", + udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1).To16(), Port: 1234}, + wantDualStack: false, + wantAddr: netip.MustParseAddrPort("[::ffff:192.0.2.1]:1234"), + }, + { + name: "nil/invalid-network", + network: "tcp", + udpAddr: nil, + wantErr: true, + }, + { + name: "IPv4/invalid-network", + network: "tcp", + udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 1234}, + wantErr: true, + }, + { + name: "IPv6/invalid-network", + network: "tcp", + udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + wantErr: true, + }, + { + name: "IP/invalid-address", + network: "udp", + udpAddr: &net.UDPAddr{IP: []byte{1, 2, 3, 4, 5}, Port: 1234}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotAddr, dualStack, err := addrPortFromUDPAddr(tt.network, tt.udpAddr) + if (err != nil) != tt.wantErr { + t.Fatalf("addrPortFromUDPAddr: error: got %v; want %v", err, tt.wantErr) + } + if err != nil { + return + } + if gotAddr != tt.wantAddr { + t.Errorf("addrPortFromUDPAddr: got addr %v; want %v", gotAddr, tt.wantAddr) + } + if dualStack != tt.wantDualStack { + t.Errorf("addrPortFromUDPAddr: dualStack: got %v; want %v", dualStack, tt.wantDualStack) + } + }) + } +} + func TestNetAddrFromAddrPort(t *testing.T) { t.Parallel() tests := []struct {