Thomas L. Kula f31092e7ad
Allow auto_auth with templates without specifying a sink (#8812)
For situations where you want the Vault agent to handle one or more templates but do not require the acquired credentials elsewhere.

Modify the logic in SyncServer so that if there are no sinks, ignore any new credentials. Since SyncServer is responsible for shutting down the agent, make sure it still properly shuts down in this new situation.

Solves #7988
2020-05-26 13:52:14 -04:00

257 lines
6.4 KiB
Go

package sink
import (
"context"
"errors"
"io/ioutil"
"math/rand"
"os"
"sync/atomic"
"time"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/dhutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
)
type Sink interface {
WriteToken(string) error
}
type SinkReader interface {
Token() string
}
type SinkConfig struct {
Sink
Logger hclog.Logger
Config map[string]interface{}
Client *api.Client
WrapTTL time.Duration
DHType string
DHPath string
AAD string
cachedRemotePubKey []byte
cachedPubKey []byte
cachedPriKey []byte
}
type SinkServerConfig struct {
Logger hclog.Logger
Client *api.Client
Context context.Context
ExitAfterAuth bool
}
// SinkServer is responsible for pushing tokens to sinks
type SinkServer struct {
DoneCh chan struct{}
logger hclog.Logger
client *api.Client
random *rand.Rand
exitAfterAuth bool
remaining *int32
}
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
ss := &SinkServer{
DoneCh: make(chan struct{}),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
exitAfterAuth: conf.ExitAfterAuth,
remaining: new(int32),
}
return ss
}
// Run executes the server's run loop, which is responsible for reading
// in new tokens and pushing them out to the various sinks.
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
latestToken := new(string)
writeSink := func(currSink *SinkConfig, currToken string) error {
if currToken != *latestToken {
return nil
}
var err error
if currSink.WrapTTL != 0 {
if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
return err
}
}
if currSink.DHType != "" {
if currToken, err = currSink.encryptToken(currToken); err != nil {
return err
}
}
return currSink.WriteToken(currToken)
}
if incoming == nil {
panic("incoming channel is nil")
}
ss.logger.Info("starting sink server")
defer func() {
ss.logger.Info("sink server stopped")
close(ss.DoneCh)
}()
type sinkToken struct {
sink *SinkConfig
token string
}
sinkCh := make(chan sinkToken, len(sinks))
for {
select {
case <-ctx.Done():
return
case token := <-incoming:
if len(sinks) > 0 {
if token != *latestToken {
// Drain the existing funcs
drainLoop:
for {
select {
case <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
default:
break drainLoop
}
}
*latestToken = token
for _, s := range sinks {
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkToken{s, token}
}
}
} else {
ss.logger.Trace("no sinks, ignoring new token")
if ss.exitAfterAuth {
ss.logger.Trace("no sinks, exitAfterAuth, bye")
return
}
}
case st := <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
select {
case <-ctx.Done():
return
default:
}
if err := writeSink(st.sink, st.token); err != nil {
backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
select {
case <-ctx.Done():
return
case <-time.After(backoff):
atomic.AddInt32(ss.remaining, 1)
sinkCh <- st
}
} else {
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
return
}
}
}
}
}
func (s *SinkConfig) encryptToken(token string) (string, error) {
var aesKey []byte
var err error
resp := new(dhutil.Envelope)
switch s.DHType {
case "curve25519":
if len(s.cachedRemotePubKey) == 0 {
_, err = os.Lstat(s.DHPath)
if err != nil {
if !os.IsNotExist(err) {
return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err)
}
return "", errors.New("no dh parameters file found, and no cached pub key")
}
fileBytes, err := ioutil.ReadFile(s.DHPath)
if err != nil {
return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err)
}
theirPubKey := new(dhutil.PublicKeyInfo)
if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
return "", errwrap.Wrapf("error decoding public key: {{err}}", err)
}
if len(theirPubKey.Curve25519PublicKey) == 0 {
return "", errors.New("public key is nil")
}
s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
}
if len(s.cachedPubKey) == 0 {
s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
if err != nil {
return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err)
}
}
resp.Curve25519PublicKey = s.cachedPubKey
}
aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey)
if err != nil {
return "", errwrap.Wrapf("error deriving shared key: {{err}}", err)
}
if len(aesKey) == 0 {
return "", errors.New("derived AES key is empty")
}
resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
if err != nil {
return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err)
}
m, err := jsonutil.EncodeJSON(resp)
if err != nil {
return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err)
}
return string(m), nil
}
func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
wrapClient, err := client.Clone()
if err != nil {
return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err)
}
wrapClient.SetToken(token)
wrapClient.SetWrappingLookupFunc(func(string, string) string {
return wrapTTL.String()
})
secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
"token": token,
})
if err != nil {
return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err)
}
if secret == nil {
return "", errors.New("nil secret returned, not writing out to sink")
}
if secret.WrapInfo == nil {
return "", errors.New("nil wrap info returned, not writing out to sink")
}
m, err := jsonutil.EncodeJSON(secret.WrapInfo)
if err != nil {
return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err)
}
return string(m), nil
}