diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 4ea1cd897..a7e3e506b 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -282,7 +282,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { MachineKey: p.Machine().String(), NodeKey: p.Key().String(), }, - Online: p.Online(), + Online: p.Online().Clone(), TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), } }), diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 14a488861..ac094c53b 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -37,9 +37,14 @@ type Map struct { StructWithPtrKey map[StructWithPtrs]int `json:"-"` } +type StructWithNoView struct { + Value int +} + type StructWithPtrs struct { - Value *StructWithoutPtrs - Int *int + Value *StructWithoutPtrs + Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs `codegen:"noclone"` } diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 9131f5040..106a9b684 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -28,6 +28,9 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { if dst.Int != nil { dst.Int = ptr.To(*src.Int) } + if dst.NoView != nil { + dst.NoView = ptr.To(*src.NoView) + } return dst } @@ -35,6 +38,7 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { var _StructWithPtrsCloneNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 9c74c9426..41c1338ff 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -61,20 +61,11 @@ func (v *StructWithPtrsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithPtrsView) Value() *StructWithoutPtrs { - if v.ж.Value == nil { - return nil - } - x := *v.ж.Value - return &x -} +func (v StructWithPtrsView) Value() StructWithoutPtrsView { return v.ж.Value.View() } +func (v StructWithPtrsView) Int() views.ValuePointer[int] { return views.ValuePointerOf(v.ж.Int) } -func (v StructWithPtrsView) Int() *int { - if v.ж.Int == nil { - return nil - } - x := *v.ж.Int - return &x +func (v StructWithPtrsView) NoView() views.ValuePointer[StructWithNoView] { + return views.ValuePointerOf(v.ж.NoView) } func (v StructWithPtrsView) NoCloneValue() *StructWithoutPtrs { return v.ж.NoCloneValue } @@ -85,6 +76,7 @@ func (v StructWithPtrsView) Equal(v2 StructWithPtrsView) bool { return v.ж.Equa var _StructWithPtrsViewNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) @@ -424,12 +416,8 @@ func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericIntStructView[T]) Value() T { return v.ж.Value } -func (v GenericIntStructView[T]) Pointer() *T { - if v.ж.Pointer == nil { - return nil - } - x := *v.ж.Pointer - return &x +func (v GenericIntStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.ж.Pointer) } func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } @@ -500,12 +488,8 @@ func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericNoPtrsStructView[T]) Value() T { return v.ж.Value } -func (v GenericNoPtrsStructView[T]) Pointer() *T { - if v.ж.Pointer == nil { - return nil - } - x := *v.ж.Pointer - return &x +func (v GenericNoPtrsStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.ж.Pointer) } func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } @@ -722,19 +706,14 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() } +func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsAliasView { return v.ж.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr } func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView { return v.ж.WithPtrByPtr.View() } -func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias { - if v.ж.WithoutPtrByPtr == nil { - return nil - } - x := *v.ж.WithoutPtrByPtr - return &x +func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() StructWithoutPtrsAliasView { + return v.ж.WithoutPtrByPtr.View() } - func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.ж.SliceWithPtrs) } diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 0c5868f3a..e265defe0 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -79,13 +79,7 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { {{end}} {{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { - if v.ж.{{.FieldName}} == nil { - return nil - } - x := *v.ж.{{.FieldName}} - return &x -} +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ValuePointer[{{.FieldType}}] { return views.ValuePointerOf(v.ж.{{.FieldName}}) } {{end}} {{define "mapField"}} @@ -126,7 +120,7 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) { return p, p, t } -func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thisPkg *types.Package) { +func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package) { t, ok := typ.Underlying().(*types.Struct) if !ok || codegen.IsViewType(t) { return @@ -354,10 +348,32 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } else { writeTemplate("unsupportedField") } - } else { - args.FieldType = it.QualifiedName(ptr) - writeTemplate("valuePointerField") + continue } + + // If a view type is already defined for the base type, use it as the field's view type. + if viewType := viewTypeForValueType(base); viewType != nil { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } + + // Otherwise, if the unaliased base type is a named type whose view type will be generated by this viewer invocation, + // append the "View" suffix to the unaliased base type name and use it as the field's view type. + if base, ok := types.Unalias(base).(*types.Named); ok && slices.Contains(typeNames, it.QualifiedName(base)) { + baseTypeName := it.QualifiedName(base) + args.FieldType = baseTypeName + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplate("viewField") + continue + } + + // Otherwise, if the base type does not require deep cloning, has no existing view type, + // and will not have a generated view type, use views.ValuePointer[T] as the field's view type. + // Its Get/GetOk methods return stack-allocated shallow copies of the field's value. + args.FieldType = it.QualifiedName(base) + writeTemplate("valuePointerField") continue case *types.Interface: // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. @@ -405,6 +421,33 @@ func appendNameSuffix(name, suffix string) string { return name + suffix } +func typeNameOf(typ types.Type) (name *types.TypeName, ok bool) { + switch t := typ.(type) { + case *types.Alias: + return t.Obj(), true + case *types.Named: + return t.Obj(), true + default: + return nil, false + } +} + +func lookupViewType(typ types.Type) types.Type { + for { + if typeName, ok := typeNameOf(typ); ok && typeName.Pkg() != nil { + if viewTypeObj := typeName.Pkg().Scope().Lookup(typeName.Name() + "View"); viewTypeObj != nil { + return viewTypeObj.Type() + } + } + switch alias := typ.(type) { + case *types.Alias: + typ = alias.Rhs() + default: + return nil + } + } +} + func viewTypeForValueType(typ types.Type) types.Type { if ptr, ok := typ.(*types.Pointer); ok { return viewTypeForValueType(ptr.Elem()) @@ -417,7 +460,12 @@ func viewTypeForValueType(typ types.Type) types.Type { if !ok || sig.Results().Len() != 1 { return nil } - return sig.Results().At(0).Type() + viewType := sig.Results().At(0).Type() + // Check if the typ's package defines an alias for the view type, and use it if so. + if viewTypeAlias, ok := lookupViewType(typ).(*types.Alias); ok && types.AssignableTo(viewType, viewTypeAlias) { + viewType = viewTypeAlias + } + return viewType } func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 97d49f90d..30c1da672 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -689,13 +689,11 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return nil, false } case "Online": - wasOnline := was.Online() - if n.Online != nil && wasOnline != nil && *n.Online != *wasOnline { + if wasOnline, ok := was.Online().GetOk(); ok && n.Online != nil && *n.Online != wasOnline { pc().Online = ptr.To(*n.Online) } case "LastSeen": - wasSeen := was.LastSeen() - if n.LastSeen != nil && wasSeen != nil && !wasSeen.Equal(*n.LastSeen) { + if wasSeen, ok := was.LastSeen().GetOk(); ok && n.LastSeen != nil && !wasSeen.Equal(*n.LastSeen) { pc().LastSeen = ptr.To(*n.LastSeen) } case "MachineAuthorized": @@ -720,18 +718,18 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "SelfNodeV4MasqAddrForThisPeer": va, vb := was.SelfNodeV4MasqAddrForThisPeer(), n.SelfNodeV4MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "SelfNodeV6MasqAddrForThisPeer": va, vb := was.SelfNodeV6MasqAddrForThisPeer(), n.SelfNodeV6MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "ExitNodeDNSResolvers": diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index fe3622ba4..8ae813ff2 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -347,8 +347,7 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // TODO(oxtoacart): for some reason, this correctly // catches when a node goes from offline to online, // but not the other way around... - online := peer.Online() - if online == nil || !*online { + if !peer.Online().Get() { return false } diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index af1aa337b..a2b10fe32 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -283,11 +283,11 @@ func formatNodes(nodes []tailcfg.NodeView) string { } fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) + if online, ok := n.Online().GetOk(); ok { + fmt.Fprintf(&sb, ", online=%v", online) } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + if lastSeen, ok := n.LastSeen().GetOk(); ok { + fmt.Fprintf(&sb, ", lastSeen=%v", lastSeen.Unix()) } if n.Key() != (key.NodePublic{}) { fmt.Fprintf(&sb, ", key=%v", n.Key().String()) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 3a2a22c58..4ebcd5d6d 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1117,13 +1117,9 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } if !prefs.ExitNodeID().IsZero() { if exitPeer, ok := b.netMap.PeerWithStableID(prefs.ExitNodeID()); ok { - online := false - if v := exitPeer.Online(); v != nil { - online = *v - } s.ExitNodeStatus = &ipnstate.ExitNodeStatus{ ID: prefs.ExitNodeID(), - Online: online, + Online: exitPeer.Online().Get(), TailscaleIPs: exitPeer.Addresses().AsSlice(), } } @@ -1194,10 +1190,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { } exitNodeID := b.pm.CurrentPrefs().ExitNodeID() for _, p := range b.peers { - var lastSeen time.Time - if p.LastSeen() != nil { - lastSeen = *p.LastSeen() - } tailscaleIPs := make([]netip.Addr, 0, p.Addresses().Len()) for i := range p.Addresses().Len() { addr := p.Addresses().At(i) @@ -1205,7 +1197,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { tailscaleIPs = append(tailscaleIPs, addr.Addr()) } } - online := p.Online() ps := &ipnstate.PeerStatus{ InNetworkMap: true, UserID: p.User(), @@ -1214,12 +1205,12 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { HostName: p.Hostinfo().Hostname(), DNSName: p.Name(), OS: p.Hostinfo().OS(), - LastSeen: lastSeen, - Online: online != nil && *online, + LastSeen: p.LastSeen().Get(), + Online: p.Online().Get(), ShareeNode: p.Hostinfo().ShareeNode(), ExitNode: p.StableID() != "" && p.StableID() == exitNodeID, SSH_HostKeys: p.Hostinfo().SSH_HostKeys().AsSlice(), - Location: p.Hostinfo().Location(), + Location: p.Hostinfo().Location().AsStruct(), Capabilities: p.Capabilities().AsSlice(), } if cm := p.CapMap(); cm.Len() > 0 { @@ -7369,8 +7360,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug if len(candidates) == 1 { peer := candidates[0] if hi := peer.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } res.ID = peer.StableID() @@ -7414,10 +7405,10 @@ type nodeDistance struct { continue } loc := hi.Location() - if loc == nil { + if !loc.Valid() { continue } - distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude, loc.Longitude) + distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude(), loc.Longitude()) if distance < minDistance { minDistance = distance } @@ -7438,8 +7429,8 @@ type nodeDistance struct { res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7468,8 +7459,8 @@ type nodeDistance struct { res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7485,13 +7476,13 @@ func pickWeighted(candidates []tailcfg.NodeView) []tailcfg.NodeView { continue } loc := hi.Location() - if loc == nil || loc.Priority < maxWeight { + if !loc.Valid() || loc.Priority() < maxWeight { continue } - if maxWeight != loc.Priority { + if maxWeight != loc.Priority() { best = best[:0] } - maxWeight = loc.Priority + maxWeight = loc.Priority() best = append(best, c) } return best diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index aa18c3588..7aa677640 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -233,11 +233,11 @@ func (h *peerAPIHandler) logf(format string, a ...any) { // isAddressValid reports whether addr is a valid destination address for this // node originating from the peer. func (h *peerAPIHandler) isAddressValid(addr netip.Addr) bool { - if v := h.peerNode.SelfNodeV4MasqAddrForThisPeer(); v != nil { - return *v == addr + if v, ok := h.peerNode.SelfNodeV4MasqAddrForThisPeer().GetOk(); ok { + return v == addr } - if v := h.peerNode.SelfNodeV6MasqAddrForThisPeer(); v != nil { - return *v == addr + if v, ok := h.peerNode.SelfNodeV6MasqAddrForThisPeer().GetOk(); ok { + return v == addr } pfx := netip.PrefixFrom(addr, addr.BitLen()) return views.SliceContains(h.selfNode.Addresses(), pfx) diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 774a18258..53df3dcef 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -145,21 +145,11 @@ func (v NodeView) Created() time.Time { return v.ж.Create func (v NodeView) Cap() CapabilityVersion { return v.ж.Cap } func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v NodeView) PrimaryRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.PrimaryRoutes) } -func (v NodeView) LastSeen() *time.Time { - if v.ж.LastSeen == nil { - return nil - } - x := *v.ж.LastSeen - return &x +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.LastSeen) } -func (v NodeView) Online() *bool { - if v.ж.Online == nil { - return nil - } - x := *v.ж.Online - return &x -} +func (v NodeView) Online() views.ValuePointer[bool] { return views.ValuePointerOf(v.ж.Online) } func (v NodeView) MachineAuthorized() bool { return v.ж.MachineAuthorized } func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.ж.Capabilities) } @@ -172,20 +162,12 @@ func (v NodeView) ComputedName() string { return v.ж.ComputedName } func (v NodeView) ComputedNameWithHost() string { return v.ж.ComputedNameWithHost } func (v NodeView) DataPlaneAuditLogID() string { return v.ж.DataPlaneAuditLogID } func (v NodeView) Expired() bool { return v.ж.Expired } -func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { - if v.ж.SelfNodeV4MasqAddrForThisPeer == nil { - return nil - } - x := *v.ж.SelfNodeV4MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV4MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.ж.SelfNodeV4MasqAddrForThisPeer) } -func (v NodeView) SelfNodeV6MasqAddrForThisPeer() *netip.Addr { - if v.ж.SelfNodeV6MasqAddrForThisPeer == nil { - return nil - } - x := *v.ж.SelfNodeV6MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV6MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.ж.SelfNodeV6MasqAddrForThisPeer) } func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } @@ -315,15 +297,8 @@ func (v HostinfoView) Userspace() opt.Bool { return v.ж.User func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter } func (v HostinfoView) AppConnector() opt.Bool { return v.ж.AppConnector } func (v HostinfoView) ServicesHash() string { return v.ж.ServicesHash } -func (v HostinfoView) Location() *Location { - if v.ж.Location == nil { - return nil - } - x := *v.ж.Location - return &x -} - -func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } +func (v HostinfoView) Location() LocationView { return v.ж.Location.View() } +func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HostinfoViewNeedsRegeneration = Hostinfo(struct { @@ -699,12 +674,8 @@ func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { return nil } -func (v RegisterResponseAuthView) Oauth2Token() *Oauth2Token { - if v.ж.Oauth2Token == nil { - return nil - } - x := *v.ж.Oauth2Token - return &x +func (v RegisterResponseAuthView) Oauth2Token() views.ValuePointer[Oauth2Token] { + return views.ValuePointerOf(v.ж.Oauth2Token) } func (v RegisterResponseAuthView) AuthKey() string { return v.ж.AuthKey } @@ -774,12 +745,8 @@ func (v RegisterRequestView) NodeKeySignature() views.ByteSlice[tkatype.Marshale return views.ByteSliceOf(v.ж.NodeKeySignature) } func (v RegisterRequestView) SignatureType() SignatureType { return v.ж.SignatureType } -func (v RegisterRequestView) Timestamp() *time.Time { - if v.ж.Timestamp == nil { - return nil - } - x := *v.ж.Timestamp - return &x +func (v RegisterRequestView) Timestamp() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.Timestamp) } func (v RegisterRequestView) DeviceCert() views.ByteSlice[[]byte] { @@ -1110,12 +1077,8 @@ func (v *SSHRuleView) UnmarshalJSON(b []byte) error { return nil } -func (v SSHRuleView) RuleExpires() *time.Time { - if v.ж.RuleExpires == nil { - return nil - } - x := *v.ж.RuleExpires - return &x +func (v SSHRuleView) RuleExpires() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.RuleExpires) } func (v SSHRuleView) Principals() views.SliceView[*SSHPrincipal, SSHPrincipalView] { @@ -1189,12 +1152,8 @@ func (v SSHActionView) HoldAndDelegate() string { return v.ж.Hol func (v SSHActionView) AllowLocalPortForwarding() bool { return v.ж.AllowLocalPortForwarding } func (v SSHActionView) AllowRemotePortForwarding() bool { return v.ж.AllowRemotePortForwarding } func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Recorders) } -func (v SSHActionView) OnRecordingFailure() *SSHRecorderFailureAction { - if v.ж.OnRecordingFailure == nil { - return nil - } - x := *v.ж.OnRecordingFailure - return &x +func (v SSHActionView) OnRecordingFailure() views.ValuePointer[SSHRecorderFailureAction] { + return views.ValuePointerOf(v.ж.OnRecordingFailure) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. diff --git a/types/prefs/prefs_view_test.go b/types/prefs/prefs_view_test.go index d76eebb43..ef9f09603 100644 --- a/types/prefs/prefs_view_test.go +++ b/types/prefs/prefs_view_test.go @@ -162,15 +162,8 @@ func (v *TestBundleView) UnmarshalJSON(b []byte) error { return nil } -func (v TestBundleView) Name() string { return v.ж.Name } -func (v TestBundleView) Nested() *TestValueStruct { - if v.ж.Nested == nil { - return nil - } - x := *v.ж.Nested - return &x -} - +func (v TestBundleView) Name() string { return v.ж.Name } +func (v TestBundleView) Nested() TestValueStructView { return v.ж.Nested.View() } func (v TestBundleView) Equal(v2 TestBundleView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. diff --git a/types/views/views.go b/types/views/views.go index 40d8811f5..d8acf27ce 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -16,6 +16,7 @@ "slices" "go4.org/mem" + "tailscale.com/types/ptr" ) func unmarshalSliceFromJSON[T any](b []byte, x *[]T) error { @@ -690,6 +691,85 @@ func (m MapFn[K, T, V]) All() iter.Seq2[K, V] { } } +// ValuePointer provides a read-only view of a pointer to a value type, +// such as a primitive type or an immutable struct. Its Value and ValueOk +// methods return a stack-allocated shallow copy of the underlying value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +type ValuePointer[T any] struct { + // ж is the underlying value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *T +} + +// Valid reports whether the underlying pointer is non-nil. +func (p ValuePointer[T]) Valid() bool { + return p.ж != nil +} + +// Get returns a shallow copy of the value if the underlying pointer is non-nil. +// Otherwise, it returns a zero value. +func (p ValuePointer[T]) Get() T { + v, _ := p.GetOk() + return v +} + +// GetOk returns a shallow copy of the underlying value and true if the underlying +// pointer is non-nil. Otherwise, it returns a zero value and false. +func (p ValuePointer[T]) GetOk() (value T, ok bool) { + if p.ж == nil { + return value, false // value holds a zero value + } + return *p.ж, true +} + +// GetOr returns a shallow copy of the underlying value if it is non-nil. +// Otherwise, it returns the provided default value. +func (p ValuePointer[T]) GetOr(def T) T { + if p.ж == nil { + return def + } + return *p.ж +} + +// Clone returns a shallow copy of the underlying value. +func (p ValuePointer[T]) Clone() *T { + if p.ж == nil { + return nil + } + return ptr.To(*p.ж) +} + +// String implements [fmt.Stringer]. +func (p ValuePointer[T]) String() string { + if p.ж == nil { + return "nil" + } + return fmt.Sprint(p.ж) +} + +// ValuePointerOf returns an immutable view of a pointer to an immutable value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +func ValuePointerOf[T any](v *T) ValuePointer[T] { + return ValuePointer[T]{v} +} + +// MarshalJSON implements [json.Marshaler]. +func (p ValuePointer[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(p.ж) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (p *ValuePointer[T]) UnmarshalJSON(b []byte) error { + if p.ж != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &p.ж) +} + // ContainsPointers reports whether T contains any pointers, // either explicitly or implicitly. // It has special handling for some types that contain pointers diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 308c3ede2..f8e9198a5 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -239,15 +239,15 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { if n.IsWireGuardOnly() { online = "wg" } else { - if v := n.Online(); v != nil { - if *v { + if v, ok := n.Online().GetOk(); ok { + if v { online = "yes" } else { online = "no" } } - if n.LastSeen() != nil && online != "yes" { - online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen())) + if lastSeen, ok := n.LastSeen().GetOk(); ok && online != "yes" { + online += fmt.Sprintf(", lastseen=%v", durFmt(lastSeen)) } } e.logf("open-conn-track: timeout opening %v to node %v; online=%v, lastRecv=%v", diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index e7d5edf15..97304aa41 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -106,8 +106,8 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cpeer := &cfg.Peers[len(cfg.Peers)-1] didExitNodeWarn := false - cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() - cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() + cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer().Clone() + cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer().Clone() cpeer.IsJailed = peer.IsJailed() for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode {