vault/command/agent/auth/auth.go
Calvin Leung Huang d54164f9e2
agent: return a non-zero exit code on error (#9670)
* agent: return a non-zero exit code on error

* agent/template: always return on template server error, add case for error_on_missing_key

* agent: fix tests by updating Run params to use an errCh

* agent/template: add permission denied test case, clean up test var

* agent: use unbuffered errCh, emit fatal errors directly to the UI output

* agent: use oklog's run.Group to schedule subsystem runners (#9761)

* agent: use oklog's run.Group to schedule subsystem runners

* agent: clean up unused DoneCh, clean up agent's main Run func

* agent/template: use ts.stopped.CAS to atomically swap value

* fix tests

* fix tests

* agent/template: add timeout on TestRunServer

* agent: output error via logs and return a generic error on non-zero exit

* fix TestAgent_ExitAfterAuth

* agent/template: do not restart ct runner on new incoming token if exit_after_auth is set to true

* agent: drain ah.OutputCh after sink exits to avoid blocking on the channel

* use context.WithTimeout, expand comments around ordering of defer cancel()
2020-09-29 18:03:09 -07:00

278 lines
7.2 KiB
Go

package auth
import (
"context"
"errors"
"math/rand"
"net/http"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
// AuthMethod is the interface that auto-auth methods implement for the agent
// to use.
type AuthMethod interface {
// Authenticate returns a mount path, header, request body, and error.
// The header may be nil if no special header is needed.
Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error)
NewCreds() chan struct{}
CredSuccess()
Shutdown()
}
// AuthMethodWithClient is an extended interface that can return an API client
// for use during the authentication call.
type AuthMethodWithClient interface {
AuthMethod
AuthClient(client *api.Client) (*api.Client, error)
}
type AuthConfig struct {
Logger hclog.Logger
MountPath string
WrapTTL time.Duration
Config map[string]interface{}
}
// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
OutputCh chan string
TemplateTokenCh chan string
logger hclog.Logger
client *api.Client
random *rand.Rand
wrapTTL time.Duration
enableReauthOnNewCredentials bool
enableTemplateTokenCh bool
}
type AuthHandlerConfig struct {
Logger hclog.Logger
Client *api.Client
WrapTTL time.Duration
EnableReauthOnNewCredentials bool
EnableTemplateTokenCh bool
}
func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
// This is buffered so that if we try to output after the sink server
// has been shut down, during agent shutdown, we won't block
OutputCh: make(chan string, 1),
TemplateTokenCh: make(chan string, 1),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
wrapTTL: conf.WrapTTL,
enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
enableTemplateTokenCh: conf.EnableTemplateTokenCh,
}
return ah
}
func backoffOrQuit(ctx context.Context, backoff time.Duration) {
select {
case <-time.After(backoff):
case <-ctx.Done():
}
}
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
if am == nil {
return errors.New("auth handler: nil auth method")
}
ah.logger.Info("starting auth handler")
defer func() {
am.Shutdown()
close(ah.OutputCh)
close(ah.TemplateTokenCh)
ah.logger.Info("auth handler stopped")
}()
credCh := am.NewCreds()
if !ah.enableReauthOnNewCredentials {
realCredCh := credCh
credCh = nil
if realCredCh != nil {
go func() {
for {
select {
case <-ctx.Done():
return
case <-realCredCh:
}
}
}()
}
}
if credCh == nil {
credCh = make(chan struct{})
}
var watcher *api.LifetimeWatcher
for {
select {
case <-ctx.Done():
return nil
default:
}
// Create a fresh backoff value
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
ah.logger.Info("authenticating")
path, header, data, err := am.Authenticate(ctx, ah.client)
if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
var clientToUse *api.Client
switch am.(type) {
case AuthMethodWithClient:
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
if err != nil {
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
default:
clientToUse = ah.client
}
if ah.wrapTTL > 0 {
wrapClient, err := clientToUse.Clone()
if err != nil {
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return ah.wrapTTL.String()
})
clientToUse = wrapClient
}
for key, values := range header {
for _, value := range values {
clientToUse.AddHeader(key, value)
}
}
secret, err := clientToUse.Logical().Write(path, data)
// Check errors/sanity
if err != nil {
ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
switch {
case ah.wrapTTL > 0:
if secret.WrapInfo == nil {
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.WrapInfo.Token == "" {
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
ah.OutputCh <- string(wrappedResp)
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- string(wrappedResp)
}
am.CredSuccess()
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered")
continue
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
continue
}
default:
if secret == nil || secret.Auth == nil {
ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
if secret.Auth.ClientToken == "" {
ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- secret.Auth.ClientToken
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- secret.Auth.ClientToken
}
am.CredSuccess()
}
if watcher != nil {
watcher.Stop()
}
watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{
Secret: secret,
})
if err != nil {
ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
// Start the renewal process
ah.logger.Info("starting renewal process")
go watcher.Renew()
LifetimeWatcherLoop:
for {
select {
case <-ctx.Done():
ah.logger.Info("shutdown triggered, stopping lifetime watcher")
watcher.Stop()
break LifetimeWatcherLoop
case err := <-watcher.DoneCh():
ah.logger.Info("lifetime watcher done channel triggered")
if err != nil {
ah.logger.Error("error renewing token", "error", err)
}
break LifetimeWatcherLoop
case <-watcher.RenewCh():
ah.logger.Info("renewed auth token")
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
break LifetimeWatcherLoop
}
}
}
}