diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9d5a7d2a8..2c90dc83c 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -192,9 +192,12 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { - if c.info == nil { - // c.info is nil when the connection hasn't been authenticated yet. - // In that case, the connection will be terminated when it is. + // move info and localUser to be protected by conn mutex? + if c.info == nil || c.localUser == nil { + // c.info or c.localUser are nil when the connection hasn't been + // authenticated yet. We will continue here, but the connection will + // be rechecked once it is authenticated. If it no longer conforms + // with the SSH access policy at that point, it will be terminated. continue } go c.checkStillValid() diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index f91cbafe7..b9e591d80 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -31,6 +31,7 @@ import ( "sync" "sync/atomic" "testing" + "testing/synctest" "time" gossh "golang.org/x/crypto/ssh" @@ -1317,6 +1318,27 @@ func TestStdOsUserUserAssumptions(t *testing.T) { } } +func TestOnPolicyChangeHandlesNilLocalUser(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + }, + } + c := &conn{ + srv: srv, + info: &sshConnInfo{sshUser: "alice"}, + } + srv.activeConns = map[*conn]bool{c: true} + + srv.OnPolicyChange() + + synctest.Wait() + }) +} + func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux()