mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 04:16:31 +02:00
Reset agent backoff on successful auth (#11033)
The existing code would retain the previous backoff value even after the system had recovered. This PR fixes that issue and improves the structure of the backoff code.
This commit is contained in:
parent
ebcdae1f34
commit
ae49dde172
@ -87,11 +87,15 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
|
||||
return ah
|
||||
}
|
||||
|
||||
func backoffOrQuit(ctx context.Context, backoff time.Duration) {
|
||||
func backoffOrQuit(ctx context.Context, backoff *agentBackoff) {
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-time.After(backoff.current):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
// Increase exponential backoff for the next time if we don't
|
||||
// successfully auth/renew/etc.
|
||||
backoff.next()
|
||||
}
|
||||
|
||||
func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
@ -99,12 +103,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
return errors.New("auth handler: nil auth method")
|
||||
}
|
||||
|
||||
backoff := initialBackoff
|
||||
maxBackoff := defaultMaxBackoff
|
||||
|
||||
if ah.maxBackoff > 0 {
|
||||
maxBackoff = ah.maxBackoff
|
||||
}
|
||||
backoff := newAgentBackoff(ah.maxBackoff)
|
||||
|
||||
ah.logger.Info("starting auth handler")
|
||||
defer func() {
|
||||
@ -145,8 +144,6 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
default:
|
||||
}
|
||||
|
||||
backoff = calculateBackoff(backoff, maxBackoff)
|
||||
|
||||
var clientToUse *api.Client
|
||||
var err error
|
||||
var path string
|
||||
@ -157,7 +154,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
case AuthMethodWithClient:
|
||||
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
|
||||
if err != nil {
|
||||
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -175,7 +172,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
|
||||
secret, err = clientToUse.Logical().Read("auth/token/lookup-self")
|
||||
if err != nil {
|
||||
ah.logger.Error("could not look up token", "err", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("could not look up token", "err", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -191,7 +188,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
|
||||
path, header, data, err = am.Authenticate(ctx, ah.client)
|
||||
if err != nil {
|
||||
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -200,7 +197,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
if ah.wrapTTL > 0 {
|
||||
wrapClient, err := clientToUse.Clone()
|
||||
if err != nil {
|
||||
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -221,7 +218,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
secret, err = clientToUse.Logical().Write(path, data)
|
||||
// Check errors/sanity
|
||||
if err != nil {
|
||||
ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("error authenticating", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -230,18 +227,18 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
switch {
|
||||
case ah.wrapTTL > 0:
|
||||
if secret.WrapInfo == nil {
|
||||
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds())
|
||||
ah.logger.Error("authentication returned nil wrap info", "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
if secret.WrapInfo.Token == "" {
|
||||
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds())
|
||||
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo)
|
||||
if err != nil {
|
||||
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -252,6 +249,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
}
|
||||
|
||||
am.CredSuccess()
|
||||
backoff.reset()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -265,12 +263,12 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
|
||||
default:
|
||||
if secret == nil || secret.Auth == nil {
|
||||
ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds())
|
||||
ah.logger.Error("authentication returned nil auth info", "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
if secret.Auth.ClientToken == "" {
|
||||
ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds())
|
||||
ah.logger.Error("authentication returned empty client token", "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -281,6 +279,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
}
|
||||
|
||||
am.CredSuccess()
|
||||
backoff.reset()
|
||||
}
|
||||
|
||||
if watcher != nil {
|
||||
@ -291,7 +290,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
Secret: secret,
|
||||
})
|
||||
if err != nil {
|
||||
ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff.Seconds())
|
||||
ah.logger.Error("error creating lifetime watcher, backing off and retrying", "error", err, "backoff", backoff)
|
||||
backoffOrQuit(ctx, backoff)
|
||||
continue
|
||||
}
|
||||
@ -326,15 +325,41 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
||||
}
|
||||
}
|
||||
|
||||
// calculateBackoff determines a new backoff duration that is roughly twice
|
||||
// the previous value, capped to a max value, with a measure of randomness.
|
||||
func calculateBackoff(previous, max time.Duration) time.Duration {
|
||||
maxBackoff := 2 * previous
|
||||
if maxBackoff > max {
|
||||
maxBackoff = max
|
||||
// agentBackoff tracks exponential backoff state.
|
||||
type agentBackoff struct {
|
||||
max time.Duration
|
||||
current time.Duration
|
||||
}
|
||||
|
||||
func newAgentBackoff(max time.Duration) *agentBackoff {
|
||||
if max <= 0 {
|
||||
max = defaultMaxBackoff
|
||||
}
|
||||
|
||||
return &agentBackoff{
|
||||
max: max,
|
||||
current: initialBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
// next determines the next backoff duration that is roughly twice
|
||||
// the current value, capped to a max value, with a measure of randomness.
|
||||
func (b *agentBackoff) next() {
|
||||
maxBackoff := 2 * b.current
|
||||
|
||||
if maxBackoff > b.max {
|
||||
maxBackoff = b.max
|
||||
}
|
||||
|
||||
// Trim a random amount (0-25%) off the doubled duration
|
||||
trim := rand.Int63n(int64(maxBackoff) / 4)
|
||||
return maxBackoff - time.Duration(trim)
|
||||
b.current = maxBackoff - time.Duration(trim)
|
||||
}
|
||||
|
||||
func (b *agentBackoff) reset() {
|
||||
b.current = initialBackoff
|
||||
}
|
||||
|
||||
func (b agentBackoff) String() string {
|
||||
return b.current.Truncate(10 * time.Millisecond).String()
|
||||
}
|
||||
|
||||
@ -107,41 +107,39 @@ consumption:
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff(t *testing.T) {
|
||||
tests := []struct {
|
||||
previous time.Duration
|
||||
max time.Duration
|
||||
expMin time.Duration
|
||||
expMax time.Duration
|
||||
}{
|
||||
{
|
||||
1000 * time.Millisecond,
|
||||
60000 * time.Millisecond,
|
||||
1500 * time.Millisecond,
|
||||
2000 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
1000 * time.Millisecond,
|
||||
5000 * time.Millisecond,
|
||||
1500 * time.Millisecond,
|
||||
2000 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
4000 * time.Millisecond,
|
||||
5000 * time.Millisecond,
|
||||
3750 * time.Millisecond,
|
||||
5000 * time.Millisecond,
|
||||
},
|
||||
func TestAgentBackoff(t *testing.T) {
|
||||
max := 1024 * time.Second
|
||||
backoff := newAgentBackoff(max)
|
||||
|
||||
// Test initial value
|
||||
if backoff.current != initialBackoff {
|
||||
t.Fatalf("expected 1s initial backoff, got: %v", backoff.current)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for i := 0; i < 100; i++ {
|
||||
backoff := calculateBackoff(test.previous, test.max)
|
||||
// Test that backoff values are in expected range (75-100% of 2*previous)
|
||||
for i := 0; i < 9; i++ {
|
||||
old := backoff.current
|
||||
backoff.next()
|
||||
|
||||
// Verify that the new backoff is 75-100% of 2*previous, but <= than the max
|
||||
if backoff < test.expMin || backoff > test.expMax {
|
||||
t.Fatalf("expected backoff in range %v to %v, got: %v", test.expMin, test.expMax, backoff)
|
||||
}
|
||||
expMax := 2 * old
|
||||
expMin := 3 * expMax / 4
|
||||
|
||||
if backoff.current < expMin || backoff.current > expMax {
|
||||
t.Fatalf("expected backoff in range %v to %v, got: %v", expMin, expMax, backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that backoff is capped
|
||||
for i := 0; i < 100; i++ {
|
||||
backoff.next()
|
||||
if backoff.current > max {
|
||||
t.Fatalf("backoff exceeded max of 100s: %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test reset
|
||||
backoff.reset()
|
||||
if backoff.current != initialBackoff {
|
||||
t.Fatalf("expected 1s backoff after reset, got: %v", backoff.current)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user