net/rioconn: implement addrPortFromUDPAddr

In this commit, we add a function that takes a network name (e.g., udp, udp4, udp6)
and an optional net.UDPAddr, and returns a netip.AddrPort along with a flag indicating
whether the socket should be configured as dual-stack (IPv6 + IPv4).

Updates tailscale/corp#8610

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2026-02-19 08:20:07 -06:00
parent 3b5180e12e
commit 05b7b04527
No known key found for this signature in database
2 changed files with 189 additions and 0 deletions

View File

@ -117,6 +117,15 @@ func (rsa rawSockaddr) ToAddrPort() (netip.AddrPort, error) {
}
}
// ToUDPAddr returns a [net.UDPAddr] representation of the receiver.
func (rsa rawSockaddr) ToUDPAddr() (*net.UDPAddr, error) {
ap, err := rsa.ToAddrPort()
if err != nil {
return nil, err
}
return net.UDPAddrFromAddrPort(ap), nil
}
func addrPortFromSocket(socket windows.Handle) (netip.AddrPort, error) {
sa, err := windows.Getsockname(socket)
if err != nil {
@ -150,6 +159,46 @@ func addrPortFromSockaddr(sa windows.Sockaddr) (netip.AddrPort, error) {
}
}
func addrPortFromUDPAddr(network string, addr *net.UDPAddr) (_ netip.AddrPort, dualStack bool, err error) {
if addr == nil {
// A nil address is equivalent to an unspecified address.
addr = &net.UDPAddr{}
}
var ap netip.AddrPort
switch {
case addr.IP != nil:
// [net.IP] values are typically (always?) 16 bytes long, even for IPv4.
// As a result, [netip.AddrFromSlice] (and [net.UDPAddr.AddrPort], etc.)
// return IPv6-mapped IPv4 addresses. We need to unmap them back to IPv4 here
// if the network is not "udp6".
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.AddrPort{}, false, fmt.Errorf("invalid IP address: %v", addr.IP)
}
switch network {
case "udp", "udp4":
ip = ip.Unmap()
case "udp6":
// Keep as-is, even if it's an IPv4-mapped IPv6 address.
default:
return netip.AddrPort{}, false, net.UnknownNetworkError(network)
}
ip = ip.WithZone(addr.Zone)
ap = netip.AddrPortFrom(ip, uint16(addr.Port))
case network == "udp":
ap = netip.AddrPortFrom(netip.IPv6Unspecified(), uint16(addr.Port))
dualStack = true // dual-stack, unspecified address
case network == "udp4":
ap = netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(addr.Port))
case network == "udp6":
ap = netip.AddrPortFrom(netip.IPv6Unspecified(), uint16(addr.Port))
default:
return netip.AddrPort{}, false, net.UnknownNetworkError(network)
}
return ap, dualStack, nil
}
func sockaddrFromAddrPort(addr netip.AddrPort) (sa windows.Sockaddr, family int32, err error) {
rsa, err := rawSockaddrFromAddrPort(addr)
if err != nil {

View File

@ -294,6 +294,13 @@ func TestRawSockaddrToAddrPort(t *testing.T) {
if ap != tt.want {
t.Errorf("rawSockaddr.ToAddrPort(): got %v; want %v", ap, tt.want)
}
gotUDPAddr, err := sa.ToUDPAddr()
if err != nil {
t.Fatalf("rawSockaddr.ToUDPAddr() error: %v", err)
}
if gotUDPAddr.AddrPort() != tt.want {
t.Errorf("rawSockaddr.ToUDPAddr(): got %v; want %v", gotUDPAddr, net.UDPAddrFromAddrPort(tt.want))
}
})
}
}
@ -546,6 +553,139 @@ func TestAddrPortFromSockaddr(t *testing.T) {
}
}
func TestAddrPortFromUDPAddr(t *testing.T) {
t.Parallel()
iface := firstInterface(t)
tests := []struct {
name string
network string
udpAddr *net.UDPAddr
wantAddr netip.AddrPort
wantDualStack bool
wantErr bool
}{
{
name: "nil/udp",
network: "udp",
udpAddr: nil,
wantDualStack: true,
wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 0),
},
{
name: "nil/udp4",
network: "udp4",
udpAddr: nil,
wantDualStack: false,
wantAddr: netip.AddrPortFrom(netip.IPv4Unspecified(), 0),
},
{
name: "nil/udp6",
network: "udp6",
udpAddr: nil,
wantDualStack: false,
wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 0),
},
{
name: "unspecified/udp",
network: "udp",
udpAddr: &net.UDPAddr{Port: 1234},
wantDualStack: true,
wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 1234),
},
{
name: "unspecified/udp4",
network: "udp4",
udpAddr: &net.UDPAddr{Port: 1234},
wantDualStack: false,
wantAddr: netip.AddrPortFrom(netip.IPv4Unspecified(), 1234),
},
{
name: "unspecified/udp6",
network: "udp6",
udpAddr: &net.UDPAddr{Port: 1234},
wantDualStack: false,
wantAddr: netip.AddrPortFrom(netip.IPv6Unspecified(), 1234),
},
{
name: "IPv4/udp",
network: "udp",
udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 1234},
wantDualStack: false,
wantAddr: netip.MustParseAddrPort("192.0.2.1:1234"),
},
{
name: "IPv6/udp",
network: "udp",
udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantDualStack: false,
wantAddr: netip.MustParseAddrPort("[2001:db8::1]:1234"),
},
{
name: "IPv6-with-zone/udp",
network: "udp",
udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234, Zone: iface.Name},
wantDualStack: false,
wantAddr: netip.AddrPortFrom(netip.MustParseAddr("2001:db8::1").WithZone(iface.Name), 1234),
},
{
name: "IPv4-mapped-IPv6/udp",
network: "udp",
udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1).To16(), Port: 1234},
wantDualStack: false,
wantAddr: netip.MustParseAddrPort("192.0.2.1:1234"),
},
{
name: "IPv4-mapped-IPv6/udp6",
network: "udp6",
udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1).To16(), Port: 1234},
wantDualStack: false,
wantAddr: netip.MustParseAddrPort("[::ffff:192.0.2.1]:1234"),
},
{
name: "nil/invalid-network",
network: "tcp",
udpAddr: nil,
wantErr: true,
},
{
name: "IPv4/invalid-network",
network: "tcp",
udpAddr: &net.UDPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 1234},
wantErr: true,
},
{
name: "IPv6/invalid-network",
network: "tcp",
udpAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantErr: true,
},
{
name: "IP/invalid-address",
network: "udp",
udpAddr: &net.UDPAddr{IP: []byte{1, 2, 3, 4, 5}, Port: 1234},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotAddr, dualStack, err := addrPortFromUDPAddr(tt.network, tt.udpAddr)
if (err != nil) != tt.wantErr {
t.Fatalf("addrPortFromUDPAddr: error: got %v; want %v", err, tt.wantErr)
}
if err != nil {
return
}
if gotAddr != tt.wantAddr {
t.Errorf("addrPortFromUDPAddr: got addr %v; want %v", gotAddr, tt.wantAddr)
}
if dualStack != tt.wantDualStack {
t.Errorf("addrPortFromUDPAddr: dualStack: got %v; want %v", dualStack, tt.wantDualStack)
}
})
}
}
func TestNetAddrFromAddrPort(t *testing.T) {
t.Parallel()
tests := []struct {