From db52827a83553122dadc36e4607e9b7a84abb2ec Mon Sep 17 00:00:00 2001 From: Gesa Stupperich Date: Tue, 3 Feb 2026 20:58:05 +0000 Subject: [PATCH] ssh/tailssh: guard access to c.info and c.localUser This moves the info and localUser of a conn under the guard of the conn's mu in order to prevent races between the fields being written in clientAuth and it being read. Given that info and localUser are pointers this doesn't strictly prevent individual fields from being written without the mutex being held. The reason I consider that good enough is that the code effectively treats both fields as immutable once set. I have added defensive nil checks everywhere, however, which might be overly conservative. Updates tailscale/corp#36268 Signed-off-by: Gesa Stupperich --- ssh/tailssh/tailssh.go | 112 +++++++++++++++++++++++++----------- ssh/tailssh/tailssh_test.go | 51 ++++++++++++++++ 2 files changed, 130 insertions(+), 33 deletions(-) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 2c90dc83c..d8dea7da2 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -192,11 +192,11 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { - // move info and localUser to be protected by conn mutex? - if c.info == nil || c.localUser == nil { + ci, lu := c.getInfoAndLocalUser() + if ci == nil || lu == nil { // c.info or c.localUser are nil when the connection hasn't been // authenticated yet. We will continue here, but the connection will - // be rechecked once it is authenticated. If it no longer conforms + // be checked once it is authenticated. If it no longer conforms // with the SSH access policy at that point, it will be terminated. continue } @@ -242,9 +242,7 @@ type conn struct { action0 *tailcfg.SSHAction // set by clientAuth finalAction *tailcfg.SSHAction // set by clientAuth - info *sshConnInfo // set by setInfo - localUser *userMeta // set by clientAuth - userGroupIDs []string // set by clientAuth + userGroupIDs []string // set by clientAuth acceptEnv []string // mu protects the following fields. @@ -252,8 +250,10 @@ type conn struct { // srv.mu should be acquired prior to mu. // It is safe to just acquire mu, but unsafe to // acquire mu and then srv.mu. - mu sync.Mutex // protects the following - sessions []*sshSession + mu sync.Mutex // protects the following + info *sshConnInfo // set by setInfo + localUser *userMeta // set by clientAuth + sessions []*sshSession } func (c *conn) logf(format string, args ...any) { @@ -267,6 +267,24 @@ func (c *conn) vlogf(format string, args ...any) { } } +func (c *conn) getInfo() *sshConnInfo { + c.mu.Lock() + defer c.mu.Unlock() + return c.info +} + +func (c *conn) getLocalUser() *userMeta { + c.mu.Lock() + defer c.mu.Unlock() + return c.localUser +} + +func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) { + c.mu.Lock() + defer c.mu.Unlock() + return c.info, c.localUser +} + // errDenied is returned by auth callbacks when a connection is denied by the // policy. It writes the message to an auth banner and then returns an empty // gossh.PartialSuccessError in order to stop processing authentication @@ -337,7 +355,12 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE case accepted: // do nothing case rejectedUser: - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", c.info.sshUser), nil) + ci := c.getInfo() + if ci != nil { + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) + } else { + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil) + } case rejected, noPolicy: return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result)) default: @@ -358,7 +381,9 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE return nil, c.errBanner("failed to look up local user's group IDs", err) } c.userGroupIDs = gids + c.mu.Lock() c.localUser = lu + c.mu.Unlock() c.acceptEnv = acceptEnv } @@ -585,9 +610,6 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. func (c *conn) setInfo(cm gossh.ConnMetadata) error { - if c.info != nil { - return nil - } ci := &sshConnInfo{ sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), src: toIPPort(cm.RemoteAddr()), @@ -606,6 +628,11 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { ci.node = node ci.uprof = uprof + c.mu.Lock() + defer c.mu.Unlock() + if c.info != nil { + return nil + } c.idH = string(cm.SessionID()) c.info = ci c.logf("handling conn: %v", ci.String()) @@ -653,8 +680,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { } ss := c.newSSHSession(s) - ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username) - ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) + ci, lu := c.getInfoAndLocalUser() + ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.Addr(), lu.Username) + ss.logf("access granted to %v as ssh-user %q", ci.uprof.LoginName, lu.Username) if f, ok := hookSSHLoginSuccess.GetOk(); ok { f(c.srv.logf, c) @@ -665,8 +693,7 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { func (c *conn) expandDelegateURLLocked(actionURL string) string { nm := c.srv.lb.NetMap() - ci := c.info - lu := c.localUser + ci, lu := c.getInfoAndLocalUser() var dstNodeID string if nm != nil { dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID())) @@ -739,7 +766,8 @@ func (c *conn) isStillValid() bool { if !a.Accept && a.HoldAndDelegate == "" { return false } - return c.localUser.Username == localUser + lu := c.getLocalUser() + return lu != nil && lu.Username == localUser } // checkStillValid checks that the conn is still valid per the latest SSHPolicy. @@ -813,7 +841,12 @@ func (ss *sshSession) killProcessOnContextDone() { io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") } } - ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.Addr(), err) + ci := ss.conn.getInfo() + if ci != nil { + ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) + } else { + ss.logf("terminating SSH session: %v", err) + } // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. @@ -915,7 +948,7 @@ func (ss *sshSession) run() { } defer ss.conn.detachSession(ss) - lu := ss.conn.localUser + lu := ss.conn.getLocalUser() logf := ss.logf if ss.conn.finalAction.SessionDuration != 0 { @@ -1148,7 +1181,8 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st if c == nil { return nil, "", nil, errInvalidConn } - if c.info == nil { + ci := c.getInfo() + if ci == nil { c.logf("invalid connection state") return nil, "", nil, errInvalidConn } @@ -1168,7 +1202,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st // For all but Reject rules, SSHUsers is required. // If SSHUsers is nil or empty, mapLocalUser will return an // empty string anyway. - localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) + localUser = mapLocalUser(r.SSHUsers, ci.sshUser) if localUser == "" { return nil, "", nil, errUserMatch } @@ -1202,7 +1236,10 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { // principalMatchesTailscaleIdentity reports whether one of p's four fields // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { - ci := c.info + ci := c.getInfo() + if ci == nil { + return false + } if p.Any { return true } @@ -1348,6 +1385,10 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { }() } + ci, lu := ss.conn.getInfoAndLocalUser() + if ci == nil || lu == nil { + return nil, errors.New("recording: missing connection metadata") + } ch := sessionrecording.CastHeader{ Version: 2, Width: w.Width, @@ -1365,17 +1406,17 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // it. Then we can (1) make the cmd, (2) start the // recording, (3) start the process. }, - SSHUser: ss.conn.info.sshUser, - LocalUser: ss.conn.localUser.Username, - SrcNode: strings.TrimSuffix(ss.conn.info.node.Name(), "."), - SrcNodeID: ss.conn.info.node.StableID(), + SSHUser: ci.sshUser, + LocalUser: lu.Username, + SrcNode: strings.TrimSuffix(ci.node.Name(), "."), + SrcNodeID: ci.node.StableID(), ConnectionID: ss.conn.connID, } - if !ss.conn.info.node.IsTagged() { - ch.SrcNodeUser = ss.conn.info.uprof.LoginName - ch.SrcNodeUserID = ss.conn.info.node.User() + if !ci.node.IsTagged() { + ch.SrcNodeUser = ci.uprof.LoginName + ch.SrcNodeUserID = ci.node.User() } else { - ch.SrcNodeTags = ss.conn.info.node.Tags().AsSlice() + ch.SrcNodeTags = ci.node.Tags().AsSlice() } j, err := json.Marshal(ch) if err != nil { @@ -1398,14 +1439,19 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // A SSHEventNotifyRequest is sent when an action or state reached during // an SSH session is a defined EventType. func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) { + ci, lu := ss.conn.getInfoAndLocalUser() + if ci == nil || lu == nil { + ss.logf("notifyControl: missing connection metadata") + return + } re := tailcfg.SSHEventNotifyRequest{ EventType: notifyType, ConnectionID: ss.conn.connID, CapVersion: tailcfg.CurrentCapabilityVersion, NodeKey: nodeKey, - SrcNode: ss.conn.info.node.ID(), - SSHUser: ss.conn.info.sshUser, - LocalUser: ss.conn.localUser.Username, + SrcNode: ci.node.ID(), + SSHUser: ci.sshUser, + LocalUser: lu.Username, RecordingAttempts: attempts, } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index b9e591d80..581c8be82 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -1339,6 +1339,57 @@ func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) { }) } +func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + }, + } + c := &conn{ + srv: srv, + info: &sshConnInfo{sshUser: "alice"}, + } + srv.activeConns = map[*conn]bool{c: true} + + fakeClientAuth := func() { + c.mu.Lock() + c.info = &sshConnInfo{sshUser: "alice"} + c.mu.Unlock() + + c.mu.Lock() + c.localUser = &userMeta{User: user.User{Username: currentUser}} + c.mu.Unlock() + } + + // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser. + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + select { + case <-done: + return + default: + fakeClientAuth() + } + } + }() + + go func() { + for i := 0; i < 100; i++ { + select { + case <-done: + return + default: + srv.OnPolicyChange() + } + } + }() + }) +} + func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux()