vault/command/agentproxyshared/auth/auth.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

554 lines
15 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package auth
import (
"context"
"encoding/json"
"errors"
"math/rand"
"net/http"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
const (
defaultMinBackoff = 1 * time.Second
defaultMaxBackoff = 5 * time.Minute
)
// AuthMethod is the interface that auto-auth methods implement for the agent/proxy
// to use.
type AuthMethod interface {
// Authenticate returns a mount path, header, request body, and error.
// The header may be nil if no special header is needed.
Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
NewCreds() chan struct{}
CredSuccess()
Shutdown()
}
// AuthMethodWithClient is an extended interface that can return an API client
// for use during the authentication call.
type AuthMethodWithClient interface {
AuthMethod
AuthClient(client *api.Client) (*api.Client, error)
}
type AuthConfig struct {
Logger hclog.Logger
MountPath string
WrapTTL time.Duration
Config map[string]interface{}
}
// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
OutputCh chan string
TemplateTokenCh chan string
ExecTokenCh chan string
token string
userAgent string
metricsSignifier string
logger hclog.Logger
client *api.Client
random *rand.Rand
wrapTTL time.Duration
maxBackoff time.Duration
minBackoff time.Duration
enableReauthOnNewCredentials bool
enableTemplateTokenCh bool
enableExecTokenCh bool
exitOnError bool
}
type AuthHandlerConfig struct {
Logger hclog.Logger
Client *api.Client
WrapTTL time.Duration
MaxBackoff time.Duration
MinBackoff time.Duration
Token string
// UserAgent is the HTTP UserAgent header auto-auth will use when
// communicating with Vault.
UserAgent string
// MetricsSignifier is the first argument we will give to
// metrics.IncrCounter, signifying what the name of the application is
MetricsSignifier string
EnableReauthOnNewCredentials bool
EnableTemplateTokenCh bool
EnableExecTokenCh bool
ExitOnError bool
}
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
// This is buffered so that if we try to output after the sink server
// has been shut down, during agent/proxy shutdown, we won't block
OutputCh: make(chan string, 1),
TemplateTokenCh: make(chan string, 1),
ExecTokenCh: make(chan string, 1),
token: conf.Token,
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
wrapTTL: conf.WrapTTL,
minBackoff: conf.MinBackoff,
maxBackoff: conf.MaxBackoff,
enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
enableTemplateTokenCh: conf.EnableTemplateTokenCh,
enableExecTokenCh: conf.EnableExecTokenCh,
exitOnError: conf.ExitOnError,
userAgent: conf.UserAgent,
metricsSignifier: conf.MetricsSignifier,
}
return ah
}
func backoff(ctx context.Context, backoff *autoAuthBackoff) bool {
if backoff.exitOnErr {
return false
}
select {
case <-time.After(backoff.current):
case <-ctx.Done():
}
// Increase exponential backoff for the next time if we don't
// successfully auth/renew/etc.
backoff.next()
return true
}
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
if am == nil {
return errors.New("auth handler: nil auth method")
}
if ah.minBackoff <= 0 {
ah.minBackoff = defaultMinBackoff
}
backoffCfg := newAutoAuthBackoff(ah.minBackoff, ah.maxBackoff, ah.exitOnError)
if backoffCfg.min >= backoffCfg.max {
return errors.New("auth handler: min_backoff cannot be greater than max_backoff")
}
ah.logger.Info("starting auth handler")
defer func() {
am.Shutdown()
close(ah.OutputCh)
close(ah.TemplateTokenCh)
close(ah.ExecTokenCh)
ah.logger.Info("auth handler stopped")
}()
credCh := am.NewCreds()
if !ah.enableReauthOnNewCredentials {
realCredCh := credCh
credCh = nil
if realCredCh != nil {
go func() {
for {
select {
case <-ctx.Done():
return
case <-realCredCh:
}
}
}()
}
}
if credCh == nil {
credCh = make(chan struct{})
}
if ah.client != nil {
headers := ah.client.Headers()
if headers == nil {
headers = make(http.Header)
}
headers.Set("User-Agent", ah.userAgent)
ah.client.SetHeaders(headers)
}
var watcher *api.LifetimeWatcher
first := true
for {
select {
case <-ctx.Done():
return nil
default:
}
var clientToUse *api.Client
var err error
var path string
var data map[string]interface{}
var header http.Header
var isTokenFileMethod bool
switch am.(type) {
case AuthMethodWithClient:
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
if err != nil {
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
default:
clientToUse = ah.client
}
// Disable retry on the client to ensure our backoffOrQuit function is
// the only source of retry/backoff.
clientToUse.SetMaxRetries(0)
var secret *api.Secret = new(api.Secret)
if first && ah.token != "" {
ah.logger.Debug("using preloaded token")
first = false
ah.logger.Debug("lookup-self with preloaded token")
clientToUse.SetToken(ah.token)
secret, err = clientToUse.Auth().Token().LookupSelfWithContext(ctx)
if err != nil {
ah.logger.Error("could not look up token", "err", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
duration, _ := secret.Data["ttl"].(json.Number).Int64()
secret.Auth = &api.SecretAuth{
ClientToken: secret.Data["id"].(string),
LeaseDuration: int(duration),
Renewable: secret.Data["renewable"].(bool),
}
} else {
ah.logger.Info("authenticating")
path, header, data, err = am.Authenticate(ctx, ah.client)
if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
}
if ah.wrapTTL > 0 {
wrapClient, err := clientToUse.Clone()
if err != nil {
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return ah.wrapTTL.String()
})
clientToUse = wrapClient
}
for key, values := range header {
for _, value := range values {
clientToUse.AddHeader(key, value)
}
}
// This should only happen if there's no preloaded token (regular auto-auth login)
// or if a preloaded token has expired and is now switching to auto-auth.
if secret.Auth == nil {
isTokenFileMethod = path == "auth/token/lookup-self"
if isTokenFileMethod {
token, _ := data["token"].(string)
lookupSelfClient, err := clientToUse.Clone()
if err != nil {
ah.logger.Error("failed to clone client to perform token lookup")
return err
}
lookupSelfClient.SetToken(token)
secret, err = lookupSelfClient.Auth().Token().LookupSelf()
} else {
secret, err = clientToUse.Logical().WriteWithContext(ctx, path, data)
}
// Check errors/sanity
if err != nil {
ah.logger.Error("error authenticating", "error", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
}
var leaseDuration int
switch {
case ah.wrapTTL > 0:
if secret.WrapInfo == nil {
ah.logger.Error("authentication returned nil wrap info", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
if secret.WrapInfo.Token == "" {
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
ah.OutputCh <- string(wrappedResp)
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- string(wrappedResp)
}
if ah.enableExecTokenCh {
ah.ExecTokenCh <- string(wrappedResp)
}
am.CredSuccess()
backoffCfg.reset()
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered")
continue
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
continue
}
default:
// We handle the token_file method specially, as it's the only
// auth method that isn't actually authenticating, i.e. the secret
// returned does not have an Auth struct attached
isTokenFileMethod := path == "auth/token/lookup-self"
if isTokenFileMethod {
// We still check the response of the request to ensure the token is valid
// i.e. if the token is invalid, we will fail in the authentication step
if secret == nil || secret.Data == nil {
ah.logger.Error("token file validation failed, token may be invalid", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
token, ok := secret.Data["id"].(string)
if !ok || token == "" {
ah.logger.Error("token file validation returned empty client token", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
duration, _ := secret.Data["ttl"].(json.Number).Int64()
leaseDuration = int(duration)
renewable, _ := secret.Data["renewable"].(bool)
secret.Auth = &api.SecretAuth{
ClientToken: token,
LeaseDuration: int(duration),
Renewable: renewable,
}
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- token
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- token
}
if ah.enableExecTokenCh {
ah.ExecTokenCh <- token
}
tokenType := secret.Data["type"].(string)
if tokenType == "batch" {
ah.logger.Info("note that this token type is batch, and batch tokens cannot be renewed", "ttl", leaseDuration)
}
} else {
if secret == nil || secret.Auth == nil {
ah.logger.Error("authentication returned nil auth info", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
if secret.Auth.ClientToken == "" {
ah.logger.Error("authentication returned empty client token", "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
leaseDuration = secret.LeaseDuration
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- secret.Auth.ClientToken
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- secret.Auth.ClientToken
}
if ah.enableExecTokenCh {
ah.ExecTokenCh <- secret.Auth.ClientToken
}
}
am.CredSuccess()
backoffCfg.reset()
}
if watcher != nil {
watcher.Stop()
}
watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{
Secret: secret,
})
if err != nil {
ah.logger.Error("error creating lifetime watcher", "error", err, "backoff", backoffCfg)
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
if backoff(ctx, backoffCfg) {
continue
}
return err
}
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "success"}, 1)
// We don't want to trigger the renewal process for tokens with
// unlimited TTL, such as the root token.
if leaseDuration == 0 && isTokenFileMethod {
ah.logger.Info("not starting token renewal process, as token has unlimited TTL")
} else {
ah.logger.Info("starting renewal process")
go watcher.Renew()
}
LifetimeWatcherLoop:
for {
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered, stopping lifetime watcher")
watcher.Stop()
break LifetimeWatcherLoop
case err := <-watcher.DoneCh():
ah.logger.Info("lifetime watcher done channel triggered")
if err != nil {
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
ah.logger.Error("error renewing token", "error", err)
}
break LifetimeWatcherLoop
case <-watcher.RenewCh():
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "success"}, 1)
ah.logger.Info("renewed auth token")
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
break LifetimeWatcherLoop
}
}
}
}
// autoAuthBackoff tracks exponential backoff state.
type autoAuthBackoff struct {
min time.Duration
max time.Duration
current time.Duration
exitOnErr bool
}
func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff {
if max <= 0 {
max = defaultMaxBackoff
}
if min <= 0 {
min = defaultMinBackoff
}
return &autoAuthBackoff{
current: min,
max: max,
min: min,
exitOnErr: exitErr,
}
}
// next determines the next backoff duration that is roughly twice
// the current value, capped to a max value, with a measure of randomness.
func (b *autoAuthBackoff) next() {
maxBackoff := 2 * b.current
if maxBackoff > b.max {
maxBackoff = b.max
}
// Trim a random amount (0-25%) off the doubled duration
trim := rand.Int63n(int64(maxBackoff) / 4)
b.current = maxBackoff - time.Duration(trim)
}
func (b *autoAuthBackoff) reset() {
b.current = b.min
}
func (b autoAuthBackoff) String() string {
return b.current.Truncate(10 * time.Millisecond).String()
}