ssh/tailssh: guard access to c.info and c.localUser

This moves the info and localUser of a conn under the guard of the
conn's mu in order to prevent races between the fields being written
in clientAuth and it being read. Given that info and localUser are
pointers this doesn't strictly prevent individual fields from being
written without the mutex being held. The reason I consider that
good enough is that the code effectively treats both fields as
immutable once set. I have added defensive nil checks everywhere,
however, which might be overly conservative.

Updates tailscale/corp#36268

Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
This commit is contained in:
Gesa Stupperich 2026-02-03 20:58:05 +00:00
parent cd66071731
commit db52827a83
2 changed files with 130 additions and 33 deletions

View File

@ -192,11 +192,11 @@ func (srv *server) OnPolicyChange() {
srv.mu.Lock()
defer srv.mu.Unlock()
for c := range srv.activeConns {
// move info and localUser to be protected by conn mutex?
if c.info == nil || c.localUser == nil {
ci, lu := c.getInfoAndLocalUser()
if ci == nil || lu == nil {
// c.info or c.localUser are nil when the connection hasn't been
// authenticated yet. We will continue here, but the connection will
// be rechecked once it is authenticated. If it no longer conforms
// be checked once it is authenticated. If it no longer conforms
// with the SSH access policy at that point, it will be terminated.
continue
}
@ -242,9 +242,7 @@ type conn struct {
action0 *tailcfg.SSHAction // set by clientAuth
finalAction *tailcfg.SSHAction // set by clientAuth
info *sshConnInfo // set by setInfo
localUser *userMeta // set by clientAuth
userGroupIDs []string // set by clientAuth
userGroupIDs []string // set by clientAuth
acceptEnv []string
// mu protects the following fields.
@ -252,8 +250,10 @@ type conn struct {
// srv.mu should be acquired prior to mu.
// It is safe to just acquire mu, but unsafe to
// acquire mu and then srv.mu.
mu sync.Mutex // protects the following
sessions []*sshSession
mu sync.Mutex // protects the following
info *sshConnInfo // set by setInfo
localUser *userMeta // set by clientAuth
sessions []*sshSession
}
func (c *conn) logf(format string, args ...any) {
@ -267,6 +267,24 @@ func (c *conn) vlogf(format string, args ...any) {
}
}
func (c *conn) getInfo() *sshConnInfo {
c.mu.Lock()
defer c.mu.Unlock()
return c.info
}
func (c *conn) getLocalUser() *userMeta {
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser
}
func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) {
c.mu.Lock()
defer c.mu.Unlock()
return c.info, c.localUser
}
// errDenied is returned by auth callbacks when a connection is denied by the
// policy. It writes the message to an auth banner and then returns an empty
// gossh.PartialSuccessError in order to stop processing authentication
@ -337,7 +355,12 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
case accepted:
// do nothing
case rejectedUser:
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", c.info.sshUser), nil)
ci := c.getInfo()
if ci != nil {
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
} else {
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil)
}
case rejected, noPolicy:
return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result))
default:
@ -358,7 +381,9 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
return nil, c.errBanner("failed to look up local user's group IDs", err)
}
c.userGroupIDs = gids
c.mu.Lock()
c.localUser = lu
c.mu.Unlock()
c.acceptEnv = acceptEnv
}
@ -585,9 +610,6 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
// connInfo populates the sshConnInfo from the provided arguments,
// validating only that they represent a known Tailscale identity.
func (c *conn) setInfo(cm gossh.ConnMetadata) error {
if c.info != nil {
return nil
}
ci := &sshConnInfo{
sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix),
src: toIPPort(cm.RemoteAddr()),
@ -606,6 +628,11 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
ci.node = node
ci.uprof = uprof
c.mu.Lock()
defer c.mu.Unlock()
if c.info != nil {
return nil
}
c.idH = string(cm.SessionID())
c.info = ci
c.logf("handling conn: %v", ci.String())
@ -653,8 +680,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
}
ss := c.newSSHSession(s)
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
ci, lu := c.getInfoAndLocalUser()
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.Addr(), lu.Username)
ss.logf("access granted to %v as ssh-user %q", ci.uprof.LoginName, lu.Username)
if f, ok := hookSSHLoginSuccess.GetOk(); ok {
f(c.srv.logf, c)
@ -665,8 +693,7 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
func (c *conn) expandDelegateURLLocked(actionURL string) string {
nm := c.srv.lb.NetMap()
ci := c.info
lu := c.localUser
ci, lu := c.getInfoAndLocalUser()
var dstNodeID string
if nm != nil {
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID()))
@ -739,7 +766,8 @@ func (c *conn) isStillValid() bool {
if !a.Accept && a.HoldAndDelegate == "" {
return false
}
return c.localUser.Username == localUser
lu := c.getLocalUser()
return lu != nil && lu.Username == localUser
}
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
@ -813,7 +841,12 @@ func (ss *sshSession) killProcessOnContextDone() {
io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
}
}
ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.Addr(), err)
ci := ss.conn.getInfo()
if ci != nil {
ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
} else {
ss.logf("terminating SSH session: %v", err)
}
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
@ -915,7 +948,7 @@ func (ss *sshSession) run() {
}
defer ss.conn.detachSession(ss)
lu := ss.conn.localUser
lu := ss.conn.getLocalUser()
logf := ss.logf
if ss.conn.finalAction.SessionDuration != 0 {
@ -1148,7 +1181,8 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
if c == nil {
return nil, "", nil, errInvalidConn
}
if c.info == nil {
ci := c.getInfo()
if ci == nil {
c.logf("invalid connection state")
return nil, "", nil, errInvalidConn
}
@ -1168,7 +1202,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
// For all but Reject rules, SSHUsers is required.
// If SSHUsers is nil or empty, mapLocalUser will return an
// empty string anyway.
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
if localUser == "" {
return nil, "", nil, errUserMatch
}
@ -1202,7 +1236,10 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool {
// principalMatchesTailscaleIdentity reports whether one of p's four fields
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
ci := c.info
ci := c.getInfo()
if ci == nil {
return false
}
if p.Any {
return true
}
@ -1348,6 +1385,10 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
}()
}
ci, lu := ss.conn.getInfoAndLocalUser()
if ci == nil || lu == nil {
return nil, errors.New("recording: missing connection metadata")
}
ch := sessionrecording.CastHeader{
Version: 2,
Width: w.Width,
@ -1365,17 +1406,17 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
// it. Then we can (1) make the cmd, (2) start the
// recording, (3) start the process.
},
SSHUser: ss.conn.info.sshUser,
LocalUser: ss.conn.localUser.Username,
SrcNode: strings.TrimSuffix(ss.conn.info.node.Name(), "."),
SrcNodeID: ss.conn.info.node.StableID(),
SSHUser: ci.sshUser,
LocalUser: lu.Username,
SrcNode: strings.TrimSuffix(ci.node.Name(), "."),
SrcNodeID: ci.node.StableID(),
ConnectionID: ss.conn.connID,
}
if !ss.conn.info.node.IsTagged() {
ch.SrcNodeUser = ss.conn.info.uprof.LoginName
ch.SrcNodeUserID = ss.conn.info.node.User()
if !ci.node.IsTagged() {
ch.SrcNodeUser = ci.uprof.LoginName
ch.SrcNodeUserID = ci.node.User()
} else {
ch.SrcNodeTags = ss.conn.info.node.Tags().AsSlice()
ch.SrcNodeTags = ci.node.Tags().AsSlice()
}
j, err := json.Marshal(ch)
if err != nil {
@ -1398,14 +1439,19 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
// A SSHEventNotifyRequest is sent when an action or state reached during
// an SSH session is a defined EventType.
func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) {
ci, lu := ss.conn.getInfoAndLocalUser()
if ci == nil || lu == nil {
ss.logf("notifyControl: missing connection metadata")
return
}
re := tailcfg.SSHEventNotifyRequest{
EventType: notifyType,
ConnectionID: ss.conn.connID,
CapVersion: tailcfg.CurrentCapabilityVersion,
NodeKey: nodeKey,
SrcNode: ss.conn.info.node.ID(),
SSHUser: ss.conn.info.sshUser,
LocalUser: ss.conn.localUser.Username,
SrcNode: ci.node.ID(),
SSHUser: ci.sshUser,
LocalUser: lu.Username,
RecordingAttempts: attempts,
}

View File

@ -1339,6 +1339,57 @@ func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) {
})
}
func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
srv := &server{
logf: tstest.WhileTestRunningLogger(t),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
},
}
c := &conn{
srv: srv,
info: &sshConnInfo{sshUser: "alice"},
}
srv.activeConns = map[*conn]bool{c: true}
fakeClientAuth := func() {
c.mu.Lock()
c.info = &sshConnInfo{sshUser: "alice"}
c.mu.Unlock()
c.mu.Lock()
c.localUser = &userMeta{User: user.User{Username: currentUser}}
c.mu.Unlock()
}
// Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser.
done := make(chan struct{})
go func() {
for i := 0; i < 100; i++ {
select {
case <-done:
return
default:
fakeClientAuth()
}
}
}()
go func() {
for i := 0; i < 100; i++ {
select {
case <-done:
return
default:
srv.OnPolicyChange()
}
}
}()
})
}
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
t.Helper()
mux := http.NewServeMux()