diff --git a/ssh/tailssh/auditd_linux.go b/ssh/tailssh/auditd_linux.go index bddb901d5..a0295ca37 100644 --- a/ssh/tailssh/auditd_linux.go +++ b/ssh/tailssh/auditd_linux.go @@ -123,7 +123,7 @@ func sendAuditMessage(logf logger.Logf, msgType uint16, message string) { // logSSHLogin logs an SSH login event to auditd with whois information. func logSSHLogin(logf logger.Logf, c *conn) { - if c == nil || c.info == nil || c.localUser == nil { + if c == nil { return } diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index b414ce3fb..b31330177 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -1099,7 +1099,7 @@ func (ss *sshSession) startWithStdPipes() (err error) { return ss.cmd.Start() } -func envForUser(u *userMeta) []string { +func envForUser(u userMeta) []string { return []string{ fmt.Sprintf("SHELL=%s", u.LoginShell()), fmt.Sprintf("USER=%s", u.Username), diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go index 69112635f..a9b9a1163 100644 --- a/ssh/tailssh/incubator_plan9.go +++ b/ssh/tailssh/incubator_plan9.go @@ -400,7 +400,7 @@ func (ss *sshSession) startWithStdPipes() (err error) { return ss.cmd.Start() } -func envForUser(u *userMeta) []string { +func envForUser(u userMeta) []string { return []string{ fmt.Sprintf("user=%s", u.Username), fmt.Sprintf("home=%s", u.HomeDir), diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index d8dea7da2..d3060a067 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -193,8 +193,8 @@ func (srv *server) OnPolicyChange() { defer srv.mu.Unlock() for c := range srv.activeConns { ci, lu := c.getInfoAndLocalUser() - if ci == nil || lu == nil { - // c.info or c.localUser are nil when the connection hasn't been + if ci.sshUser == "" || lu.Username == "" { + // c.info or c.localUser are empty when the connection hasn't been // authenticated yet. We will continue here, but the connection will // be checked once it is authenticated. If it no longer conforms // with the SSH access policy at that point, it will be terminated. @@ -250,9 +250,9 @@ type conn struct { // srv.mu should be acquired prior to mu. // It is safe to just acquire mu, but unsafe to // acquire mu and then srv.mu. - mu sync.Mutex // protects the following - info *sshConnInfo // set by setInfo - localUser *userMeta // set by clientAuth + mu sync.Mutex // protects the following + info sshConnInfo // set by setInfo + localUser userMeta // set by clientAuth sessions []*sshSession } @@ -267,19 +267,19 @@ func (c *conn) vlogf(format string, args ...any) { } } -func (c *conn) getInfo() *sshConnInfo { +func (c *conn) getInfo() sshConnInfo { c.mu.Lock() defer c.mu.Unlock() return c.info } -func (c *conn) getLocalUser() *userMeta { +func (c *conn) getLocalUser() userMeta { c.mu.Lock() defer c.mu.Unlock() return c.localUser } -func (c *conn) getInfoAndLocalUser() (*sshConnInfo, *userMeta) { +func (c *conn) getInfoAndLocalUser() (sshConnInfo, userMeta) { c.mu.Lock() defer c.mu.Unlock() return c.info, c.localUser @@ -356,11 +356,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE // do nothing case rejectedUser: ci := c.getInfo() - if ci != nil { - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) - } else { - return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH"), nil) - } + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", ci.sshUser), nil) case rejected, noPolicy: return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result)) default: @@ -610,7 +606,7 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. func (c *conn) setInfo(cm gossh.ConnMetadata) error { - ci := &sshConnInfo{ + ci := sshConnInfo{ sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), src: toIPPort(cm.RemoteAddr()), dst: toIPPort(cm.LocalAddr()), @@ -630,7 +626,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { c.mu.Lock() defer c.mu.Unlock() - if c.info != nil { + if c.info.sshUser != "" { return nil } c.idH = string(cm.SessionID()) @@ -767,7 +763,7 @@ func (c *conn) isStillValid() bool { return false } lu := c.getLocalUser() - return lu != nil && lu.Username == localUser + return lu.Username == localUser } // checkStillValid checks that the conn is still valid per the latest SSHPolicy. @@ -842,11 +838,7 @@ func (ss *sshSession) killProcessOnContextDone() { } } ci := ss.conn.getInfo() - if ci != nil { - ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) - } else { - ss.logf("terminating SSH session: %v", err) - } + ss.logf("terminating SSH session from %v: %v", ci.src.Addr(), err) // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. @@ -884,7 +876,7 @@ var errSessionDone = errors.New("session is done") // handleSSHAgentForwarding starts a Unix socket listener and in the background // forwards agent connections between the listener and the ssh.Session. // On success, it assigns ss.agentListener. -func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error { +func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu userMeta) error { if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding { return nil } @@ -1182,10 +1174,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser st return nil, "", nil, errInvalidConn } ci := c.getInfo() - if ci == nil { - c.logf("invalid connection state") - return nil, "", nil, errInvalidConn - } if r == nil { return nil, "", nil, errNilRule } @@ -1237,9 +1225,6 @@ func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { ci := c.getInfo() - if ci == nil { - return false - } if p.Any { return true } @@ -1386,9 +1371,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { } ci, lu := ss.conn.getInfoAndLocalUser() - if ci == nil || lu == nil { - return nil, errors.New("recording: missing connection metadata") - } ch := sessionrecording.CastHeader{ Version: 2, Width: w.Width, @@ -1440,10 +1422,6 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { // an SSH session is a defined EventType. func (ss *sshSession) notifyControl(ctx context.Context, nodeKey key.NodePublic, notifyType tailcfg.SSHEventType, attempts []*tailcfg.SSHRecordingAttempt, url string) { ci, lu := ss.conn.getInfoAndLocalUser() - if ci == nil || lu == nil { - ss.logf("notifyControl: missing connection metadata") - return - } re := tailcfg.SSHEventNotifyRequest{ EventType: notifyType, ConnectionID: ss.conn.connID, diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 581c8be82..18252de03 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -64,31 +64,20 @@ func TestMatchRule(t *testing.T) { tests := []struct { name string rule *tailcfg.SSHRule - ci *sshConnInfo + ci sshConnInfo wantErr error wantUser string wantAcceptEnv []string }{ - { - name: "invalid-conn", - rule: &tailcfg.SSHRule{ - Action: someAction, - Principals: []*tailcfg.SSHPrincipal{{Any: true}}, - SSHUsers: map[string]string{ - "*": "ubuntu", - }, - }, - wantErr: errInvalidConn, - }, { name: "nil-rule", - ci: &sshConnInfo{}, + ci: sshConnInfo{}, rule: nil, wantErr: errNilRule, }, { name: "nil-action", - ci: &sshConnInfo{}, + ci: sshConnInfo{}, rule: &tailcfg.SSHRule{}, wantErr: errNilAction, }, @@ -98,7 +87,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, RuleExpires: ptr.To(time.Unix(100, 0)), }, - ci: &sshConnInfo{}, + ci: sshConnInfo{}, wantErr: errRuleExpired, }, { @@ -108,7 +97,7 @@ func TestMatchRule(t *testing.T) { SSHUsers: map[string]string{ "*": "ubuntu", }}, - ci: &sshConnInfo{}, + ci: sshConnInfo{}, wantErr: errPrincipalMatch, }, { @@ -117,7 +106,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, Principals: []*tailcfg.SSHPrincipal{{Any: true}}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantErr: errUserMatch, }, { @@ -129,7 +118,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -144,7 +133,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -157,7 +146,7 @@ func TestMatchRule(t *testing.T) { "alice": "thealice", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", }, { @@ -171,7 +160,7 @@ func TestMatchRule(t *testing.T) { }, AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, }, @@ -181,7 +170,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Action: &tailcfg.SSHAction{Reject: true}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, }, { name: "match-principal-node-ip", @@ -190,7 +179,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")}, + ci: sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")}, wantUser: "ubuntu", }, { @@ -200,7 +189,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()}, + ci: sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()}, wantUser: "ubuntu", }, { @@ -210,7 +199,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}}, + ci: sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}}, wantUser: "ubuntu", }, { @@ -222,7 +211,7 @@ func TestMatchRule(t *testing.T) { "*": "=", }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "alice", }, } @@ -254,7 +243,7 @@ func TestEvalSSHPolicy(t *testing.T) { tests := []struct { name string policy *tailcfg.SSHPolicy - ci *sshConnInfo + ci sshConnInfo wantResult evalResult wantUser string wantAcceptEnv []string @@ -298,7 +287,7 @@ func TestEvalSSHPolicy(t *testing.T) { }, }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "thealice", wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, wantResult: accepted, @@ -308,7 +297,7 @@ func TestEvalSSHPolicy(t *testing.T) { policy: &tailcfg.SSHPolicy{ Rules: []*tailcfg.SSHRule{}, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "", wantAcceptEnv: nil, wantResult: rejected, @@ -349,7 +338,7 @@ func TestEvalSSHPolicy(t *testing.T) { }, }, }, - ci: &sshConnInfo{sshUser: "alice"}, + ci: sshConnInfo{sshUser: "alice"}, wantUser: "", wantAcceptEnv: nil, wantResult: rejectedUser, @@ -1100,7 +1089,7 @@ func TestSSH(t *testing.T) { t.Fatal(err) } sc.localUser = um - sc.info = &sshConnInfo{ + sc.info = sshConnInfo{ sshUser: "test", src: netip.MustParseAddrPort("1.2.3.4:32342"), dst: netip.MustParseAddrPort("1.2.3.5:22"), @@ -1318,76 +1307,71 @@ func TestStdOsUserUserAssumptions(t *testing.T) { } } -func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - srv := &server{ - logf: tstest.WhileTestRunningLogger(t), - lb: &localState{ - sshEnabled: true, - matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), - }, - } - c := &conn{ - srv: srv, - info: &sshConnInfo{sshUser: "alice"}, - } - srv.activeConns = map[*conn]bool{c: true} +func TestOnPolicyChangeDefersValidationOnEmptyLocalUser(t *testing.T) { + tests := []struct { + name string + sshRule *tailcfg.SSHRule + wantCancelOnValidation bool + }{ + { + name: "defer-then-accept-when-allowed", + sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + wantCancelOnValidation: false, + }, + { + name: "defer-then-reject-when-not-allowed", + sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}), + wantCancelOnValidation: true, + }, + } - srv.OnPolicyChange() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { - synctest.Wait() - }) -} - -func TestRaceWriteAndReadConnInfoAndLocalUser(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - srv := &server{ - logf: tstest.WhileTestRunningLogger(t), - lb: &localState{ - sshEnabled: true, - matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), - }, - } - c := &conn{ - srv: srv, - info: &sshConnInfo{sshUser: "alice"}, - } - srv.activeConns = map[*conn]bool{c: true} - - fakeClientAuth := func() { - c.mu.Lock() - c.info = &sshConnInfo{sshUser: "alice"} - c.mu.Unlock() - - c.mu.Lock() - c.localUser = &userMeta{User: user.User{Username: currentUser}} - c.mu.Unlock() - } - - // Simulate a race between clientAuth() writing and OnPolicyChange reading a connection's info and localUser. - done := make(chan struct{}) - go func() { - for i := 0; i < 100; i++ { - select { - case <-done: - return - default: - fakeClientAuth() + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: tt.sshRule, + }, } - } - }() - - go func() { - for i := 0; i < 100; i++ { - select { - case <-done: - return - default: - srv.OnPolicyChange() + c := &conn{ + srv: srv, + info: sshConnInfo{sshUser: "alice"}, } - } - }() - }) + srv.activeConns = map[*conn]bool{c: true} + ctx, cancel := context.WithCancelCause(context.Background()) + ss := &sshSession{ctx: ctx, cancelCtx: cancel} + c.sessions = []*sshSession{ss} + + srv.OnPolicyChange() + synctest.Wait() + select { + case <-ctx.Done(): + t.Fatal("expected deferral of cancellation decision while localUser unset but session got canceled") + default: + } + + c.mu.Lock() + c.localUser = userMeta{User: user.User{Username: currentUser}} + c.mu.Unlock() + + srv.OnPolicyChange() + synctest.Wait() + select { + case <-ctx.Done(): + if !tt.wantCancelOnValidation { + t.Fatal("valid session shouldn't have been canceled") + } + default: + if tt.wantCancelOnValidation { + t.Fatal("invalid session should have been canceled but it wasn't") + } + } + }) + }) + } } func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go index 7da6bb4eb..c84c93821 100644 --- a/ssh/tailssh/user.go +++ b/ssh/tailssh/user.go @@ -36,15 +36,15 @@ func (u *userMeta) GroupIds() ([]string, error) { return osuser.GetGroupIds(&u.User) } -// userLookup is like os/user.Lookup but it returns a *userMeta wrapper +// userLookup is like os/user.Lookup but it returns a userMeta wrapper // around a *user.User with extra fields. -func userLookup(username string) (*userMeta, error) { +func userLookup(username string) (userMeta, error) { u, s, err := osuser.LookupByUsernameWithShell(username) if err != nil { - return nil, err + return userMeta{}, err } - return &userMeta{User: *u, loginShellCached: s}, nil + return userMeta{User: *u, loginShellCached: s}, nil } func (u *userMeta) LoginShell() string {