From 79483a1e5e86ddfe9c59d760809b0fea830ddc84 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 1 Apr 2022 12:57:12 -0700 Subject: [PATCH] tailcfg, ssh/tailssh: optionally support SSH public keys in wire policy Updates #3802 Change-Id: I756dc2d579a16757537142283d791f1d0319f4f0 Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/tailssh.go | 174 ++++++++++++++++++++++++++++++------ ssh/tailssh/tailssh_test.go | 5 +- tailcfg/tailcfg.go | 16 ++-- 3 files changed, 164 insertions(+), 31 deletions(-) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 3fcd87d09..64e10e132 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -9,8 +9,10 @@ package tailssh import ( + "bytes" "context" "crypto/rand" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -77,9 +79,15 @@ func (srv *server) newSSHServer() (*ssh.Server, error) { Version: "SSH-2.0-Tailscale", LocalPortForwardingCallback: srv.mayForwardLocalPortTo, NoClientAuthCallback: func(m gossh.ConnMetadata) (*gossh.Permissions, error) { - srv.logf("SSH connection from %v for %q; client ver %q", m.RemoteAddr(), m.User(), m.ClientVersion()) + if srv.askForCert(m.User(), m.LocalAddr(), m.RemoteAddr()) { + return nil, errors.New("cert required") // any non-nil error will do + } return nil, nil }, + PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { + srv.logf("SSH public key %T %#v", key, key) + return true // rejected later, after accepting connections + }, } for k, v := range ssh.DefaultRequestHandlers { ss.RequestHandlers[k] = v @@ -124,6 +132,31 @@ func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string return ss.action.AllowLocalPortForwarding } +// askForCert reports whether the SSH server, during the auth negotiation phase, +// should requires that the client send an SSH cert. +func (srv *server) askForCert(sshUser string, localAddr, remoteAddr net.Addr) bool { + pol, ok := srv.sshPolicy() + if !ok { + return false + } + a, ci, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, nil) + if err == nil && (a.Accept || a.HoldAndDelegate != "") { + // Policy doesn't require a cert. + return false + } + + // Is there any rule that looks like it'd require a cert for + // this sshUser? + for _, r := range pol.Rules { + for _, p := range r.Principals { + if principalMatchesTailscaleIdentity(p, ci) && len(p.Certs) > 0 { + return true + } + } + } + return false +} + // sshPolicy returns the SSHPolicy for current node. // If there is no SSHPolicy in the netmap, it returns a debugPolicy // if one is defined. @@ -170,7 +203,7 @@ func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) { // evaluatePolicy returns the SSHAction, sshConnInfo and localUser // after evaluating the sshUser and remoteAddr against the SSHPolicy. // The remoteAddr and localAddr params must be Tailscale IPs. -func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) { +func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) { logf := srv.logf lb := srv.lb logf("Handling SSH from %v for user %v", remoteAddr, sshUser) @@ -194,12 +227,14 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr } ci := &sshConnInfo{ - now: time.Now(), - sshUser: sshUser, - src: srcIPP, - dst: dstIPP, - node: node, - uprof: &uprof, + now: time.Now(), + fetchPublicKeysURL: srv.fetchPublicKeysURL, + sshUser: sshUser, + src: srcIPP, + dst: dstIPP, + node: node, + uprof: &uprof, + pubKey: pubKey, } a, localUser, ok := evalSSHPolicy(pol, ci) if !ok { @@ -208,12 +243,36 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr return a, ci, localUser, nil } +func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { + if !strings.HasPrefix(url, "https://") { + return nil, errors.New("invalid URL scheme") + } + // TODO(bradfitz): add caching + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, errors.New(res.Status) + } + all, err := io.ReadAll(io.LimitReader(res.Body, 1<<10)) + return strings.Split(string(all), "\n"), err +} + // handleSSH is invoked when a new SSH connection attempt is made. func (srv *server) handleSSH(s ssh.Session) { logf := srv.logf sshUser := s.User() - action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr()) + action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr(), s.PublicKey()) if err != nil { logf(err.Error()) s.Exit(1) @@ -609,6 +668,10 @@ type sshConnInfo struct { // now is the time to consider the present moment for the // purposes of rule evaluation. now time.Time + // fetchPublicKeysURL, if non-nil, is a func to fetch the public + // keys of a URL. The strings are in the the typical public + // key "type base64-string [comment]" format seen at e.g. https://github.com/USER.keys + fetchPublicKeysURL func(url string) ([]string, error) // sshUser is the requested local SSH username ("root", "alice", etc). sshUser string @@ -624,6 +687,11 @@ type sshConnInfo struct { // uprof is node's UserProfile. uprof *tailcfg.UserProfile + + // pubKey is the public key presented by the client, or nil + // if they haven't yet sent one (as in the early "none" phase + // of authentication negotiation). + pubKey ssh.PublicKey } func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) { @@ -654,15 +722,15 @@ func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, local if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) { return nil, "", errRuleExpired } - if !matchesPrincipal(r.Principals, ci) { - return nil, "", errPrincipalMatch - } if !r.Action.Reject || r.SSHUsers != nil { localUser = mapLocalUser(r.SSHUsers, ci.sshUser) if localUser == "" { return nil, "", errUserMatch } } + if !anyPrincipalMatches(r.Principals, ci) { + return nil, "", errPrincipalMatch + } return r.Action, localUser, nil } @@ -677,29 +745,85 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser return v } -func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool { +func anyPrincipalMatches(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool { for _, p := range ps { if p == nil { continue } - if p.Any { - return true - } - if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID { - return true - } - if p.NodeIP != "" { - if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() { - return true - } - } - if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin { + if principalMatches(p, ci) { return true } } return false } +func principalMatches(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool { + return principalMatchesTailscaleIdentity(p, ci) && + principalMatchesCert(p, ci) +} + +// principalMatchesTailscaleIdentity reports whether one of p's four fields +// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). +// This function does not consider Certs. +func principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool { + if p.Any { + return true + } + if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID { + return true + } + if p.NodeIP != "" { + if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.src.IP() { + return true + } + } + if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin { + return true + } + return false +} + +func principalMatchesCert(p *tailcfg.SSHPrincipal, ci *sshConnInfo) bool { + if len(p.Certs) == 0 { + return true + } + if ci.pubKey == nil { + return false + } + certs := p.Certs + if len(certs) == 1 && strings.HasPrefix(certs[0], "https://") { + if ci.fetchPublicKeysURL == nil { + // TODO: log? + return false + } + var err error + certs, err = ci.fetchPublicKeysURL(certs[0]) + if err != nil { + // TODO: log? + return false + } + } + for _, cert := range certs { + if pubKeyMatchesAuthorizedKey(ci.pubKey, cert) { + return true + } + } + return false +} + +func pubKeyMatchesAuthorizedKey(pubKey ssh.PublicKey, wantKey string) bool { + wantKeyType, rest, ok := strings.Cut(wantKey, " ") + if !ok { + return false + } + if pubKey.Type() != wantKeyType { + return false + } + wantKeyB64, _, _ := strings.Cut(rest, " ") + wantKeyData, _ := base64.StdEncoding.DecodeString(wantKeyB64) + return len(wantKeyData) > 0 && bytes.Equal(pubKey.Marshal(), wantKeyData) +} + func randBytes(n int) []byte { b := make([]byte, n) if _, err := rand.Read(b); err != nil { diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index afae4984a..23e2540a3 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -63,7 +63,10 @@ func TestMatchRule(t *testing.T) { name: "no-principal", rule: &tailcfg.SSHRule{ Action: someAction, - }, + SSHUsers: map[string]string{ + "*": "ubuntu", + }}, + ci: &sshConnInfo{}, wantErr: errPrincipalMatch, }, { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index b39f04135..df433c114 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -1590,16 +1590,22 @@ type SSHRule struct { } // SSHPrincipal is either a particular node or a user on any node. -// Any matching field causes a match. type SSHPrincipal struct { + // Matching any one of the following four field causes a match. + // It must also match Certs, if non-empty. + Node StableNodeID `json:"node,omitempty"` NodeIP string `json:"nodeIP,omitempty"` UserLogin string `json:"userLogin,omitempty"` // email-ish: foo@example.com, bar@github - - // Any, if true, matches any user. - Any bool `json:"any,omitempty"` - + Any bool `json:"any,omitempty"` // if true, match any connection // TODO(bradfitz): add StableUserID, once that exists + + // Certs, if non-empty, means that this SSHPrincipal only + // matches if one of these certs is presented by the user. + // + // As a special case, if len(Certs) == 1 and Certs[0] starts + // with "https://", then it's fetched (like https://github.com/username.keys). + Certs []string `json:"certs,omitempty"` } // SSHAction is how to handle an incoming connection.