diff --git a/wgengine/router/osrouter/router_linux.go b/wgengine/router/osrouter/router_linux.go index 1f825b917..cf1a9f027 100644 --- a/wgengine/router/osrouter/router_linux.go +++ b/wgengine/router/osrouter/router_linux.go @@ -86,8 +86,8 @@ type linuxRouter struct { cmd commandRunner nfr linuxfw.NetfilterRunner - magicsockPortV4 uint16 - magicsockPortV6 uint16 + magicsockPortV4 atomic.Uint32 // actually a uint16 + magicsockPortV6 atomic.Uint32 // actually a uint16 } func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) { @@ -546,7 +546,7 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error { } } - var magicsockPort *uint16 + var magicsockPort *atomic.Uint32 switch network { case "udp4": magicsockPort = &r.magicsockPortV4 @@ -566,27 +566,29 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error { // set the port, we'll make the firewall rule when netfilter turns back on if r.netfilterMode == netfilterOff { - *magicsockPort = port + magicsockPort.Store(uint32(port)) return nil } - if *magicsockPort == port { + cur := magicsockPort.Load() + + if cur == uint32(port) { return nil } - if *magicsockPort != 0 { - if err := r.nfr.DelMagicsockPortRule(*magicsockPort, network); err != nil { + if cur != 0 { + if err := r.nfr.DelMagicsockPortRule(uint16(cur), network); err != nil { return fmt.Errorf("del magicsock port rule: %w", err) } } if port != 0 { - if err := r.nfr.AddMagicsockPortRule(*magicsockPort, network); err != nil { + if err := r.nfr.AddMagicsockPortRule(uint16(port), network); err != nil { return fmt.Errorf("add magicsock port rule: %w", err) } } - *magicsockPort = port + magicsockPort.Store(uint32(port)) return nil } @@ -658,13 +660,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { if err := r.nfr.AddBase(r.tunname); err != nil { return err } - if r.magicsockPortV4 != 0 { - if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil { + if mport := uint16(r.magicsockPortV4.Load()); mport != 0 { + if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil { return fmt.Errorf("could not add magicsock port rule v4: %w", err) } } - if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() { - if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil { + if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() { + if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil { return fmt.Errorf("could not add magicsock port rule v6: %w", err) } } @@ -698,13 +700,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { if err := r.nfr.AddBase(r.tunname); err != nil { return err } - if r.magicsockPortV4 != 0 { - if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil { + if mport := uint16(r.magicsockPortV4.Load()); mport != 0 { + if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil { return fmt.Errorf("could not add magicsock port rule v4: %w", err) } } - if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() { - if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil { + if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() { + if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil { return fmt.Errorf("could not add magicsock port rule v6: %w", err) } }