diff --git a/vault/cluster.go b/vault/cluster.go index 5e6e9974c3..7b330f91db 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -90,16 +90,26 @@ func (c *Core) Cluster() (*Cluster, error) { // It also ensures the cert is in our local cluster cert pool. func (c *Core) loadClusterTLS(adv activeAdvertisement) error { switch { + case adv.ClusterAddr == "": + // Clustering disabled on the server, don't try to look for params + return nil + + case adv.ClusterKeyParams == nil: + c.logger.Printf("[ERR] core/loadClusterTLS: no key params found") + return fmt.Errorf("no local cluster key params found") + case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil: - c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed to parse local cluster key due to missing params") + c.logger.Printf("[ERR] core/loadClusterTLS: failed to parse local cluster key due to missing params") return fmt.Errorf("failed to parse local cluster key") - case adv.ClusterKeyParams.Type == corePrivateKeyTypeP521: - // Nothing, this is what we want - - default: - c.logger.Printf("[ERR] core/loadClusterPrivateKey: unknown local cluster key type %v", adv.ClusterKeyParams.Type) + case adv.ClusterKeyParams.Type != corePrivateKeyTypeP521: + c.logger.Printf("[ERR] core/loadClusterTLS: unknown local cluster key type %v", adv.ClusterKeyParams.Type) return fmt.Errorf("failed to find valid local cluster key type") + + case adv.ClusterCert == nil || len(adv.ClusterCert) == 0: + c.logger.Printf("[ERR] core/loadClusterTLS: no local cluster cert found") + return fmt.Errorf("no local cluster cert found") + } // Prevent data races with the TLS parameters @@ -119,7 +129,7 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error { cert, err := x509.ParseCertificate(c.localClusterCert) if err != nil { - c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed parsing local cluster certificate: %v", err) + c.logger.Printf("[ERR] core/loadClusterTLS: failed parsing local cluster certificate: %v", err) return fmt.Errorf("error parsing local cluster certificate: %v", err) } diff --git a/vault/core.go b/vault/core.go index 09b0408e5d..2b0fb8f130 100644 --- a/vault/core.go +++ b/vault/core.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto" "crypto/ecdsa" - "crypto/sha256" "crypto/x509" "errors" "fmt" @@ -114,10 +113,10 @@ func (e *ErrInvalidKey) Error() string { } type activeAdvertisement struct { - RedirectAddr string `json:"redirect_addr"` - ClusterAddr string `json:"cluster_addr"` - ClusterCert []byte `json:"cluster_cert"` - ClusterKeyParams clusterKeyParams `json:"cluster_key_params"` + RedirectAddr string `json:"redirect_addr"` + ClusterAddr string `json:"cluster_addr,omitempty"` + ClusterCert []byte `json:"cluster_cert,omitempty"` + ClusterKeyParams *clusterKeyParams `json:"cluster_key_params,omitempty"` } // Core is used as the central manager of Vault activity. It is the primary point of @@ -269,12 +268,11 @@ type Core struct { // Write lock used to ensure that we don't have multiple connections adjust // this value at the same time requestForwardingConnectionLock sync.RWMutex - // Most recent hashed value of the advertise/cluster info. Used to avoid - // repeatedly JSON parsing the same values. - clusterActiveAdvertisementHash []byte - // Cache of most recently known active advertisement information, used to - // return values when the hash matches - clusterActiveAdvertisement activeAdvertisement + // Most recent leader UUID. Used to avoid repeatedly JSON parsing the same + // values. + clusterLeaderUUID string + // Most recent leader redirect addr + clusterLeaderRedirectAddr string // The grpc Server that handles server RPC calls rpcServer *grpc.Server // The function for canceling the client connection @@ -626,7 +624,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { } // Read the value - held, value, err := lock.Value() + held, leaderUUID, err := lock.Value() if err != nil { return false, "", err } @@ -634,8 +632,13 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { return false, "", nil } - // Value is the UUID of the leader, fetch the key - key := coreLeaderPrefix + value + // If the leader hasn't changed, return the cached value; nothing changes + // mid-leadership, and the barrier caches anyways + if leaderUUID == c.clusterLeaderUUID && c.clusterLeaderRedirectAddr != "" { + return false, c.clusterLeaderRedirectAddr, nil + } + + key := coreLeaderPrefix + leaderUUID entry, err := c.barrier.Get(key) if err != nil { return false, "", err @@ -644,26 +647,14 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { return false, "", nil } - entrySHA256 := sha256.Sum256(entry.Value) - - // Avoid JSON parsing and function calling if nothing has changed - if c.clusterActiveAdvertisementHash != nil { - if bytes.Compare(entrySHA256[:], c.clusterActiveAdvertisementHash) == 0 { - return false, c.clusterActiveAdvertisement.RedirectAddr, nil - } - } - - var advAddr string var oldAdv bool var adv activeAdvertisement err = jsonutil.DecodeJSON(entry.Value, &adv) if err != nil { // Fall back to pre-struct handling - advAddr = string(entry.Value) + adv.RedirectAddr = string(entry.Value) oldAdv = true - } else { - advAddr = adv.RedirectAddr } if !oldAdv { @@ -681,10 +672,12 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { } } - c.clusterActiveAdvertisement = adv - c.clusterActiveAdvertisementHash = entrySHA256[:] + // Don't set these until everything has been parsed successfully or we'll + // never try again + c.clusterLeaderRedirectAddr = adv.RedirectAddr + c.clusterLeaderUUID = leaderUUID - return false, advAddr, nil + return false, c.clusterLeaderRedirectAddr, nil } // SecretProgress returns the number of keys provided so far @@ -1409,7 +1402,7 @@ func (c *Core) advertiseLeader(uuid string, leaderLostCh <-chan struct{}) error return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey) } - keyParams := clusterKeyParams{ + keyParams := &clusterKeyParams{ Type: corePrivateKeyTypeP521, X: key.X, Y: key.Y, diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 79a0923072..7086c166d9 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -25,16 +25,7 @@ const ( // Starts the listeners and servers necessary to handle forwarded requests func (c *Core) startForwarding() error { // Clean up in case we have transitioned from a client to a server - c.requestForwardingConnection = nil - c.rpcForwardingClient = nil - if c.rpcClientConnCancelFunc != nil { - c.rpcClientConnCancelFunc() - c.rpcClientConnCancelFunc = nil - } - if c.rpcClientConn != nil { - c.rpcClientConn.Close() - c.rpcClientConn = nil - } + c.clearForwardingClients() // Get our base handler (for our RPC server) and our wrapped handler (for // straight HTTP/2 forwarding) @@ -191,16 +182,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error { // Disabled, potentially, so clean up anything that might be around. if clusterAddr == "" { - c.requestForwardingConnection = nil - c.rpcForwardingClient = nil - if c.rpcClientConnCancelFunc != nil { - c.rpcClientConnCancelFunc() - c.rpcClientConnCancelFunc = nil - } - if c.rpcClientConn != nil { - c.rpcClientConn.Close() - c.rpcClientConn = nil - } + c.clearForwardingClients() return nil } @@ -244,6 +226,25 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error { return nil } +func (c *Core) clearForwardingClients() { + if c.requestForwardingConnection != nil { + c.requestForwardingConnection.transport.CloseIdleConnections() + c.requestForwardingConnection = nil + } + + c.rpcForwardingClient = nil + + if c.rpcClientConnCancelFunc != nil { + c.rpcClientConnCancelFunc() + c.rpcClientConnCancelFunc = nil + } + + if c.rpcClientConn != nil { + c.rpcClientConn.Close() + c.rpcClientConn = nil + } +} + // ForwardRequest forwards a given request to the active node and returns the // response. func (c *Core) ForwardRequest(req *http.Request) (int, []byte, error) {