diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 2c90dc83c..d8dea7da2 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -192,11 +192,11 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { - // move info and localUser to be protected by conn mutex? - if c.info == nil || c.localUser == nil { + ci, lu := c.getInfoAndLocalUser() + if ci == nil || lu == nil { // c.info or c.localUser are nil when the connection hasn't been // authenticated yet. We will continue here, but the connection will - // be rechecked once it is authenticated. If it no longer conforms + // be checked once it is authenticated. If it no longer conforms // with the SSH access policy at that point, it will be terminated. continue } @@ -242,9 +242,7 @@ type conn struct { action0 *tailcfg.SSHAction // set by clientAuth finalAction *tailcfg.SSHAction // set by clientAuth - info *sshConnInfo // set by setInfo - localUser *userMeta // set by clientAuth - userGroupIDs []string // set by clientAuth + userGroupIDs []string // set by clientAuth acceptEnv []string // mu protects the following fields. @@ -252,8 +250,10 @@ type conn struct { // srv.mu should be acquired prior to mu. // It is safe to just acquire mu, but unsafe to // acquire mu and then srv.mu. - mu sync.Mutex // protects the following - sessions []*sshSession + mu sync.Mutex // protects the following + info *sshConnInfo // set by setInfo + localUser *userMeta // set by clientAuth + sessions []*sshSession } func (c *conn) logf(format string, args ...any) { @@ -267,6 +267,24 @@ func (c *conn) vlogf(format string, args ...any) { } } +func (c *conn) getInfo() *sshConnInfo { + c.mu.Lock() + defer c.mu.Unlock() + return c.info +} + +func (c *conn) getLocalUser() *userMeta { + c.mu.Lock() + defer c.mu.Unlock() + return c.localUser +} + +func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) { + c.mu.Lock() + defer c.mu.Unlock() + return c.info, c.localUser +} + // errDenied is returned by auth callbacks when a connection is denied by the // policy. It writes the message to an auth banner and then returns an empty // gossh.PartialSuccessError in order to stop processing authentication @@ -337,7 +355,12 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE case accepted: // do nothing case rejectedUser: - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", c.info.sshUser), nil) + ci := c.getInfo() + if ci != nil { + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) + } else { + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil) + } case rejected, noPolicy: return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result)) default: @@ -358,7 +381,9 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE return nil, c.errBanner("failed to look up local user's group IDs", err) } c.userGroupIDs = gids + c.mu.Lock() c.localUser = lu + c.mu.Unlock() c.acceptEnv = acceptEnv } @@ -585,9 +610,6 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. func (c *conn) setInfo(cm gossh.ConnMetadata) error { - if c.info != nil { - return nil - } ci := &sshConnInfo{ sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), src: toIPPort(cm.RemoteAddr()), @@ -606,6 +628,11 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { ci.node = node ci.uprof = uprof + c.mu.Lock() + defer c.mu.Unlock() + if c.info != nil { + return nil + } c.idH = string(cm.SessionID()) c.info = ci c.logf("handling conn: %v", ci.String()) @@ -653,8 +680,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { } ss := c.newSSHSession(s) - ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username) - ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) + ci, lu := c.getInfoAndLocalUser() + ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.Addr(), lu.Username) + ss.logf("access granted to %v as ssh-user %q", ci.uprof.LoginName, lu.Username) if f, ok := hookSSHLoginSuccess.GetOk(); ok { f(c.srv.logf, c) @@ -665,8 +693,7 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { func (c *conn) expandDelegateURLLocked(actionURL string) string { nm := c.srv.lb.NetMap() - ci := c.info - lu := c.localUser + ci, lu := c.getInfoAndLocalUser() var dstNodeID string if nm != nil { dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID())) @@ -739,7 +766,8 @@ func (c *conn) isStillValid() bool { if !a.Accept && a.HoldAndDelegate == "" { return false } - return c.localUser.Username == localUser + lu := c.getLocalUser() + return lu != nil && lu.Username == localUser } // checkStillValid checks that the conn is still valid per the latest SSHPolicy. @@ -813,7 +841,12 @@ func (ss *sshSession) killProcessOnContextDone() { io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") } } - ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.Addr(), err) + ci := ss.conn.getInfo() + if ci != nil { + ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) + } else { + ss.logf("terminating SSH session: %v", err) + } // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. @@ -915,7 +948,7 @@ func (ss *sshSession) run() { } defer ss.conn.detachSession(ss) - lu := ss.conn.localUser + lu := ss.conn.getLocalUser() logf := ss.logf if ss.conn.finalAction.SessionDuration != 0 { @@ -1148,7 +1181,8 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st if c == nil { return nil, "", nil, errInvalidConn } - if c.info == nil { + ci := c.getInfo() + if ci == nil { c.logf("invalid connection state") return nil, "", nil, errInvalidConn } @@ -1168,7 +1202,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st // For all but Reject rules, SSHUsers is required. // If SSHUsers is nil or empty, mapLocalUser will return an // empty string anyway. - localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) + localUser = mapLocalUser(r.SSHUsers, ci.sshUser) if localUser == "" { return nil, "", nil, errUserMatch } @@ -1202,7 +1236,10 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { // principalMatchesTailscaleIdentity reports whether one of p's four fields // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { - ci := c.info + ci := c.getInfo() + if ci == nil { + return false + } if p.Any { return true } @@ -1348,6 +1385,10 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { }() } + ci, lu := ss.conn.getInfoAndLocalUser() + if ci == nil || lu == nil { + return nil, errors.New("recording: missing connection metadata") + } ch := sessionrecording.CastHeader{ Version: 2, Width: w.Width, @@ -1365,17 +1406,17 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // it. Then we can (1) make the cmd, (2) start the // recording, (3) start the process. }, - SSHUser: ss.conn.info.sshUser, - LocalUser: ss.conn.localUser.Username, - SrcNode: strings.TrimSuffix(ss.conn.info.node.Name(), "."), - SrcNodeID: ss.conn.info.node.StableID(), + SSHUser: ci.sshUser, + LocalUser: lu.Username, + SrcNode: strings.TrimSuffix(ci.node.Name(), "."), + SrcNodeID: ci.node.StableID(), ConnectionID: ss.conn.connID, } - if !ss.conn.info.node.IsTagged() { - ch.SrcNodeUser = ss.conn.info.uprof.LoginName - ch.SrcNodeUserID = ss.conn.info.node.User() + if !ci.node.IsTagged() { + ch.SrcNodeUser = ci.uprof.LoginName + ch.SrcNodeUserID = ci.node.User() } else { - ch.SrcNodeTags = ss.conn.info.node.Tags().AsSlice() + ch.SrcNodeTags = ci.node.Tags().AsSlice() } j, err := json.Marshal(ch) if err != nil { @@ -1398,14 +1439,19 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // A SSHEventNotifyRequest is sent when an action or state reached during // an SSH session is a defined EventType. func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) { + ci, lu := ss.conn.getInfoAndLocalUser() + if ci == nil || lu == nil { + ss.logf("notifyControl: missing connection metadata") + return + } re := tailcfg.SSHEventNotifyRequest{ EventType: notifyType, ConnectionID: ss.conn.connID, CapVersion: tailcfg.CurrentCapabilityVersion, NodeKey: nodeKey, - SrcNode: ss.conn.info.node.ID(), - SSHUser: ss.conn.info.sshUser, - LocalUser: ss.conn.localUser.Username, + SrcNode: ci.node.ID(), + SSHUser: ci.sshUser, + LocalUser: lu.Username, RecordingAttempts: attempts, } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index b9e591d80..581c8be82 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -1339,6 +1339,57 @@ func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) { }) } +func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + }, + } + c := &conn{ + srv: srv, + info: &sshConnInfo{sshUser: "alice"}, + } + srv.activeConns = map[*conn]bool{c: true} + + fakeClientAuth := func() { + c.mu.Lock() + c.info = &sshConnInfo{sshUser: "alice"} + c.mu.Unlock() + + c.mu.Lock() + c.localUser = &userMeta{User: user.User{Username: currentUser}} + c.mu.Unlock() + } + + // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser. + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + select { + case <-done: + return + default: + fakeClientAuth() + } + } + }() + + go func() { + for i := 0; i < 100; i++ { + select { + case <-done: + return + default: + srv.OnPolicyChange() + } + } + }() + }) +} + func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux()