From a5563e4aec6b1bb448a53abaae4f3f4afcdb20ea Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 25 May 2018 14:38:06 -0400 Subject: [PATCH] Redo API client locking (#4551) * Redo API client locking This assigns local values when in critical paths, allowing a single API client to much more quickly and safely pipeline requests. Additionally, in order to take that paradigm all the way it changes how timeouts are set. It now uses a context value set on the request instead of configuring the timeout in the http client per request, which was also potentially quite racy. Trivially tested with VAULT_CLIENT_TIMEOUT=2 vault write pki/root/generate/internal key_type=rsa key_bits=8192 --- api/client.go | 83 +++++++++++++++++++++++++++++----------------- api/client_test.go | 15 +-------- 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/api/client.go b/api/client.go index 8f5a298682..ce10fff141 100644 --- a/api/client.go +++ b/api/client.go @@ -388,11 +388,12 @@ func (c *Client) SetAddress(addr string) error { c.modifyLock.Lock() defer c.modifyLock.Unlock() - var err error - if c.addr, err = url.Parse(addr); err != nil { + parsedAddr, err := url.Parse(addr) + if err != nil { return errwrap.Wrapf("failed to set address: {{err}}", err) } + c.addr = parsedAddr return nil } @@ -411,7 +412,8 @@ func (c *Client) SetLimiter(rateLimit float64, burst int) { c.modifyLock.RLock() c.config.modifyLock.Lock() defer c.config.modifyLock.Unlock() - defer c.modifyLock.RUnlock() + c.modifyLock.RUnlock() + c.config.Limiter = rate.NewLimiter(rate.Limit(rateLimit), burst) } @@ -544,14 +546,20 @@ func (c *Client) SetPolicyOverride(override bool) { // doesn't need to be called externally. func (c *Client) NewRequest(method, requestPath string) *Request { c.modifyLock.RLock() - defer c.modifyLock.RUnlock() + addr := c.addr + token := c.token + mfaCreds := c.mfaCreds + wrappingLookupFunc := c.wrappingLookupFunc + headers := c.headers + policyOverride := c.policyOverride + c.modifyLock.RUnlock() // if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV // record and take the highest match; this is not designed for high-availability, just discovery - var host string = c.addr.Host - if c.addr.Port() == "" { + var host string = addr.Host + if addr.Port() == "" { // Internet Draft specifies that the SRV record is ignored if a port is given - _, addrs, err := net.LookupSRV("http", "tcp", c.addr.Hostname()) + _, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname()) if err == nil && len(addrs) > 0 { host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port) } @@ -560,12 +568,12 @@ func (c *Client) NewRequest(method, requestPath string) *Request { req := &Request{ Method: method, URL: &url.URL{ - User: c.addr.User, - Scheme: c.addr.Scheme, + User: addr.User, + Scheme: addr.Scheme, Host: host, - Path: path.Join(c.addr.Path, requestPath), + Path: path.Join(addr.Path, requestPath), }, - ClientToken: c.token, + ClientToken: token, Params: make(map[string][]string), } @@ -579,21 +587,19 @@ func (c *Client) NewRequest(method, requestPath string) *Request { lookupPath = requestPath } - req.MFAHeaderVals = c.mfaCreds + req.MFAHeaderVals = mfaCreds - if c.wrappingLookupFunc != nil { - req.WrapTTL = c.wrappingLookupFunc(method, lookupPath) + if wrappingLookupFunc != nil { + req.WrapTTL = wrappingLookupFunc(method, lookupPath) } else { req.WrapTTL = DefaultWrappingLookupFunc(method, lookupPath) } - if c.config.Timeout != 0 { - c.config.HttpClient.Timeout = c.config.Timeout - } - if c.headers != nil { - req.Headers = c.headers + + if headers != nil { + req.Headers = headers } - req.PolicyOverride = c.policyOverride + req.PolicyOverride = policyOverride return req } @@ -602,18 +608,23 @@ func (c *Client) NewRequest(method, requestPath string) *Request { // a Vault server not configured with this client. This is an advanced operation // that generally won't need to be called externally. func (c *Client) RawRequest(r *Request) (*Response, error) { - c.modifyLock.RLock() - c.config.modifyLock.RLock() - defer c.config.modifyLock.RUnlock() - - if c.config.Limiter != nil { - c.config.Limiter.Wait(context.Background()) - } - token := c.token + + c.config.modifyLock.RLock() + limiter := c.config.Limiter + maxRetries := c.config.MaxRetries + backoff := c.config.Backoff + httpClient := c.config.HttpClient + timeout := c.config.Timeout + c.config.modifyLock.RUnlock() + c.modifyLock.RUnlock() + if limiter != nil { + limiter.Wait(context.Background()) + } + // Sanity check the token before potentially erroring from the API idx := strings.IndexFunc(token, func(c rune) bool { return !unicode.IsPrint(c) @@ -632,16 +643,23 @@ START: return nil, fmt.Errorf("nil request created") } - backoff := c.config.Backoff + // Set the timeout, if any + var cancelFunc context.CancelFunc + if timeout != 0 { + var ctx context.Context + ctx, cancelFunc = context.WithTimeout(context.Background(), timeout) + req.Request = req.Request.WithContext(ctx) + } + if backoff == nil { backoff = retryablehttp.LinearJitterBackoff } client := &retryablehttp.Client{ - HTTPClient: c.config.HttpClient, + HTTPClient: httpClient, RetryWaitMin: 1000 * time.Millisecond, RetryWaitMax: 1500 * time.Millisecond, - RetryMax: c.config.MaxRetries, + RetryMax: maxRetries, CheckRetry: retryablehttp.DefaultRetryPolicy, Backoff: backoff, ErrorHandler: retryablehttp.PassthroughErrorHandler, @@ -649,6 +667,9 @@ START: var result *Response resp, err := client.Do(req) + if cancelFunc != nil { + cancelFunc() + } if resp != nil { result = &Response{Response: resp} } diff --git a/api/client_test.go b/api/client_test.go index 970354bab1..5678478ea0 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -7,7 +7,6 @@ import ( "os" "strings" "testing" - "time" ) func init() { @@ -244,22 +243,10 @@ func TestClientTimeoutSetting(t *testing.T) { defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout) config := DefaultConfig() config.ReadEnvironment() - client, err := NewClient(config) + _, err := NewClient(config) if err != nil { t.Fatal(err) } - _ = client.NewRequest("PUT", "/") - if client.config.HttpClient.Timeout != time.Second*10 { - t.Fatalf("error setting client timeout using env variable") - } - - // Setting custom client timeout for a new request - client.SetClientTimeout(time.Second * 20) - _ = client.NewRequest("PUT", "/") - if client.config.HttpClient.Timeout != time.Second*20 { - t.Fatalf("error setting client timeout using SetClientTimeout") - } - } type roundTripperFunc func(*http.Request) (*http.Response, error)