Cleanup and avoid unnecessary advertisement parsing in leader check

This commit is contained in:
Jeff Mitchell 2016-08-19 14:49:11 -04:00
parent 4bcf591dfa
commit f9c44a4458
3 changed files with 62 additions and 58 deletions

View File

@ -90,16 +90,26 @@ func (c *Core) Cluster() (*Cluster, error) {
// It also ensures the cert is in our local cluster cert pool. // It also ensures the cert is in our local cluster cert pool.
func (c *Core) loadClusterTLS(adv activeAdvertisement) error { func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
switch { switch {
case adv.ClusterAddr == "":
// Clustering disabled on the server, don't try to look for params
return nil
case adv.ClusterKeyParams == nil:
c.logger.Printf("[ERR] core/loadClusterTLS: no key params found")
return fmt.Errorf("no local cluster key params found")
case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil: case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil:
c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed to parse local cluster key due to missing params") c.logger.Printf("[ERR] core/loadClusterTLS: failed to parse local cluster key due to missing params")
return fmt.Errorf("failed to parse local cluster key") return fmt.Errorf("failed to parse local cluster key")
case adv.ClusterKeyParams.Type == corePrivateKeyTypeP521: case adv.ClusterKeyParams.Type != corePrivateKeyTypeP521:
// Nothing, this is what we want c.logger.Printf("[ERR] core/loadClusterTLS: unknown local cluster key type %v", adv.ClusterKeyParams.Type)
default:
c.logger.Printf("[ERR] core/loadClusterPrivateKey: unknown local cluster key type %v", adv.ClusterKeyParams.Type)
return fmt.Errorf("failed to find valid local cluster key type") return fmt.Errorf("failed to find valid local cluster key type")
case adv.ClusterCert == nil || len(adv.ClusterCert) == 0:
c.logger.Printf("[ERR] core/loadClusterTLS: no local cluster cert found")
return fmt.Errorf("no local cluster cert found")
} }
// Prevent data races with the TLS parameters // Prevent data races with the TLS parameters
@ -119,7 +129,7 @@ func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
cert, err := x509.ParseCertificate(c.localClusterCert) cert, err := x509.ParseCertificate(c.localClusterCert)
if err != nil { if err != nil {
c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed parsing local cluster certificate: %v", err) c.logger.Printf("[ERR] core/loadClusterTLS: failed parsing local cluster certificate: %v", err)
return fmt.Errorf("error parsing local cluster certificate: %v", err) return fmt.Errorf("error parsing local cluster certificate: %v", err)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
@ -114,10 +113,10 @@ func (e *ErrInvalidKey) Error() string {
} }
type activeAdvertisement struct { type activeAdvertisement struct {
RedirectAddr string `json:"redirect_addr"` RedirectAddr string `json:"redirect_addr"`
ClusterAddr string `json:"cluster_addr"` ClusterAddr string `json:"cluster_addr,omitempty"`
ClusterCert []byte `json:"cluster_cert"` ClusterCert []byte `json:"cluster_cert,omitempty"`
ClusterKeyParams clusterKeyParams `json:"cluster_key_params"` ClusterKeyParams *clusterKeyParams `json:"cluster_key_params,omitempty"`
} }
// Core is used as the central manager of Vault activity. It is the primary point of // Core is used as the central manager of Vault activity. It is the primary point of
@ -269,12 +268,11 @@ type Core struct {
// Write lock used to ensure that we don't have multiple connections adjust // Write lock used to ensure that we don't have multiple connections adjust
// this value at the same time // this value at the same time
requestForwardingConnectionLock sync.RWMutex requestForwardingConnectionLock sync.RWMutex
// Most recent hashed value of the advertise/cluster info. Used to avoid // Most recent leader UUID. Used to avoid repeatedly JSON parsing the same
// repeatedly JSON parsing the same values. // values.
clusterActiveAdvertisementHash []byte clusterLeaderUUID string
// Cache of most recently known active advertisement information, used to // Most recent leader redirect addr
// return values when the hash matches clusterLeaderRedirectAddr string
clusterActiveAdvertisement activeAdvertisement
// The grpc Server that handles server RPC calls // The grpc Server that handles server RPC calls
rpcServer *grpc.Server rpcServer *grpc.Server
// The function for canceling the client connection // The function for canceling the client connection
@ -626,7 +624,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
} }
// Read the value // Read the value
held, value, err := lock.Value() held, leaderUUID, err := lock.Value()
if err != nil { if err != nil {
return false, "", err return false, "", err
} }
@ -634,8 +632,13 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
return false, "", nil return false, "", nil
} }
// Value is the UUID of the leader, fetch the key // If the leader hasn't changed, return the cached value; nothing changes
key := coreLeaderPrefix + value // mid-leadership, and the barrier caches anyways
if leaderUUID == c.clusterLeaderUUID && c.clusterLeaderRedirectAddr != "" {
return false, c.clusterLeaderRedirectAddr, nil
}
key := coreLeaderPrefix + leaderUUID
entry, err := c.barrier.Get(key) entry, err := c.barrier.Get(key)
if err != nil { if err != nil {
return false, "", err return false, "", err
@ -644,26 +647,14 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
return false, "", nil return false, "", nil
} }
entrySHA256 := sha256.Sum256(entry.Value)
// Avoid JSON parsing and function calling if nothing has changed
if c.clusterActiveAdvertisementHash != nil {
if bytes.Compare(entrySHA256[:], c.clusterActiveAdvertisementHash) == 0 {
return false, c.clusterActiveAdvertisement.RedirectAddr, nil
}
}
var advAddr string
var oldAdv bool var oldAdv bool
var adv activeAdvertisement var adv activeAdvertisement
err = jsonutil.DecodeJSON(entry.Value, &adv) err = jsonutil.DecodeJSON(entry.Value, &adv)
if err != nil { if err != nil {
// Fall back to pre-struct handling // Fall back to pre-struct handling
advAddr = string(entry.Value) adv.RedirectAddr = string(entry.Value)
oldAdv = true oldAdv = true
} else {
advAddr = adv.RedirectAddr
} }
if !oldAdv { if !oldAdv {
@ -681,10 +672,12 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
} }
} }
c.clusterActiveAdvertisement = adv // Don't set these until everything has been parsed successfully or we'll
c.clusterActiveAdvertisementHash = entrySHA256[:] // never try again
c.clusterLeaderRedirectAddr = adv.RedirectAddr
c.clusterLeaderUUID = leaderUUID
return false, advAddr, nil return false, c.clusterLeaderRedirectAddr, nil
} }
// SecretProgress returns the number of keys provided so far // SecretProgress returns the number of keys provided so far
@ -1409,7 +1402,7 @@ func (c *Core) advertiseLeader(uuid string, leaderLostCh <-chan struct{}) error
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey) return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
} }
keyParams := clusterKeyParams{ keyParams := &clusterKeyParams{
Type: corePrivateKeyTypeP521, Type: corePrivateKeyTypeP521,
X: key.X, X: key.X,
Y: key.Y, Y: key.Y,

View File

@ -25,16 +25,7 @@ const (
// Starts the listeners and servers necessary to handle forwarded requests // Starts the listeners and servers necessary to handle forwarded requests
func (c *Core) startForwarding() error { func (c *Core) startForwarding() error {
// Clean up in case we have transitioned from a client to a server // Clean up in case we have transitioned from a client to a server
c.requestForwardingConnection = nil c.clearForwardingClients()
c.rpcForwardingClient = nil
if c.rpcClientConnCancelFunc != nil {
c.rpcClientConnCancelFunc()
c.rpcClientConnCancelFunc = nil
}
if c.rpcClientConn != nil {
c.rpcClientConn.Close()
c.rpcClientConn = nil
}
// Get our base handler (for our RPC server) and our wrapped handler (for // Get our base handler (for our RPC server) and our wrapped handler (for
// straight HTTP/2 forwarding) // straight HTTP/2 forwarding)
@ -191,16 +182,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
// Disabled, potentially, so clean up anything that might be around. // Disabled, potentially, so clean up anything that might be around.
if clusterAddr == "" { if clusterAddr == "" {
c.requestForwardingConnection = nil c.clearForwardingClients()
c.rpcForwardingClient = nil
if c.rpcClientConnCancelFunc != nil {
c.rpcClientConnCancelFunc()
c.rpcClientConnCancelFunc = nil
}
if c.rpcClientConn != nil {
c.rpcClientConn.Close()
c.rpcClientConn = nil
}
return nil return nil
} }
@ -244,6 +226,25 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
return nil return nil
} }
func (c *Core) clearForwardingClients() {
if c.requestForwardingConnection != nil {
c.requestForwardingConnection.transport.CloseIdleConnections()
c.requestForwardingConnection = nil
}
c.rpcForwardingClient = nil
if c.rpcClientConnCancelFunc != nil {
c.rpcClientConnCancelFunc()
c.rpcClientConnCancelFunc = nil
}
if c.rpcClientConn != nil {
c.rpcClientConn.Close()
c.rpcClientConn = nil
}
}
// ForwardRequest forwards a given request to the active node and returns the // ForwardRequest forwards a given request to the active node and returns the
// response. // response.
func (c *Core) ForwardRequest(req *http.Request) (int, []byte, error) { func (c *Core) ForwardRequest(req *http.Request) (int, []byte, error) {