wgengine/router/osrouter: fix data race in magicsock port update callback

As found by @cmol in #17423.

Updates #17423

Change-Id: I1492501f74ca7b57a8c5278ea6cb87a56a4086b9
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2025-10-03 13:31:49 -07:00 committed by Brad Fitzpatrick
parent 447cbdd1d0
commit 141eb64d3f

View File

@ -86,8 +86,8 @@ type linuxRouter struct {
cmd commandRunner
nfr linuxfw.NetfilterRunner
magicsockPortV4 uint16
magicsockPortV6 uint16
magicsockPortV4 atomic.Uint32 // actually a uint16
magicsockPortV6 atomic.Uint32 // actually a uint16
}
func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) {
@ -546,7 +546,7 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error {
}
}
var magicsockPort *uint16
var magicsockPort *atomic.Uint32
switch network {
case "udp4":
magicsockPort = &r.magicsockPortV4
@ -566,27 +566,29 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error {
// set the port, we'll make the firewall rule when netfilter turns back on
if r.netfilterMode == netfilterOff {
*magicsockPort = port
magicsockPort.Store(uint32(port))
return nil
}
if *magicsockPort == port {
cur := magicsockPort.Load()
if cur == uint32(port) {
return nil
}
if *magicsockPort != 0 {
if err := r.nfr.DelMagicsockPortRule(*magicsockPort, network); err != nil {
if cur != 0 {
if err := r.nfr.DelMagicsockPortRule(uint16(cur), network); err != nil {
return fmt.Errorf("del magicsock port rule: %w", err)
}
}
if port != 0 {
if err := r.nfr.AddMagicsockPortRule(*magicsockPort, network); err != nil {
if err := r.nfr.AddMagicsockPortRule(uint16(port), network); err != nil {
return fmt.Errorf("add magicsock port rule: %w", err)
}
}
*magicsockPort = port
magicsockPort.Store(uint32(port))
return nil
}
@ -658,13 +660,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
if err := r.nfr.AddBase(r.tunname); err != nil {
return err
}
if r.magicsockPortV4 != 0 {
if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil {
if mport := uint16(r.magicsockPortV4.Load()); mport != 0 {
if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil {
return fmt.Errorf("could not add magicsock port rule v4: %w", err)
}
}
if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() {
if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil {
if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() {
if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil {
return fmt.Errorf("could not add magicsock port rule v6: %w", err)
}
}
@ -698,13 +700,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
if err := r.nfr.AddBase(r.tunname); err != nil {
return err
}
if r.magicsockPortV4 != 0 {
if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil {
if mport := uint16(r.magicsockPortV4.Load()); mport != 0 {
if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil {
return fmt.Errorf("could not add magicsock port rule v4: %w", err)
}
}
if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() {
if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil {
if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() {
if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil {
return fmt.Errorf("could not add magicsock port rule v6: %w", err)
}
}