mirror of
https://github.com/hashicorp/vault.git
synced 2025-11-23 19:51:09 +01:00
For situations where you want the Vault agent to handle one or more templates but do not require the acquired credentials elsewhere. Modify the logic in SyncServer so that if there are no sinks, ignore any new credentials. Since SyncServer is responsible for shutting down the agent, make sure it still properly shuts down in this new situation. Solves #7988
257 lines
6.4 KiB
Go
257 lines
6.4 KiB
Go
package sink
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
hclog "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/helper/dhutil"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
)
|
|
|
|
type Sink interface {
|
|
WriteToken(string) error
|
|
}
|
|
|
|
type SinkReader interface {
|
|
Token() string
|
|
}
|
|
|
|
type SinkConfig struct {
|
|
Sink
|
|
Logger hclog.Logger
|
|
Config map[string]interface{}
|
|
Client *api.Client
|
|
WrapTTL time.Duration
|
|
DHType string
|
|
DHPath string
|
|
AAD string
|
|
cachedRemotePubKey []byte
|
|
cachedPubKey []byte
|
|
cachedPriKey []byte
|
|
}
|
|
|
|
type SinkServerConfig struct {
|
|
Logger hclog.Logger
|
|
Client *api.Client
|
|
Context context.Context
|
|
ExitAfterAuth bool
|
|
}
|
|
|
|
// SinkServer is responsible for pushing tokens to sinks
|
|
type SinkServer struct {
|
|
DoneCh chan struct{}
|
|
logger hclog.Logger
|
|
client *api.Client
|
|
random *rand.Rand
|
|
exitAfterAuth bool
|
|
remaining *int32
|
|
}
|
|
|
|
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
|
|
ss := &SinkServer{
|
|
DoneCh: make(chan struct{}),
|
|
logger: conf.Logger,
|
|
client: conf.Client,
|
|
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
|
|
exitAfterAuth: conf.ExitAfterAuth,
|
|
remaining: new(int32),
|
|
}
|
|
|
|
return ss
|
|
}
|
|
|
|
// Run executes the server's run loop, which is responsible for reading
|
|
// in new tokens and pushing them out to the various sinks.
|
|
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
|
|
latestToken := new(string)
|
|
writeSink := func(currSink *SinkConfig, currToken string) error {
|
|
if currToken != *latestToken {
|
|
return nil
|
|
}
|
|
var err error
|
|
|
|
if currSink.WrapTTL != 0 {
|
|
if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if currSink.DHType != "" {
|
|
if currToken, err = currSink.encryptToken(currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return currSink.WriteToken(currToken)
|
|
}
|
|
|
|
if incoming == nil {
|
|
panic("incoming channel is nil")
|
|
}
|
|
|
|
ss.logger.Info("starting sink server")
|
|
defer func() {
|
|
ss.logger.Info("sink server stopped")
|
|
close(ss.DoneCh)
|
|
}()
|
|
|
|
type sinkToken struct {
|
|
sink *SinkConfig
|
|
token string
|
|
}
|
|
sinkCh := make(chan sinkToken, len(sinks))
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
|
|
case token := <-incoming:
|
|
if len(sinks) > 0 {
|
|
if token != *latestToken {
|
|
|
|
// Drain the existing funcs
|
|
drainLoop:
|
|
for {
|
|
select {
|
|
case <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
default:
|
|
break drainLoop
|
|
}
|
|
}
|
|
|
|
*latestToken = token
|
|
|
|
for _, s := range sinks {
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- sinkToken{s, token}
|
|
}
|
|
}
|
|
} else {
|
|
ss.logger.Trace("no sinks, ignoring new token")
|
|
if ss.exitAfterAuth {
|
|
ss.logger.Trace("no sinks, exitAfterAuth, bye")
|
|
return
|
|
}
|
|
}
|
|
case st := <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if err := writeSink(st.sink, st.token); err != nil {
|
|
backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
|
|
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(backoff):
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- st
|
|
}
|
|
} else {
|
|
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SinkConfig) encryptToken(token string) (string, error) {
|
|
var aesKey []byte
|
|
var err error
|
|
resp := new(dhutil.Envelope)
|
|
switch s.DHType {
|
|
case "curve25519":
|
|
if len(s.cachedRemotePubKey) == 0 {
|
|
_, err = os.Lstat(s.DHPath)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err)
|
|
}
|
|
return "", errors.New("no dh parameters file found, and no cached pub key")
|
|
}
|
|
fileBytes, err := ioutil.ReadFile(s.DHPath)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err)
|
|
}
|
|
theirPubKey := new(dhutil.PublicKeyInfo)
|
|
if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
|
|
return "", errwrap.Wrapf("error decoding public key: {{err}}", err)
|
|
}
|
|
if len(theirPubKey.Curve25519PublicKey) == 0 {
|
|
return "", errors.New("public key is nil")
|
|
}
|
|
s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
|
|
}
|
|
if len(s.cachedPubKey) == 0 {
|
|
s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err)
|
|
}
|
|
}
|
|
resp.Curve25519PublicKey = s.cachedPubKey
|
|
}
|
|
|
|
aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error deriving shared key: {{err}}", err)
|
|
}
|
|
if len(aesKey) == 0 {
|
|
return "", errors.New("derived AES key is empty")
|
|
}
|
|
|
|
resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err)
|
|
}
|
|
m, err := jsonutil.EncodeJSON(resp)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|
|
|
|
func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
|
|
wrapClient, err := client.Clone()
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err)
|
|
}
|
|
wrapClient.SetToken(token)
|
|
wrapClient.SetWrappingLookupFunc(func(string, string) string {
|
|
return wrapTTL.String()
|
|
})
|
|
secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
|
|
"token": token,
|
|
})
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err)
|
|
}
|
|
if secret == nil {
|
|
return "", errors.New("nil secret returned, not writing out to sink")
|
|
}
|
|
if secret.WrapInfo == nil {
|
|
return "", errors.New("nil wrap info returned, not writing out to sink")
|
|
}
|
|
|
|
m, err := jsonutil.EncodeJSON(secret.WrapInfo)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|