ssh/tailssh: store c.info and c.localUser as values

This converts the info and localUser fields on the conn from
pointers to values. I consider this an overall improvement since
both structs are small and it makes access safer in cases when
they've not yet been set.

Updates tailscale/corp#36268

Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
This commit is contained in:
Gesa Stupperich 2026-02-10 11:21:24 +00:00
parent db52827a83
commit 06abf96811
6 changed files with 102 additions and 140 deletions

View File

@ -123,7 +123,7 @@ func sendAuditMessage(logf logger.Logf, msgType uint16, message string) {
// logSSHLogin logs an SSH login event to auditd with whois information.
func logSSHLogin(logf logger.Logf, c *conn) {
if c == nil || c.info == nil || c.localUser == nil {
if c == nil {
return
}

View File

@ -1099,7 +1099,7 @@ func (ss *sshSession) startWithStdPipes() (err error) {
return ss.cmd.Start()
}
func envForUser(u *userMeta) []string {
func envForUser(u userMeta) []string {
return []string{
fmt.Sprintf("SHELL=%s", u.LoginShell()),
fmt.Sprintf("USER=%s", u.Username),

View File

@ -400,7 +400,7 @@ func (ss *sshSession) startWithStdPipes() (err error) {
return ss.cmd.Start()
}
func envForUser(u *userMeta) []string {
func envForUser(u userMeta) []string {
return []string{
fmt.Sprintf("user=%s", u.Username),
fmt.Sprintf("home=%s", u.HomeDir),

View File

@ -193,8 +193,8 @@ func (srv *server) OnPolicyChange() {
defer srv.mu.Unlock()
for c := range srv.activeConns {
ci, lu := c.getInfoAndLocalUser()
if ci == nil || lu == nil {
// c.info or c.localUser are nil when the connection hasn't been
if ci.sshUser == "" || lu.Username == "" {
// c.info or c.localUser are empty when the connection hasn't been
// authenticated yet. We will continue here, but the connection will
// be checked once it is authenticated. If it no longer conforms
// with the SSH access policy at that point, it will be terminated.
@ -250,9 +250,9 @@ type conn struct {
// srv.mu should be acquired prior to mu.
// It is safe to just acquire mu, but unsafe to
// acquire mu and then srv.mu.
mu sync.Mutex // protects the following
info *sshConnInfo // set by setInfo
localUser *userMeta // set by clientAuth
mu sync.Mutex // protects the following
info sshConnInfo // set by setInfo
localUser userMeta // set by clientAuth
sessions []*sshSession
}
@ -267,19 +267,19 @@ func (c *conn) vlogf(format string, args ...any) {
}
}
func (c *conn) getInfo() *sshConnInfo {
func (c *conn) getInfo() sshConnInfo {
c.mu.Lock()
defer c.mu.Unlock()
return c.info
}
func (c *conn) getLocalUser() *userMeta {
func (c *conn) getLocalUser() userMeta {
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser
}
func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) {
func (c *conn) getInfoAndLocalUser() (sshConnInfo, userMeta) {
c.mu.Lock()
defer c.mu.Unlock()
return c.info, c.localUser
@ -356,11 +356,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
// do nothing
case rejectedUser:
ci := c.getInfo()
if ci != nil {
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
} else {
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil)
}
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
case rejected, noPolicy:
return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result))
default:
@ -610,7 +606,7 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
// connInfo populates the sshConnInfo from the provided arguments,
// validating only that they represent a known Tailscale identity.
func (c *conn) setInfo(cm gossh.ConnMetadata) error {
ci := &sshConnInfo{
ci := sshConnInfo{
sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix),
src: toIPPort(cm.RemoteAddr()),
dst: toIPPort(cm.LocalAddr()),
@ -630,7 +626,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.info != nil {
if c.info.sshUser != "" {
return nil
}
c.idH = string(cm.SessionID())
@ -767,7 +763,7 @@ func (c *conn) isStillValid() bool {
return false
}
lu := c.getLocalUser()
return lu != nil && lu.Username == localUser
return lu.Username == localUser
}
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
@ -842,11 +838,7 @@ func (ss *sshSession) killProcessOnContextDone() {
}
}
ci := ss.conn.getInfo()
if ci != nil {
ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
} else {
ss.logf("terminating SSH session: %v", err)
}
ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
@ -884,7 +876,7 @@ var errSessionDone = errors.New("session is done")
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session.
// On success, it assigns ss.agentListener.
func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error {
func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu userMeta) error {
if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
return nil
}
@ -1182,10 +1174,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
return nil, "", nil, errInvalidConn
}
ci := c.getInfo()
if ci == nil {
c.logf("invalid connection state")
return nil, "", nil, errInvalidConn
}
if r == nil {
return nil, "", nil, errNilRule
}
@ -1237,9 +1225,6 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool {
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
ci := c.getInfo()
if ci == nil {
return false
}
if p.Any {
return true
}
@ -1386,9 +1371,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
}
ci, lu := ss.conn.getInfoAndLocalUser()
if ci == nil || lu == nil {
return nil, errors.New("recording: missing connection metadata")
}
ch := sessionrecording.CastHeader{
Version: 2,
Width: w.Width,
@ -1440,10 +1422,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
// an SSH session is a defined EventType.
func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) {
ci, lu := ss.conn.getInfoAndLocalUser()
if ci == nil || lu == nil {
ss.logf("notifyControl: missing connection metadata")
return
}
re := tailcfg.SSHEventNotifyRequest{
EventType: notifyType,
ConnectionID: ss.conn.connID,

View File

@ -64,31 +64,20 @@ func TestMatchRule(t *testing.T) {
tests := []struct {
name string
rule *tailcfg.SSHRule
ci *sshConnInfo
ci sshConnInfo
wantErr error
wantUser string
wantAcceptEnv []string
}{
{
name: "invalid-conn",
rule: &tailcfg.SSHRule{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
},
},
wantErr: errInvalidConn,
},
{
name: "nil-rule",
ci: &sshConnInfo{},
ci: sshConnInfo{},
rule: nil,
wantErr: errNilRule,
},
{
name: "nil-action",
ci: &sshConnInfo{},
ci: sshConnInfo{},
rule: &tailcfg.SSHRule{},
wantErr: errNilAction,
},
@ -98,7 +87,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
RuleExpires: ptr.To(time.Unix(100, 0)),
},
ci: &sshConnInfo{},
ci: sshConnInfo{},
wantErr: errRuleExpired,
},
{
@ -108,7 +97,7 @@ func TestMatchRule(t *testing.T) {
SSHUsers: map[string]string{
"*": "ubuntu",
}},
ci: &sshConnInfo{},
ci: sshConnInfo{},
wantErr: errPrincipalMatch,
},
{
@ -117,7 +106,7 @@ func TestMatchRule(t *testing.T) {
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantErr: errUserMatch,
},
{
@ -129,7 +118,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@ -144,7 +133,7 @@ func TestMatchRule(t *testing.T) {
"*": "ubuntu",
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "ubuntu",
},
{
@ -157,7 +146,7 @@ func TestMatchRule(t *testing.T) {
"alice": "thealice",
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
},
{
@ -171,7 +160,7 @@ func TestMatchRule(t *testing.T) {
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
@ -181,7 +170,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Reject: true},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
},
{
name: "match-principal-node-ip",
@ -190,7 +179,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
ci: sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
wantUser: "ubuntu",
},
{
@ -200,7 +189,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
ci: sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
wantUser: "ubuntu",
},
{
@ -210,7 +199,7 @@ func TestMatchRule(t *testing.T) {
Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}},
SSHUsers: map[string]string{"*": "ubuntu"},
},
ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
ci: sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}},
wantUser: "ubuntu",
},
{
@ -222,7 +211,7 @@ func TestMatchRule(t *testing.T) {
"*": "=",
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "alice",
},
}
@ -254,7 +243,7 @@ func TestEvalSSHPolicy(t *testing.T) {
tests := []struct {
name string
policy *tailcfg.SSHPolicy
ci *sshConnInfo
ci sshConnInfo
wantResult evalResult
wantUser string
wantAcceptEnv []string
@ -298,7 +287,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
wantResult: accepted,
@ -308,7 +297,7 @@ func TestEvalSSHPolicy(t *testing.T) {
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejected,
@ -349,7 +338,7 @@ func TestEvalSSHPolicy(t *testing.T) {
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
ci: sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantResult: rejectedUser,
@ -1100,7 +1089,7 @@ func TestSSH(t *testing.T) {
t.Fatal(err)
}
sc.localUser = um
sc.info = &sshConnInfo{
sc.info = sshConnInfo{
sshUser: "test",
src: netip.MustParseAddrPort("1.2.3.4:32342"),
dst: netip.MustParseAddrPort("1.2.3.5:22"),
@ -1318,76 +1307,71 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
}
}
func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
srv := &server{
logf: tstest.WhileTestRunningLogger(t),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
},
}
c := &conn{
srv: srv,
info: &sshConnInfo{sshUser: "alice"},
}
srv.activeConns = map[*conn]bool{c: true}
func TestOnPolicyChangeDefersValidationOnEmptyLocalUser(t *testing.T) {
tests := []struct {
name string
sshRule *tailcfg.SSHRule
wantCancelOnValidation bool
}{
{
name: "defer-then-accept-when-allowed",
sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
wantCancelOnValidation: false,
},
{
name: "defer-then-reject-when-not-allowed",
sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}),
wantCancelOnValidation: true,
},
}
srv.OnPolicyChange()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
synctest.Wait()
})
}
func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
srv := &server{
logf: tstest.WhileTestRunningLogger(t),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
},
}
c := &conn{
srv: srv,
info: &sshConnInfo{sshUser: "alice"},
}
srv.activeConns = map[*conn]bool{c: true}
fakeClientAuth := func() {
c.mu.Lock()
c.info = &sshConnInfo{sshUser: "alice"}
c.mu.Unlock()
c.mu.Lock()
c.localUser = &userMeta{User: user.User{Username: currentUser}}
c.mu.Unlock()
}
// Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser.
done := make(chan struct{})
go func() {
for i := 0; i < 100; i++ {
select {
case <-done:
return
default:
fakeClientAuth()
synctest.Test(t, func(t *testing.T) {
srv := &server{
logf: tstest.WhileTestRunningLogger(t),
lb: &localState{
sshEnabled: true,
matchingRule: tt.sshRule,
},
}
}
}()
go func() {
for i := 0; i < 100; i++ {
select {
case <-done:
return
default:
srv.OnPolicyChange()
c := &conn{
srv: srv,
info: sshConnInfo{sshUser: "alice"},
}
}
}()
})
srv.activeConns = map[*conn]bool{c: true}
ctx, cancel := context.WithCancelCause(context.Background())
ss := &sshSession{ctx: ctx, cancelCtx: cancel}
c.sessions = []*sshSession{ss}
srv.OnPolicyChange()
synctest.Wait()
select {
case <-ctx.Done():
t.Fatal("expected deferral of cancellation decision while localUser unset but session got canceled")
default:
}
c.mu.Lock()
c.localUser = userMeta{User: user.User{Username: currentUser}}
c.mu.Unlock()
srv.OnPolicyChange()
synctest.Wait()
select {
case <-ctx.Done():
if !tt.wantCancelOnValidation {
t.Fatal("valid session shouldn't have been canceled")
}
default:
if tt.wantCancelOnValidation {
t.Fatal("invalid session should have been canceled but it wasn't")
}
}
})
})
}
}
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {

View File

@ -36,15 +36,15 @@ func (u *userMeta) GroupIds() ([]string, error) {
return osuser.GetGroupIds(&u.User)
}
// userLookup is like os/user.Lookup but it returns a *userMeta wrapper
// userLookup is like os/user.Lookup but it returns a userMeta wrapper
// around a *user.User with extra fields.
func userLookup(username string) (*userMeta, error) {
func userLookup(username string) (userMeta, error) {
u, s, err := osuser.LookupByUsernameWithShell(username)
if err != nil {
return nil, err
return userMeta{}, err
}
return &userMeta{User: *u, loginShellCached: s}, nil
return userMeta{User: *u, loginShellCached: s}, nil
}
func (u *userMeta) LoginShell() string {