vault/command/agentproxyshared/cache/api_proxy.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

175 lines
5.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"context"
"fmt"
gohttp "net/http"
"sync"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
)
type EnforceConsistency int
const (
EnforceConsistencyNever EnforceConsistency = iota
EnforceConsistencyAlways
)
type WhenInconsistentAction int
const (
WhenInconsistentFail WhenInconsistentAction = iota
WhenInconsistentRetry
WhenInconsistentForward
)
// APIProxy is an implementation of the proxier interface that is used to
// forward the request to Vault and get the response.
type APIProxy struct {
client *api.Client
logger hclog.Logger
enforceConsistency EnforceConsistency
whenInconsistentAction WhenInconsistentAction
l sync.RWMutex
lastIndexStates []string
userAgentString string
userAgentStringFunction func(string) string
}
var _ Proxier = &APIProxy{}
type APIProxyConfig struct {
Client *api.Client
Logger hclog.Logger
EnforceConsistency EnforceConsistency
WhenInconsistentAction WhenInconsistentAction
// UserAgentString is used as the User Agent when the proxied client
// does not have a user agent of its own.
UserAgentString string
// UserAgentStringFunction is the function to transform the proxied client's
// user agent into one that includes Vault-specific information.
UserAgentStringFunction func(string) string
}
func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
if config.Client == nil {
return nil, fmt.Errorf("nil API client")
}
return &APIProxy{
client: config.Client,
logger: config.Logger,
enforceConsistency: config.EnforceConsistency,
whenInconsistentAction: config.WhenInconsistentAction,
userAgentString: config.UserAgentString,
userAgentStringFunction: config.UserAgentStringFunction,
}, nil
}
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := ap.client.Clone()
if err != nil {
return nil, err
}
client.SetToken(req.Token)
// Derive and set a logger for the client
clientLogger := ap.logger.Named("client")
client.SetLogger(clientLogger)
// http.Transport will transparently request gzip and decompress the response, but only if
// the client doesn't manually set the header. Removing any Accept-Encoding header allows the
// transparent compression to occur.
req.Request.Header.Del("Accept-Encoding")
if req.Request.Header == nil {
req.Request.Header = make(gohttp.Header)
}
// Set our User-Agent to be one indicating we are Vault Agent's API proxy.
// If the sending client had one, preserve it.
if req.Request.Header.Get("User-Agent") != "" {
initialUserAgent := req.Request.Header.Get("User-Agent")
req.Request.Header.Set("User-Agent", ap.userAgentStringFunction(initialUserAgent))
} else {
req.Request.Header.Set("User-Agent", ap.userAgentString)
}
client.SetHeaders(req.Request.Header)
fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path)
fwReq.BodyBytes = req.RequestBody
query := req.Request.URL.Query()
if len(query) != 0 {
fwReq.Params = query
}
var newState string
manageState := ap.enforceConsistency == EnforceConsistencyAlways &&
req.Request.Header.Get(http.VaultIndexHeaderName) == "" &&
req.Request.Header.Get(http.VaultForwardHeaderName) == "" &&
req.Request.Header.Get(http.VaultInconsistentHeaderName) == ""
if manageState {
client = client.WithResponseCallbacks(api.RecordState(&newState))
ap.l.RLock()
lastStates := ap.lastIndexStates
ap.l.RUnlock()
if len(lastStates) != 0 {
client = client.WithRequestCallbacks(api.RequireState(lastStates...))
switch ap.whenInconsistentAction {
case WhenInconsistentFail:
// In this mode we want to delegate handling of inconsistency
// failures to the external client talking to Agent.
client.SetCheckRetry(retryablehttp.DefaultRetryPolicy)
case WhenInconsistentRetry:
// In this mode we want to handle retries due to inconsistency
// internally. This is the default api.Client behaviour so
// we needn't do anything.
case WhenInconsistentForward:
fwReq.Headers.Set(http.VaultInconsistentHeaderName, http.VaultInconsistentForward)
}
}
}
// Make the request to Vault and get the response
ap.logger.Info("forwarding request to Vault", "method", req.Request.Method, "path", req.Request.URL.Path)
resp, err := client.RawRequestWithContext(ctx, fwReq)
if resp == nil && err != nil {
// We don't want to cache nil responses, so we simply return the error
return nil, err
}
if newState != "" {
ap.l.Lock()
// We want to be using the "newest" states seen, but newer isn't well
// defined here. There can be two states S1 and S2 which aren't strictly ordered:
// S1 could have a newer localindex and S2 could have a newer replicatedindex. So
// we need to merge them. But we can't merge them because we wouldn't be able to
// "sign" the resulting header because we don't have access to the HMAC key that
// Vault uses to do so. So instead we compare any of the 0-2 saved states
// we have to the new header, keeping the newest 1-2 of these, and sending
// them to Vault to evaluate.
ap.lastIndexStates = api.MergeReplicationStates(ap.lastIndexStates, newState)
ap.l.Unlock()
}
// Before error checking from the request call, we'd want to initialize a SendResponse to
// potentially return
sendResponse, newErr := NewSendResponse(resp, nil)
if newErr != nil {
return nil, newErr
}
// Bubble back the api.Response as well for error checking/handling at the handler layer.
return sendResponse, err
}