mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 20:26:47 +02:00
ssh/tailssh: guard access to c.info and c.localUser
This moves the info and localUser of a conn under the guard of the conn's mu in order to prevent races between the fields being written in clientAuth and it being read. Given that info and localUser are pointers this doesn't strictly prevent individual fields from being written without the mutex being held. The reason I consider that good enough is that the code effectively treats both fields as immutable once set. I have added defensive nil checks everywhere, however, which might be overly conservative. Updates tailscale/corp#36268 Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
This commit is contained in:
parent
cd66071731
commit
db52827a83
@ -192,11 +192,11 @@ func (srv *server) OnPolicyChange() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
for c := range srv.activeConns {
|
||||
// move info and localUser to be protected by conn mutex?
|
||||
if c.info == nil || c.localUser == nil {
|
||||
ci, lu := c.getInfoAndLocalUser()
|
||||
if ci == nil || lu == nil {
|
||||
// c.info or c.localUser are nil when the connection hasn't been
|
||||
// authenticated yet. We will continue here, but the connection will
|
||||
// be rechecked once it is authenticated. If it no longer conforms
|
||||
// be checked once it is authenticated. If it no longer conforms
|
||||
// with the SSH access policy at that point, it will be terminated.
|
||||
continue
|
||||
}
|
||||
@ -242,9 +242,7 @@ type conn struct {
|
||||
action0 *tailcfg.SSHAction // set by clientAuth
|
||||
finalAction *tailcfg.SSHAction // set by clientAuth
|
||||
|
||||
info *sshConnInfo // set by setInfo
|
||||
localUser *userMeta // set by clientAuth
|
||||
userGroupIDs []string // set by clientAuth
|
||||
userGroupIDs []string // set by clientAuth
|
||||
acceptEnv []string
|
||||
|
||||
// mu protects the following fields.
|
||||
@ -252,8 +250,10 @@ type conn struct {
|
||||
// srv.mu should be acquired prior to mu.
|
||||
// It is safe to just acquire mu, but unsafe to
|
||||
// acquire mu and then srv.mu.
|
||||
mu sync.Mutex // protects the following
|
||||
sessions []*sshSession
|
||||
mu sync.Mutex // protects the following
|
||||
info *sshConnInfo // set by setInfo
|
||||
localUser *userMeta // set by clientAuth
|
||||
sessions []*sshSession
|
||||
}
|
||||
|
||||
func (c *conn) logf(format string, args ...any) {
|
||||
@ -267,6 +267,24 @@ func (c *conn) vlogf(format string, args ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) getInfo() *sshConnInfo {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.info
|
||||
}
|
||||
|
||||
func (c *conn) getLocalUser() *userMeta {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.localUser
|
||||
}
|
||||
|
||||
func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.info, c.localUser
|
||||
}
|
||||
|
||||
// errDenied is returned by auth callbacks when a connection is denied by the
|
||||
// policy. It writes the message to an auth banner and then returns an empty
|
||||
// gossh.PartialSuccessError in order to stop processing authentication
|
||||
@ -337,7 +355,12 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
|
||||
case accepted:
|
||||
// do nothing
|
||||
case rejectedUser:
|
||||
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", c.info.sshUser), nil)
|
||||
ci := c.getInfo()
|
||||
if ci != nil {
|
||||
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil)
|
||||
} else {
|
||||
return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil)
|
||||
}
|
||||
case rejected, noPolicy:
|
||||
return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result))
|
||||
default:
|
||||
@ -358,7 +381,9 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE
|
||||
return nil, c.errBanner("failed to look up local user's group IDs", err)
|
||||
}
|
||||
c.userGroupIDs = gids
|
||||
c.mu.Lock()
|
||||
c.localUser = lu
|
||||
c.mu.Unlock()
|
||||
c.acceptEnv = acceptEnv
|
||||
}
|
||||
|
||||
@ -585,9 +610,6 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
|
||||
// connInfo populates the sshConnInfo from the provided arguments,
|
||||
// validating only that they represent a known Tailscale identity.
|
||||
func (c *conn) setInfo(cm gossh.ConnMetadata) error {
|
||||
if c.info != nil {
|
||||
return nil
|
||||
}
|
||||
ci := &sshConnInfo{
|
||||
sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix),
|
||||
src: toIPPort(cm.RemoteAddr()),
|
||||
@ -606,6 +628,11 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
|
||||
ci.node = node
|
||||
ci.uprof = uprof
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.info != nil {
|
||||
return nil
|
||||
}
|
||||
c.idH = string(cm.SessionID())
|
||||
c.info = ci
|
||||
c.logf("handling conn: %v", ci.String())
|
||||
@ -653,8 +680,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
||||
}
|
||||
|
||||
ss := c.newSSHSession(s)
|
||||
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username)
|
||||
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
|
||||
ci, lu := c.getInfoAndLocalUser()
|
||||
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", ci.uprof.LoginName, ci.src.Addr(), lu.Username)
|
||||
ss.logf("access granted to %v as ssh-user %q", ci.uprof.LoginName, lu.Username)
|
||||
|
||||
if f, ok := hookSSHLoginSuccess.GetOk(); ok {
|
||||
f(c.srv.logf, c)
|
||||
@ -665,8 +693,7 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
||||
|
||||
func (c *conn) expandDelegateURLLocked(actionURL string) string {
|
||||
nm := c.srv.lb.NetMap()
|
||||
ci := c.info
|
||||
lu := c.localUser
|
||||
ci, lu := c.getInfoAndLocalUser()
|
||||
var dstNodeID string
|
||||
if nm != nil {
|
||||
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID()))
|
||||
@ -739,7 +766,8 @@ func (c *conn) isStillValid() bool {
|
||||
if !a.Accept && a.HoldAndDelegate == "" {
|
||||
return false
|
||||
}
|
||||
return c.localUser.Username == localUser
|
||||
lu := c.getLocalUser()
|
||||
return lu != nil && lu.Username == localUser
|
||||
}
|
||||
|
||||
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
|
||||
@ -813,7 +841,12 @@ func (ss *sshSession) killProcessOnContextDone() {
|
||||
io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
|
||||
}
|
||||
}
|
||||
ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.Addr(), err)
|
||||
ci := ss.conn.getInfo()
|
||||
if ci != nil {
|
||||
ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err)
|
||||
} else {
|
||||
ss.logf("terminating SSH session: %v", err)
|
||||
}
|
||||
// We don't need to Process.Wait here, sshSession.run() does
|
||||
// the waiting regardless of termination reason.
|
||||
|
||||
@ -915,7 +948,7 @@ func (ss *sshSession) run() {
|
||||
}
|
||||
defer ss.conn.detachSession(ss)
|
||||
|
||||
lu := ss.conn.localUser
|
||||
lu := ss.conn.getLocalUser()
|
||||
logf := ss.logf
|
||||
|
||||
if ss.conn.finalAction.SessionDuration != 0 {
|
||||
@ -1148,7 +1181,8 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
|
||||
if c == nil {
|
||||
return nil, "", nil, errInvalidConn
|
||||
}
|
||||
if c.info == nil {
|
||||
ci := c.getInfo()
|
||||
if ci == nil {
|
||||
c.logf("invalid connection state")
|
||||
return nil, "", nil, errInvalidConn
|
||||
}
|
||||
@ -1168,7 +1202,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st
|
||||
// For all but Reject rules, SSHUsers is required.
|
||||
// If SSHUsers is nil or empty, mapLocalUser will return an
|
||||
// empty string anyway.
|
||||
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
|
||||
localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
|
||||
if localUser == "" {
|
||||
return nil, "", nil, errUserMatch
|
||||
}
|
||||
@ -1202,7 +1236,10 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool {
|
||||
// principalMatchesTailscaleIdentity reports whether one of p's four fields
|
||||
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
|
||||
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
|
||||
ci := c.info
|
||||
ci := c.getInfo()
|
||||
if ci == nil {
|
||||
return false
|
||||
}
|
||||
if p.Any {
|
||||
return true
|
||||
}
|
||||
@ -1348,6 +1385,10 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
}()
|
||||
}
|
||||
|
||||
ci, lu := ss.conn.getInfoAndLocalUser()
|
||||
if ci == nil || lu == nil {
|
||||
return nil, errors.New("recording: missing connection metadata")
|
||||
}
|
||||
ch := sessionrecording.CastHeader{
|
||||
Version: 2,
|
||||
Width: w.Width,
|
||||
@ -1365,17 +1406,17 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
// it. Then we can (1) make the cmd, (2) start the
|
||||
// recording, (3) start the process.
|
||||
},
|
||||
SSHUser: ss.conn.info.sshUser,
|
||||
LocalUser: ss.conn.localUser.Username,
|
||||
SrcNode: strings.TrimSuffix(ss.conn.info.node.Name(), "."),
|
||||
SrcNodeID: ss.conn.info.node.StableID(),
|
||||
SSHUser: ci.sshUser,
|
||||
LocalUser: lu.Username,
|
||||
SrcNode: strings.TrimSuffix(ci.node.Name(), "."),
|
||||
SrcNodeID: ci.node.StableID(),
|
||||
ConnectionID: ss.conn.connID,
|
||||
}
|
||||
if !ss.conn.info.node.IsTagged() {
|
||||
ch.SrcNodeUser = ss.conn.info.uprof.LoginName
|
||||
ch.SrcNodeUserID = ss.conn.info.node.User()
|
||||
if !ci.node.IsTagged() {
|
||||
ch.SrcNodeUser = ci.uprof.LoginName
|
||||
ch.SrcNodeUserID = ci.node.User()
|
||||
} else {
|
||||
ch.SrcNodeTags = ss.conn.info.node.Tags().AsSlice()
|
||||
ch.SrcNodeTags = ci.node.Tags().AsSlice()
|
||||
}
|
||||
j, err := json.Marshal(ch)
|
||||
if err != nil {
|
||||
@ -1398,14 +1439,19 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
// A SSHEventNotifyRequest is sent when an action or state reached during
|
||||
// an SSH session is a defined EventType.
|
||||
func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) {
|
||||
ci, lu := ss.conn.getInfoAndLocalUser()
|
||||
if ci == nil || lu == nil {
|
||||
ss.logf("notifyControl: missing connection metadata")
|
||||
return
|
||||
}
|
||||
re := tailcfg.SSHEventNotifyRequest{
|
||||
EventType: notifyType,
|
||||
ConnectionID: ss.conn.connID,
|
||||
CapVersion: tailcfg.CurrentCapabilityVersion,
|
||||
NodeKey: nodeKey,
|
||||
SrcNode: ss.conn.info.node.ID(),
|
||||
SSHUser: ss.conn.info.sshUser,
|
||||
LocalUser: ss.conn.localUser.Username,
|
||||
SrcNode: ci.node.ID(),
|
||||
SSHUser: ci.sshUser,
|
||||
LocalUser: lu.Username,
|
||||
RecordingAttempts: attempts,
|
||||
}
|
||||
|
||||
|
||||
@ -1339,6 +1339,57 @@ func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
srv := &server{
|
||||
logf: tstest.WhileTestRunningLogger(t),
|
||||
lb: &localState{
|
||||
sshEnabled: true,
|
||||
matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}),
|
||||
},
|
||||
}
|
||||
c := &conn{
|
||||
srv: srv,
|
||||
info: &sshConnInfo{sshUser: "alice"},
|
||||
}
|
||||
srv.activeConns = map[*conn]bool{c: true}
|
||||
|
||||
fakeClientAuth := func() {
|
||||
c.mu.Lock()
|
||||
c.info = &sshConnInfo{sshUser: "alice"}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
c.localUser = &userMeta{User: user.User{Username: currentUser}}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
fakeClientAuth()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
srv.OnPolicyChange()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user