fact(utils): factor out single subnet ip logic

Removes repeated logic of calculating IP address subnets for single
subnet hosts and consolidates it in one place.
This commit is contained in:
Aaron U'Ren 2025-06-20 16:28:05 -05:00 committed by Aaron U'Ren
parent b070531ec5
commit 3c895955f7
3 changed files with 87 additions and 21 deletions

View File

@ -23,9 +23,6 @@ import (
) )
const ( const (
ipv4NetMaskBits = 32
ipv6NetMaskBits = 128
// TODO: it's bad to rely on eth0 here. While this is inside the container's namespace and is determined by the // TODO: it's bad to rely on eth0 here. While this is inside the container's namespace and is determined by the
// container runtime and so far we've been able to count on this being reliably set to eth0, it is possible that // container runtime and so far we've been able to count on this being reliably set to eth0, it is possible that
// this may shift sometime in the future with a different runtime. It would be better to find a reliable way to // this may shift sometime in the future with a different runtime. It would be better to find a reliable way to
@ -65,7 +62,6 @@ type netlinkCalls interface {
} }
func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP string) error { func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP string) error {
var netMask net.IPMask
parsedIP := net.ParseIP(ip) parsedIP := net.ParseIP(ip)
parsedNodeIP := net.ParseIP(nodeIP) parsedNodeIP := net.ParseIP(nodeIP)
if parsedIP.To4() != nil { if parsedIP.To4() != nil {
@ -73,8 +69,6 @@ func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP strin
if parsedNodeIP.To4() == nil { if parsedNodeIP.To4() == nil {
return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP) return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP)
} }
netMask = net.CIDRMask(ipv4NetMaskBits, ipv4NetMaskBits)
} else { } else {
// If the IP family of the NodeIP and the VIP IP don't match, we can't proceed // If the IP family of the NodeIP and the VIP IP don't match, we can't proceed
if parsedNodeIP.To4() != nil { if parsedNodeIP.To4() != nil {
@ -85,11 +79,9 @@ func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP strin
klog.V(2).Infof("Ignoring link-local IP address: %s", ip) klog.V(2).Infof("Ignoring link-local IP address: %s", ip)
return nil return nil
} }
netMask = net.CIDRMask(ipv6NetMaskBits, ipv6NetMaskBits)
} }
naddr := &netlink.Addr{IPNet: &net.IPNet{IP: parsedIP, Mask: netMask}, Scope: syscall.RT_SCOPE_LINK} naddr := &netlink.Addr{IPNet: utils.GetSingleIPNet(parsedIP), Scope: syscall.RT_SCOPE_LINK}
err := netlink.AddrDel(iface, naddr) err := netlink.AddrDel(iface, naddr)
if err != nil { if err != nil {
if err.Error() != IfaceHasNoAddr { if err.Error() != IfaceHasNoAddr {
@ -106,7 +98,7 @@ func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP strin
// #nosec G204 // #nosec G204
nRoute := &netlink.Route{ nRoute := &netlink.Route{
Type: unix.RTN_LOCAL, Type: unix.RTN_LOCAL,
Dst: &net.IPNet{IP: parsedIP, Mask: netMask}, Dst: utils.GetSingleIPNet(parsedIP),
LinkIndex: iface.Attrs().Index, LinkIndex: iface.Attrs().Index,
Table: syscall.RT_TABLE_LOCAL, Table: syscall.RT_TABLE_LOCAL,
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
@ -132,7 +124,6 @@ func (ln *linuxNetworking) ipAddrDel(iface netlink.Link, ip string, nodeIP strin
// to kube-dummy-if. Also when DSR is used, used to assign VIP to dummy interface // to kube-dummy-if. Also when DSR is used, used to assign VIP to dummy interface
// inside the container. // inside the container.
func (ln *linuxNetworking) ipAddrAdd(iface netlink.Link, ip string, nodeIP string, addRoute bool) error { func (ln *linuxNetworking) ipAddrAdd(iface netlink.Link, ip string, nodeIP string, addRoute bool) error {
var netMask net.IPMask
var isIPv6 bool var isIPv6 bool
parsedIP := net.ParseIP(ip) parsedIP := net.ParseIP(ip)
parsedNodeIP := net.ParseIP(nodeIP) parsedNodeIP := net.ParseIP(nodeIP)
@ -141,20 +132,16 @@ func (ln *linuxNetworking) ipAddrAdd(iface netlink.Link, ip string, nodeIP strin
if addRoute && parsedNodeIP.To4() == nil { if addRoute && parsedNodeIP.To4() == nil {
return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP) return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP)
} }
netMask = net.CIDRMask(ipv4NetMaskBits, ipv4NetMaskBits)
isIPv6 = false isIPv6 = false
} else { } else {
// If we're supposed to add a route and the IP family of the NodeIP and the VIP IP don't match, we can't proceed // If we're supposed to add a route and the IP family of the NodeIP and the VIP IP don't match, we can't proceed
if addRoute && parsedNodeIP.To4() != nil { if addRoute && parsedNodeIP.To4() != nil {
return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP) return fmt.Errorf("nodeIP %s does not match family for VIP IP: %s, unable to proceed", ip, nodeIP)
} }
netMask = net.CIDRMask(ipv6NetMaskBits, ipv6NetMaskBits)
isIPv6 = true isIPv6 = true
} }
ipPrefix := &net.IPNet{IP: parsedIP, Mask: netMask} ipPrefix := utils.GetSingleIPNet(parsedIP)
naddr := &netlink.Addr{IPNet: ipPrefix, Scope: syscall.RT_SCOPE_LINK} naddr := &netlink.Addr{IPNet: ipPrefix, Scope: syscall.RT_SCOPE_LINK}
err := netlink.AddrAdd(iface, naddr) err := netlink.AddrAdd(iface, naddr)
if err != nil && err.Error() != IfaceHasAddr { if err != nil && err.Error() != IfaceHasAddr {
@ -196,7 +183,7 @@ func (ln *linuxNetworking) ipAddrAdd(iface netlink.Link, ip string, nodeIP strin
// create the source route below. See: https://github.com/cloudnativelabs/kube-router/issues/1698 // create the source route below. See: https://github.com/cloudnativelabs/kube-router/issues/1698
if isIPv6 { if isIPv6 {
nRoute := &netlink.Route{ nRoute := &netlink.Route{
Dst: &net.IPNet{IP: parsedIP, Mask: netMask}, Dst: utils.GetSingleIPNet(parsedIP),
Table: unix.RT_TABLE_UNSPEC, Table: unix.RT_TABLE_UNSPEC,
} }
routes, err := netlink.RouteListFiltered(netlink.FAMILY_V6, nRoute, routes, err := netlink.RouteListFiltered(netlink.FAMILY_V6, nRoute,
@ -327,10 +314,10 @@ func (ln *linuxNetworking) ipvsAddService(svcs []*ipvs.Service, vip net.IP, prot
var ipMask uint32 var ipMask uint32
if vip.To4() != nil { if vip.To4() != nil {
ipvsFamily = syscall.AF_INET ipvsFamily = syscall.AF_INET
ipMask = uint32(ipv4NetMaskBits) ipMask = utils.GetIPv4NetMaxMaskBits()
} else { } else {
ipvsFamily = syscall.AF_INET6 ipvsFamily = syscall.AF_INET6
ipMask = uint32(ipv6NetMaskBits) ipMask = utils.GetIPv6NetMaxMaskBits()
} }
svc := ipvs.Service{ svc := ipvs.Service{
Address: vip, Address: vip,
@ -370,9 +357,9 @@ func (ln *linuxNetworking) ipvsAddFWMarkService(svcs []*ipvs.Service, fwMark uin
var netmaskForFamily uint32 var netmaskForFamily uint32
switch family { switch family {
case syscall.AF_INET: case syscall.AF_INET:
netmaskForFamily = ipv4NetMaskBits netmaskForFamily = utils.GetIPv4NetMaxMaskBits()
case syscall.AF_INET6: case syscall.AF_INET6:
netmaskForFamily = ipv6NetMaskBits netmaskForFamily = utils.GetIPv6NetMaxMaskBits()
} }
for _, svc := range svcs { for _, svc := range svcs {
if fwMark == svc.FWMark { if fwMark == svc.FWMark {

View File

@ -9,8 +9,37 @@ import (
const ( const (
IPv4DefaultRoute = "0.0.0.0/0" IPv4DefaultRoute = "0.0.0.0/0"
IPv6DefaultRoute = "::/0" IPv6DefaultRoute = "::/0"
ipv4NetMaskBits = 32
ipv6NetMaskBits = 128
) )
// GetSingleIPNet returns an IPNet object that represents a subnet containing a single IP address for a given IP address
// with proper handling for IPv4 and IPv6 addresses.
func GetSingleIPNet(ip net.IP) *net.IPNet {
if ip.To4() != nil {
return &net.IPNet{
IP: ip,
Mask: net.CIDRMask(ipv4NetMaskBits, ipv4NetMaskBits),
}
} else {
return &net.IPNet{
IP: ip,
Mask: net.CIDRMask(ipv6NetMaskBits, ipv6NetMaskBits),
}
}
}
// GetIPv4NetMaxMaskBits returns the maximum mask bits for an IPv4 address
func GetIPv4NetMaxMaskBits() uint32 {
return ipv4NetMaskBits
}
// GetIPv6NetMaxMaskBits returns the maximum mask bits for an IPv6 address
func GetIPv6NetMaxMaskBits() uint32 {
return ipv6NetMaskBits
}
// ContainsIPv4Address checks a given string array to see if it contains a valid IPv4 address within it // ContainsIPv4Address checks a given string array to see if it contains a valid IPv4 address within it
func ContainsIPv4Address(addrs []string) bool { func ContainsIPv4Address(addrs []string) bool {
for _, addr := range addrs { for _, addr := range addrs {

View File

@ -292,3 +292,53 @@ func TestIP_IsPrivate(t *testing.T) {
}) })
} }
} }
func TestGetSingleIPNet(t *testing.T) {
tests := []struct {
name string
ip net.IP
expected *net.IPNet
}{
{
name: "IPv4 address",
ip: net.IPv4(192, 168, 1, 1),
expected: &net.IPNet{
IP: net.IPv4(192, 168, 1, 1),
Mask: net.CIDRMask(32, 32),
},
},
{
name: "IPv4-mapped IPv6 address",
ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1},
expected: &net.IPNet{
IP: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1},
Mask: net.CIDRMask(32, 32),
},
},
{
name: "IPv6 address",
ip: net.IPv6loopback,
expected: &net.IPNet{
IP: net.IPv6loopback,
Mask: net.CIDRMask(128, 128),
},
},
{
name: "Another IPv6 address",
ip: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
expected: &net.IPNet{
IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
Mask: net.CIDRMask(128, 128),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSingleIPNet(tt.ip)
require.NotNil(t, result)
assert.True(t, tt.expected.IP.Equal(result.IP))
assert.Equal(t, tt.expected.Mask, result.Mask)
})
}
}