From da9965d51ca89a393bbf4c52818a579320dca93c Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Wed, 8 Jan 2025 17:21:44 -0600 Subject: [PATCH] cmd/viewer,types/views,various: avoid allocations in pointer field getters whenever possible In this PR, we add a generic views.ValuePointer type that can be used as a view for pointers to basic types and struct types that do not require deep cloning and do not have corresponding view types. Its Get/GetOk methods return stack-allocated shallow copies of the underlying value. We then update the cmd/viewer codegen to produce getters that return either concrete views when available or ValuePointer views when not, for pointer fields in generated view types. This allows us to avoid unnecessary allocations compared to returning pointers to newly allocated shallow copies. Updates #14570 Signed-off-by: Nick Khyl --- cmd/tsconnect/wasm/wasm_js.go | 2 +- cmd/viewer/tests/tests.go | 9 +++- cmd/viewer/tests/tests_clone.go | 4 ++ cmd/viewer/tests/tests_view.go | 45 +++++-------------- cmd/viewer/viewer.go | 72 ++++++++++++++++++++++++----- control/controlclient/map.go | 14 +++--- ipn/ipnlocal/drive.go | 3 +- ipn/ipnlocal/expiry_test.go | 8 ++-- ipn/ipnlocal/local.go | 39 +++++++--------- ipn/ipnlocal/peerapi.go | 8 ++-- tailcfg/tailcfg_view.go | 75 +++++++------------------------ types/prefs/prefs_view_test.go | 11 +---- types/views/views.go | 80 +++++++++++++++++++++++++++++++++ wgengine/pendopen.go | 8 ++-- wgengine/wgcfg/nmcfg/nmcfg.go | 4 +- 15 files changed, 219 insertions(+), 163 deletions(-) diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 4ea1cd897..a7e3e506b 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -282,7 +282,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { MachineKey: p.Machine().String(), NodeKey: p.Key().String(), }, - Online: p.Online(), + Online: p.Online().Clone(), TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), } }), diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 14a488861..ac094c53b 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -37,9 +37,14 @@ type Map struct { StructWithPtrKey map[StructWithPtrs]int `json:"-"` } +type StructWithNoView struct { + Value int +} + type StructWithPtrs struct { - Value *StructWithoutPtrs - Int *int + Value *StructWithoutPtrs + Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs `codegen:"noclone"` } diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 9131f5040..106a9b684 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -28,6 +28,9 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { if dst.Int != nil { dst.Int = ptr.To(*src.Int) } + if dst.NoView != nil { + dst.NoView = ptr.To(*src.NoView) + } return dst } @@ -35,6 +38,7 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { var _StructWithPtrsCloneNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 9c74c9426..41c1338ff 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -61,20 +61,11 @@ func (v *StructWithPtrsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithPtrsView) Value() *StructWithoutPtrs { - if v.ж.Value == nil { - return nil - } - x := *v.ж.Value - return &x -} +func (v StructWithPtrsView) Value() StructWithoutPtrsView { return v.ж.Value.View() } +func (v StructWithPtrsView) Int() views.ValuePointer[int] { return views.ValuePointerOf(v.ж.Int) } -func (v StructWithPtrsView) Int() *int { - if v.ж.Int == nil { - return nil - } - x := *v.ж.Int - return &x +func (v StructWithPtrsView) NoView() views.ValuePointer[StructWithNoView] { + return views.ValuePointerOf(v.ж.NoView) } func (v StructWithPtrsView) NoCloneValue() *StructWithoutPtrs { return v.ж.NoCloneValue } @@ -85,6 +76,7 @@ func (v StructWithPtrsView) Equal(v2 StructWithPtrsView) bool { return v.ж.Equa var _StructWithPtrsViewNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) @@ -424,12 +416,8 @@ func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericIntStructView[T]) Value() T { return v.ж.Value } -func (v GenericIntStructView[T]) Pointer() *T { - if v.ж.Pointer == nil { - return nil - } - x := *v.ж.Pointer - return &x +func (v GenericIntStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.ж.Pointer) } func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } @@ -500,12 +488,8 @@ func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericNoPtrsStructView[T]) Value() T { return v.ж.Value } -func (v GenericNoPtrsStructView[T]) Pointer() *T { - if v.ж.Pointer == nil { - return nil - } - x := *v.ж.Pointer - return &x +func (v GenericNoPtrsStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.ж.Pointer) } func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } @@ -722,19 +706,14 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() } +func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsAliasView { return v.ж.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr } func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView { return v.ж.WithPtrByPtr.View() } -func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias { - if v.ж.WithoutPtrByPtr == nil { - return nil - } - x := *v.ж.WithoutPtrByPtr - return &x +func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() StructWithoutPtrsAliasView { + return v.ж.WithoutPtrByPtr.View() } - func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.ж.SliceWithPtrs) } diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 0c5868f3a..e265defe0 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -79,13 +79,7 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { {{end}} {{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { - if v.ж.{{.FieldName}} == nil { - return nil - } - x := *v.ж.{{.FieldName}} - return &x -} +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ValuePointer[{{.FieldType}}] { return views.ValuePointerOf(v.ж.{{.FieldName}}) } {{end}} {{define "mapField"}} @@ -126,7 +120,7 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) { return p, p, t } -func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thisPkg *types.Package) { +func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package) { t, ok := typ.Underlying().(*types.Struct) if !ok || codegen.IsViewType(t) { return @@ -354,10 +348,32 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } else { writeTemplate("unsupportedField") } - } else { - args.FieldType = it.QualifiedName(ptr) - writeTemplate("valuePointerField") + continue } + + // If a view type is already defined for the base type, use it as the field's view type. + if viewType := viewTypeForValueType(base); viewType != nil { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } + + // Otherwise, if the unaliased base type is a named type whose view type will be generated by this viewer invocation, + // append the "View" suffix to the unaliased base type name and use it as the field's view type. + if base, ok := types.Unalias(base).(*types.Named); ok && slices.Contains(typeNames, it.QualifiedName(base)) { + baseTypeName := it.QualifiedName(base) + args.FieldType = baseTypeName + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplate("viewField") + continue + } + + // Otherwise, if the base type does not require deep cloning, has no existing view type, + // and will not have a generated view type, use views.ValuePointer[T] as the field's view type. + // Its Get/GetOk methods return stack-allocated shallow copies of the field's value. + args.FieldType = it.QualifiedName(base) + writeTemplate("valuePointerField") continue case *types.Interface: // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. @@ -405,6 +421,33 @@ func appendNameSuffix(name, suffix string) string { return name + suffix } +func typeNameOf(typ types.Type) (name *types.TypeName, ok bool) { + switch t := typ.(type) { + case *types.Alias: + return t.Obj(), true + case *types.Named: + return t.Obj(), true + default: + return nil, false + } +} + +func lookupViewType(typ types.Type) types.Type { + for { + if typeName, ok := typeNameOf(typ); ok && typeName.Pkg() != nil { + if viewTypeObj := typeName.Pkg().Scope().Lookup(typeName.Name() + "View"); viewTypeObj != nil { + return viewTypeObj.Type() + } + } + switch alias := typ.(type) { + case *types.Alias: + typ = alias.Rhs() + default: + return nil + } + } +} + func viewTypeForValueType(typ types.Type) types.Type { if ptr, ok := typ.(*types.Pointer); ok { return viewTypeForValueType(ptr.Elem()) @@ -417,7 +460,12 @@ func viewTypeForValueType(typ types.Type) types.Type { if !ok || sig.Results().Len() != 1 { return nil } - return sig.Results().At(0).Type() + viewType := sig.Results().At(0).Type() + // Check if the typ's package defines an alias for the view type, and use it if so. + if viewTypeAlias, ok := lookupViewType(typ).(*types.Alias); ok && types.AssignableTo(viewType, viewTypeAlias) { + viewType = viewTypeAlias + } + return viewType } func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 97d49f90d..30c1da672 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -689,13 +689,11 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return nil, false } case "Online": - wasOnline := was.Online() - if n.Online != nil && wasOnline != nil && *n.Online != *wasOnline { + if wasOnline, ok := was.Online().GetOk(); ok && n.Online != nil && *n.Online != wasOnline { pc().Online = ptr.To(*n.Online) } case "LastSeen": - wasSeen := was.LastSeen() - if n.LastSeen != nil && wasSeen != nil && !wasSeen.Equal(*n.LastSeen) { + if wasSeen, ok := was.LastSeen().GetOk(); ok && n.LastSeen != nil && !wasSeen.Equal(*n.LastSeen) { pc().LastSeen = ptr.To(*n.LastSeen) } case "MachineAuthorized": @@ -720,18 +718,18 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "SelfNodeV4MasqAddrForThisPeer": va, vb := was.SelfNodeV4MasqAddrForThisPeer(), n.SelfNodeV4MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "SelfNodeV6MasqAddrForThisPeer": va, vb := was.SelfNodeV6MasqAddrForThisPeer(), n.SelfNodeV6MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "ExitNodeDNSResolvers": diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index fe3622ba4..8ae813ff2 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -347,8 +347,7 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // TODO(oxtoacart): for some reason, this correctly // catches when a node goes from offline to online, // but not the other way around... - online := peer.Online() - if online == nil || !*online { + if !peer.Online().Get() { return false } diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index af1aa337b..a2b10fe32 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -283,11 +283,11 @@ func formatNodes(nodes []tailcfg.NodeView) string { } fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) + if online, ok := n.Online().GetOk(); ok { + fmt.Fprintf(&sb, ", online=%v", online) } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + if lastSeen, ok := n.LastSeen().GetOk(); ok { + fmt.Fprintf(&sb, ", lastSeen=%v", lastSeen.Unix()) } if n.Key() != (key.NodePublic{}) { fmt.Fprintf(&sb, ", key=%v", n.Key().String()) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 3a2a22c58..4ebcd5d6d 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1117,13 +1117,9 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } if !prefs.ExitNodeID().IsZero() { if exitPeer, ok := b.netMap.PeerWithStableID(prefs.ExitNodeID()); ok { - online := false - if v := exitPeer.Online(); v != nil { - online = *v - } s.ExitNodeStatus = &ipnstate.ExitNodeStatus{ ID: prefs.ExitNodeID(), - Online: online, + Online: exitPeer.Online().Get(), TailscaleIPs: exitPeer.Addresses().AsSlice(), } } @@ -1194,10 +1190,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { } exitNodeID := b.pm.CurrentPrefs().ExitNodeID() for _, p := range b.peers { - var lastSeen time.Time - if p.LastSeen() != nil { - lastSeen = *p.LastSeen() - } tailscaleIPs := make([]netip.Addr, 0, p.Addresses().Len()) for i := range p.Addresses().Len() { addr := p.Addresses().At(i) @@ -1205,7 +1197,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { tailscaleIPs = append(tailscaleIPs, addr.Addr()) } } - online := p.Online() ps := &ipnstate.PeerStatus{ InNetworkMap: true, UserID: p.User(), @@ -1214,12 +1205,12 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { HostName: p.Hostinfo().Hostname(), DNSName: p.Name(), OS: p.Hostinfo().OS(), - LastSeen: lastSeen, - Online: online != nil && *online, + LastSeen: p.LastSeen().Get(), + Online: p.Online().Get(), ShareeNode: p.Hostinfo().ShareeNode(), ExitNode: p.StableID() != "" && p.StableID() == exitNodeID, SSH_HostKeys: p.Hostinfo().SSH_HostKeys().AsSlice(), - Location: p.Hostinfo().Location(), + Location: p.Hostinfo().Location().AsStruct(), Capabilities: p.Capabilities().AsSlice(), } if cm := p.CapMap(); cm.Len() > 0 { @@ -7369,8 +7360,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug if len(candidates) == 1 { peer := candidates[0] if hi := peer.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } res.ID = peer.StableID() @@ -7414,10 +7405,10 @@ type nodeDistance struct { continue } loc := hi.Location() - if loc == nil { + if !loc.Valid() { continue } - distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude, loc.Longitude) + distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude(), loc.Longitude()) if distance < minDistance { minDistance = distance } @@ -7438,8 +7429,8 @@ type nodeDistance struct { res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7468,8 +7459,8 @@ type nodeDistance struct { res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7485,13 +7476,13 @@ func pickWeighted(candidates []tailcfg.NodeView) []tailcfg.NodeView { continue } loc := hi.Location() - if loc == nil || loc.Priority < maxWeight { + if !loc.Valid() || loc.Priority() < maxWeight { continue } - if maxWeight != loc.Priority { + if maxWeight != loc.Priority() { best = best[:0] } - maxWeight = loc.Priority + maxWeight = loc.Priority() best = append(best, c) } return best diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index aa18c3588..7aa677640 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -233,11 +233,11 @@ func (h *peerAPIHandler) logf(format string, a ...any) { // isAddressValid reports whether addr is a valid destination address for this // node originating from the peer. func (h *peerAPIHandler) isAddressValid(addr netip.Addr) bool { - if v := h.peerNode.SelfNodeV4MasqAddrForThisPeer(); v != nil { - return *v == addr + if v, ok := h.peerNode.SelfNodeV4MasqAddrForThisPeer().GetOk(); ok { + return v == addr } - if v := h.peerNode.SelfNodeV6MasqAddrForThisPeer(); v != nil { - return *v == addr + if v, ok := h.peerNode.SelfNodeV6MasqAddrForThisPeer().GetOk(); ok { + return v == addr } pfx := netip.PrefixFrom(addr, addr.BitLen()) return views.SliceContains(h.selfNode.Addresses(), pfx) diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 774a18258..53df3dcef 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -145,21 +145,11 @@ func (v NodeView) Created() time.Time { return v.ж.Create func (v NodeView) Cap() CapabilityVersion { return v.ж.Cap } func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v NodeView) PrimaryRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.PrimaryRoutes) } -func (v NodeView) LastSeen() *time.Time { - if v.ж.LastSeen == nil { - return nil - } - x := *v.ж.LastSeen - return &x +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.LastSeen) } -func (v NodeView) Online() *bool { - if v.ж.Online == nil { - return nil - } - x := *v.ж.Online - return &x -} +func (v NodeView) Online() views.ValuePointer[bool] { return views.ValuePointerOf(v.ж.Online) } func (v NodeView) MachineAuthorized() bool { return v.ж.MachineAuthorized } func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.ж.Capabilities) } @@ -172,20 +162,12 @@ func (v NodeView) ComputedName() string { return v.ж.ComputedName } func (v NodeView) ComputedNameWithHost() string { return v.ж.ComputedNameWithHost } func (v NodeView) DataPlaneAuditLogID() string { return v.ж.DataPlaneAuditLogID } func (v NodeView) Expired() bool { return v.ж.Expired } -func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { - if v.ж.SelfNodeV4MasqAddrForThisPeer == nil { - return nil - } - x := *v.ж.SelfNodeV4MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV4MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.ж.SelfNodeV4MasqAddrForThisPeer) } -func (v NodeView) SelfNodeV6MasqAddrForThisPeer() *netip.Addr { - if v.ж.SelfNodeV6MasqAddrForThisPeer == nil { - return nil - } - x := *v.ж.SelfNodeV6MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV6MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.ж.SelfNodeV6MasqAddrForThisPeer) } func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } @@ -315,15 +297,8 @@ func (v HostinfoView) Userspace() opt.Bool { return v.ж.User func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter } func (v HostinfoView) AppConnector() opt.Bool { return v.ж.AppConnector } func (v HostinfoView) ServicesHash() string { return v.ж.ServicesHash } -func (v HostinfoView) Location() *Location { - if v.ж.Location == nil { - return nil - } - x := *v.ж.Location - return &x -} - -func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } +func (v HostinfoView) Location() LocationView { return v.ж.Location.View() } +func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HostinfoViewNeedsRegeneration = Hostinfo(struct { @@ -699,12 +674,8 @@ func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { return nil } -func (v RegisterResponseAuthView) Oauth2Token() *Oauth2Token { - if v.ж.Oauth2Token == nil { - return nil - } - x := *v.ж.Oauth2Token - return &x +func (v RegisterResponseAuthView) Oauth2Token() views.ValuePointer[Oauth2Token] { + return views.ValuePointerOf(v.ж.Oauth2Token) } func (v RegisterResponseAuthView) AuthKey() string { return v.ж.AuthKey } @@ -774,12 +745,8 @@ func (v RegisterRequestView) NodeKeySignature() views.ByteSlice[tkatype.Marshale return views.ByteSliceOf(v.ж.NodeKeySignature) } func (v RegisterRequestView) SignatureType() SignatureType { return v.ж.SignatureType } -func (v RegisterRequestView) Timestamp() *time.Time { - if v.ж.Timestamp == nil { - return nil - } - x := *v.ж.Timestamp - return &x +func (v RegisterRequestView) Timestamp() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.Timestamp) } func (v RegisterRequestView) DeviceCert() views.ByteSlice[[]byte] { @@ -1110,12 +1077,8 @@ func (v *SSHRuleView) UnmarshalJSON(b []byte) error { return nil } -func (v SSHRuleView) RuleExpires() *time.Time { - if v.ж.RuleExpires == nil { - return nil - } - x := *v.ж.RuleExpires - return &x +func (v SSHRuleView) RuleExpires() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.RuleExpires) } func (v SSHRuleView) Principals() views.SliceView[*SSHPrincipal, SSHPrincipalView] { @@ -1189,12 +1152,8 @@ func (v SSHActionView) HoldAndDelegate() string { return v.ж.Hol func (v SSHActionView) AllowLocalPortForwarding() bool { return v.ж.AllowLocalPortForwarding } func (v SSHActionView) AllowRemotePortForwarding() bool { return v.ж.AllowRemotePortForwarding } func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Recorders) } -func (v SSHActionView) OnRecordingFailure() *SSHRecorderFailureAction { - if v.ж.OnRecordingFailure == nil { - return nil - } - x := *v.ж.OnRecordingFailure - return &x +func (v SSHActionView) OnRecordingFailure() views.ValuePointer[SSHRecorderFailureAction] { + return views.ValuePointerOf(v.ж.OnRecordingFailure) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. diff --git a/types/prefs/prefs_view_test.go b/types/prefs/prefs_view_test.go index d76eebb43..ef9f09603 100644 --- a/types/prefs/prefs_view_test.go +++ b/types/prefs/prefs_view_test.go @@ -162,15 +162,8 @@ func (v *TestBundleView) UnmarshalJSON(b []byte) error { return nil } -func (v TestBundleView) Name() string { return v.ж.Name } -func (v TestBundleView) Nested() *TestValueStruct { - if v.ж.Nested == nil { - return nil - } - x := *v.ж.Nested - return &x -} - +func (v TestBundleView) Name() string { return v.ж.Name } +func (v TestBundleView) Nested() TestValueStructView { return v.ж.Nested.View() } func (v TestBundleView) Equal(v2 TestBundleView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. diff --git a/types/views/views.go b/types/views/views.go index 40d8811f5..d8acf27ce 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -16,6 +16,7 @@ "slices" "go4.org/mem" + "tailscale.com/types/ptr" ) func unmarshalSliceFromJSON[T any](b []byte, x *[]T) error { @@ -690,6 +691,85 @@ func (m MapFn[K, T, V]) All() iter.Seq2[K, V] { } } +// ValuePointer provides a read-only view of a pointer to a value type, +// such as a primitive type or an immutable struct. Its Value and ValueOk +// methods return a stack-allocated shallow copy of the underlying value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +type ValuePointer[T any] struct { + // ж is the underlying value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *T +} + +// Valid reports whether the underlying pointer is non-nil. +func (p ValuePointer[T]) Valid() bool { + return p.ж != nil +} + +// Get returns a shallow copy of the value if the underlying pointer is non-nil. +// Otherwise, it returns a zero value. +func (p ValuePointer[T]) Get() T { + v, _ := p.GetOk() + return v +} + +// GetOk returns a shallow copy of the underlying value and true if the underlying +// pointer is non-nil. Otherwise, it returns a zero value and false. +func (p ValuePointer[T]) GetOk() (value T, ok bool) { + if p.ж == nil { + return value, false // value holds a zero value + } + return *p.ж, true +} + +// GetOr returns a shallow copy of the underlying value if it is non-nil. +// Otherwise, it returns the provided default value. +func (p ValuePointer[T]) GetOr(def T) T { + if p.ж == nil { + return def + } + return *p.ж +} + +// Clone returns a shallow copy of the underlying value. +func (p ValuePointer[T]) Clone() *T { + if p.ж == nil { + return nil + } + return ptr.To(*p.ж) +} + +// String implements [fmt.Stringer]. +func (p ValuePointer[T]) String() string { + if p.ж == nil { + return "nil" + } + return fmt.Sprint(p.ж) +} + +// ValuePointerOf returns an immutable view of a pointer to an immutable value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +func ValuePointerOf[T any](v *T) ValuePointer[T] { + return ValuePointer[T]{v} +} + +// MarshalJSON implements [json.Marshaler]. +func (p ValuePointer[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(p.ж) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (p *ValuePointer[T]) UnmarshalJSON(b []byte) error { + if p.ж != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &p.ж) +} + // ContainsPointers reports whether T contains any pointers, // either explicitly or implicitly. // It has special handling for some types that contain pointers diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 308c3ede2..f8e9198a5 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -239,15 +239,15 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { if n.IsWireGuardOnly() { online = "wg" } else { - if v := n.Online(); v != nil { - if *v { + if v, ok := n.Online().GetOk(); ok { + if v { online = "yes" } else { online = "no" } } - if n.LastSeen() != nil && online != "yes" { - online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen())) + if lastSeen, ok := n.LastSeen().GetOk(); ok && online != "yes" { + online += fmt.Sprintf(", lastseen=%v", durFmt(lastSeen)) } } e.logf("open-conn-track: timeout opening %v to node %v; online=%v, lastRecv=%v", diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index e7d5edf15..97304aa41 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -106,8 +106,8 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cpeer := &cfg.Peers[len(cfg.Peers)-1] didExitNodeWarn := false - cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() - cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() + cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer().Clone() + cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer().Clone() cpeer.IsJailed = peer.IsJailed() for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode {