diff --git a/cmd/systray/systray.go b/cmd/systray/systray.go index 26316feeb..8a4ee08fd 100644 --- a/cmd/systray/systray.go +++ b/cmd/systray/systray.go @@ -7,7 +7,6 @@ package main import ( - "cmp" "context" "errors" "fmt" @@ -30,12 +29,15 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/util/stringsx" ) var ( localClient tailscale.LocalClient chState chan ipn.State // tailscale state changes + chRebuild chan struct{} // triggers a menu rebuild + appIcon *os.File // newMenuDelay is the amount of time to sleep after creating a new menu, @@ -111,6 +113,7 @@ func onReady() { io.Copy(appIcon, connected.renderWithBorder(3)) chState = make(chan ipn.State, 1) + chRebuild = make(chan struct{}, 1) menu := new(Menu) menu.rebuild(fetchState(ctx)) @@ -146,6 +149,10 @@ func fetchState(ctx context.Context) state { // You cannot iterate over the items in a menu, nor can you remove some items like separators. // So for now we rebuild the whole thing, and can optimize this later if needed. func (menu *Menu) rebuild(state state) { + if state.status == nil { + return + } + menu.mu.Lock() defer menu.mu.Unlock() @@ -181,25 +188,20 @@ func (menu *Menu) rebuild(state state) { item = accounts.AddSubMenuItem(title, "") } setRemoteIcon(item, profile.UserProfile.ProfilePicURL) - go func(profile ipn.LoginProfile) { - for { - select { - case <-ctx.Done(): - return - case <-item.ClickedCh: - select { - case <-ctx.Done(): - return - case menu.accountsCh <- profile.ID: - } - } + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.accountsCh <- profile.ID: } - }(profile) + }) } - if state.status != nil && state.status.Self != nil { + if state.status != nil && state.status.Self != nil && len(state.status.Self.TailscaleIPs) > 0 { title := fmt.Sprintf("This Device: %s (%s)", state.status.Self.HostName, state.status.Self.TailscaleIPs[0]) menu.self = systray.AddMenuItem(title, "") + } else { + menu.self = systray.AddMenuItem("This Device: not connected", "") + menu.self.Disable() } systray.AddSeparator() @@ -266,6 +268,8 @@ func (menu *Menu) eventLoop(ctx context.Context) { select { case <-ctx.Done(): return + case <-chRebuild: + menu.rebuild(fetchState(ctx)) case state := <-chState: switch state { case ipn.Running: @@ -277,10 +281,11 @@ func (menu *Menu) eventLoop(ctx context.Context) { menu.disconnect.Show() menu.disconnect.Enable() case ipn.NoState, ipn.Stopped: + setAppIcon(disconnected) + menu.rebuild(fetchState(ctx)) menu.connect.SetTitle("Connect") menu.connect.Enable() menu.disconnect.Hide() - setAppIcon(disconnected) case ipn.Starting: setAppIcon(loading) } @@ -337,7 +342,6 @@ func (menu *Menu) eventLoop(ctx context.Context) { log.Printf("failed setting exit node: %v", err) } } - menu.rebuild(fetchState(ctx)) case <-menu.quit.ClickedCh: systray.Quit() @@ -345,6 +349,20 @@ func (menu *Menu) eventLoop(ctx context.Context) { } } +// onClick registers a click handler for a menu item. +func onClick(ctx context.Context, item *systray.MenuItem, fn func(ctx context.Context)) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-item.ClickedCh: + fn(ctx) + } + } + }() +} + // watchIPNBus subscribes to the tailscale event bus and sends state updates to chState. // This method does not return. func watchIPNBus(ctx context.Context) { @@ -383,6 +401,9 @@ func watchIPNBusInner(ctx context.Context) error { chState <- *n.State log.Printf("new state: %v", n.State) } + if n.Prefs != nil { + chRebuild <- struct{}{} + } } } } @@ -425,25 +446,17 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { time.Sleep(newMenuDelay) // register a click handler for a menu item to set nodeID as the exit node. - onClick := func(item *systray.MenuItem, nodeID tailcfg.StableNodeID) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-item.ClickedCh: - select { - case <-ctx.Done(): - return - case menu.exitNodeCh <- nodeID: - } - } + setExitNodeOnClick := func(item *systray.MenuItem, nodeID tailcfg.StableNodeID) { + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.exitNodeCh <- nodeID: } - }() + }) } noExitNodeMenu := menu.exitNodes.AddSubMenuItemCheckbox("None", "", status.ExitNodeStatus == nil) - onClick(noExitNodeMenu, "") + setExitNodeOnClick(noExitNodeMenu, "") // Show recommended exit node if available. if status.Self.CapMap.Contains(tailcfg.NodeAttrSuggestExitNodeUI) { @@ -458,7 +471,7 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { } menu.exitNodes.AddSeparator() rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", false) - onClick(rm, sugg.ID) + setExitNodeOnClick(rm, sugg.ID) if status.ExitNodeStatus != nil && sugg.ID == status.ExitNodeStatus.ID { rm.Check() } @@ -490,7 +503,7 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { if status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID { sm.Check() } - onClick(sm, ps.ID) + setExitNodeOnClick(sm, ps.ID) } } @@ -510,7 +523,7 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { // single-city country, no submenu if len(country.cities) == 1 || hideMullvadCities { - onClick(countryMenu, country.best.ID) + setExitNodeOnClick(countryMenu, country.best.ID) if status.ExitNodeStatus != nil { for _, city := range country.cities { for _, ps := range city.peers { @@ -527,12 +540,12 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { // multi-city country, build submenu with "best available" option and cities. time.Sleep(newMenuDelay) bm := countryMenu.AddSubMenuItemCheckbox("Best Available", "", false) - onClick(bm, country.best.ID) + setExitNodeOnClick(bm, country.best.ID) countryMenu.AddSeparator() for _, city := range country.sortedCities() { cityMenu := countryMenu.AddSubMenuItemCheckbox(city.name, "", false) - onClick(cityMenu, city.best.ID) + setExitNodeOnClick(cityMenu, city.best.ID) if status.ExitNodeStatus != nil { for _, ps := range city.peers { if status.ExitNodeStatus.ID == ps.ID { @@ -558,7 +571,7 @@ type mullvadPeers struct { func (mp mullvadPeers) sortedCountries() []*mvCountry { countries := slices.Collect(maps.Values(mp.countries)) slices.SortFunc(countries, func(a, b *mvCountry) int { - return cmp.Compare(a.name, b.name) + return stringsx.CompareFold(a.name, b.name) }) return countries } @@ -574,7 +587,7 @@ type mvCountry struct { func (mc *mvCountry) sortedCities() []*mvCity { cities := slices.Collect(maps.Values(mc.cities)) slices.SortFunc(cities, func(a, b *mvCity) int { - return cmp.Compare(a.name, b.name) + return stringsx.CompareFold(a.name, b.name) }) return cities }